#!/usr/bin/python3
# coding=utf8
#
# The MIT License (MIT)
#
# Copyright (c) 2015 Lorenz Hüdepohl
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

import os
import sys
import time
import json
import errno
import argparse
from typing import Iterator
from subprocess import Popen, PIPE, run

DEFAULT_TIME_FORMAT = "%Y-%m-%d-%H:%M"

parser = argparse.ArgumentParser(
    formatter_class=argparse.RawTextHelpFormatter,
    description='''
Use with cron to make periodic ZFS snapshots, with the option of keeping
only a limited number of old snapshots.
''')

parser.add_argument('filesystem', metavar='FILESYSTEM', type=str, nargs="+",
                    help='Snapshot recursively below FILESYSTEM')

parser.add_argument('--timeformat', metavar='TIMEFORMAT', type=str, default=DEFAULT_TIME_FORMAT,
                    help='Use this time.strptime() format for the time part of the\n'
                         'snapshot name instead of the default ' + DEFAULT_TIME_FORMAT.replace("%", "%%"))

parser.add_argument('--snapname', metavar='SNAPNAME', type=str, required=True,
                    help='Create a new snapshot autosnapshot-SNAPNAME-' + DEFAULT_TIME_FORMAT.replace("%", "%%"))

parser.add_argument('-v', '--verbose', action="store_true",
                    help='Echo the commands that are issued')

parser.add_argument('-k', '--keep', metavar="NUM", type=int,
                    help='Keep at most NUM snapshots, by potentially deleting some of the oldest')

parser.add_argument('-m', '--minutes', metavar="MIN", type=int,
                    help='Only make a snapshot if more than MIN minutes passed since last snapshot')

parser.add_argument('-n', '--dry-run', action="store_true",
                    help='Only echo the commands that would be issued,\ndo not actually do anything')

parser.add_argument('-e', '--except', action='append', metavar='FILESYSTEM', type=str,
                    help='Do not snapshot at or below filesystem FILESYSTEM')

args = parser.parse_args()


local_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "snapshots.lua")
global_script = "/usr/share/zfs-auto-snapshot/snapshots.lua"
if os.path.isfile(local_script):
    snapshots_lua = local_script
else:
    snapshots_lua = global_script


snap_prefix = "autosnapshot-" + args.snapname
now_format = time.strftime(args.timeformat, time.gmtime())
new_snap_name = snap_prefix + "-" + now_format


def list_snapshots() -> Iterator[tuple[str, str, float]]:
    for fs in args.filesystem:
        zfs_snapshots = Popen(["/usr/sbin/zfs", "list", "-t", "snapshot", fs, "-H", "-r", "-o", "name"],
                              stdout=PIPE, text=True)
        if zfs_snapshots.stdout is None:
            print("Error getting list of snapshots")
            raise SystemExit(1)
        for line in zfs_snapshots.stdout:
            fs, snap_name = line.rstrip("\n").split("@")
            if getattr(args, "except") and any(fs.startswith(e) for e in getattr(args, "except")):
                continue
            if snap_name.startswith(snap_prefix + "-"):
                snap_time = time.strptime(snap_name.removeprefix(snap_prefix + "-"), args.timeformat)
                yield fs, snap_name, time.mktime(snap_time)


if args.minutes:
    # Parse back to round time now to format used
    now = time.mktime(time.strptime(now_format, args.timeformat))
    for fs, snap_name, snap_time in list_snapshots():
        age = now - snap_time
        if age / 60 < args.minutes:
            if args.verbose or args.dry_run:
                print("Not doing any snapshots, last snapshot in {0} is only {1:.0f} minutes old, "
                      "less than the configured {2}".format(fs, age / 60, args.minutes))
            # Not continuing
            raise SystemExit(0)

# Make snapshots
error = False
excepts = []
if getattr(args, "except"):
    for e in getattr(args, "except"):
        excepts.extend(["--except=" + e])

fs_pools: dict[str, list[str]] = {}
for fs in args.filesystem:
    pool, *_ = fs.split("/", 1)
    if pool not in fs_pools:
        fs_pools[pool] = []
    fs_pools[pool].append(fs)

if args.keep is not None:
    keep = ["--keep=" + str(args.keep), "--keepprefix=@" + snap_prefix + "-"]
else:
    keep = []
if args.dry_run:
    dry_run = ["--dry-run"]
else:
    dry_run = []
for pool in fs_pools:
    cmd = ["/usr/sbin/zfs", "program", "-j", pool, "--", snapshots_lua, "@" + new_snap_name] + \
        fs_pools[pool] + excepts + keep + dry_run
    if args.verbose or args.dry_run:
        print(" ".join(cmd))

    p = run(cmd, capture_output=True, text=True)
    if p.returncode != 0:
        if p.stdout:
            print(p.stdout)
        print(p.stderr, file=sys.stderr)
        raise SystemExit(p.returncode)

    result = json.loads(p.stdout)["return"]

    if result["snapshot_errors"]:
        error = True
        print("Error making the following snapshots:")
        for snapshot, snaperr in result["snapshot_errors"].items():
            e = snaperr["errno"]
            print(f' {snapshot}: {e} ({errno.errorcode[e]}): {os.strerror(e)}')
            if "details" in snaperr:
                print("", snaperr["details"])
        print()

    if result["delete_errors"]:
        error = True
        print("Error making the following deletes:", file=sys.stderr)
        for snapshot, snaperr in result["delete_errors"].items():
            e = snaperr["errno"]
            print(f' {snapshot}: {e} ({errno.errorcode[e]}): {os.strerror(e)}', file=sys.stderr)
            if "details" in snaperr:
                print("", snaperr["details"], file=sys.stderr)
        print(file=sys.stderr)

    if args.verbose and result["created_snapshots"]:
        print("Created the following snapshots:", file=sys.stderr)
        for snapshot in result["created_snapshots"].values():
            print(f" {snapshot}", file=sys.stderr)
        print(file=sys.stderr)

    if args.verbose and result["deleted_snapshots"]:
        print("Deleted the following snapshots:", file=sys.stderr)
        for snapshot in result["deleted_snapshots"].values():
            print(f" {snapshot}", file=sys.stderr)
        print(file=sys.stderr)

if error:
    raise SystemExit(1)
