#!/usr/bin/python3
#
# Copyright 2024 Lorenz Hüdepohl
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.

# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.

# You should have received a copy of the GNU General Public License along
# with this program. If not, see <https://www.gnu.org/licenses/>.
#
import re
import sys
import argparse
import textwrap
from typing import Iterable, List, Dict, Tuple, Set, Union, Optional, NoReturn
from subprocess import Popen, PIPE, run, check_output, CompletedProcess

sys.tracebacklimit = 0

parser = argparse.ArgumentParser(
    description='This utility uses zfs send/receive to transfer snapshots from '
                'one zfs pool/filesytem to another. If preexisting snapshots with the same '
                'name exist on both sides they are assumed to hold identical state and only '
                'incremental send and receives are done to reduce the amount of data transferred.',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    fromfile_prefix_chars='@'
)

parser.add_argument("origin", type=str,
                    help='Origin zfs filesystem, could be an SSH remote path, e.g. "user@host:pool/fs"')

parser.add_argument("destination", type=str,
                    help='Destination prefix, e.g. backup_pool/this_host. The ZFS snapshots from "origin" '
                         'are then stored below there. As with "origin", can be an SSH remote path')

parser.add_argument('-b', '--bidirectional', action="store_true",
                    help='Sync in both directions.')

parser.add_argument('--snapname', metavar="TAGNAME", type=str, action="append",
                    help='Only consider snapshot names starting with TAGNAME. '
                         'Can be specified more than once')

parser.add_argument('--ignore-snapname', metavar="TAGNAME", type=str, action="append",
                    help='Do not consider snapshot names starting with TAGNAME. '
                         'Can be specified more than once')

parser.add_argument('--ignore-filesystem', metavar="FSNAME", type=str, action="append",
                    help='Do not consider filesystem FSNAME (and its children). '
                         'Can be specified more than once')

parser.add_argument('-w', '--raw', action="store_true",
                    help='Pass -w/--raw to zfs send for encrypted source filesystems')

parser.add_argument('-v', '--verbose', action="append_const", const=1,
                    help='Echo the commands that are issued. Two -v pass a -v along to the zfs send/receive commands')

parser.add_argument('-n', '--dry-run', action="store_true",
                    help='Only echo the transfer commands that would be issued,\n'
                         'do not actually send anything.')

parser.add_argument('-F', '--force', action="store_true",
                    help='Pass `-F` to `zfs receive` to overwrite other '
                         'snapshots or diverged changes on the remote side\n')

parser.add_argument('-z', '--compression', action="store_true",
                    help='Pass "-C" to ssh')

args = parser.parse_args()
args.verbose = sum(args.verbose) if args.verbose is not None else 0

if args.verbose > 1:
    verbose = ["-v"]
else:
    verbose = []

if args.force:
    force = ["-F"]
else:
    force = []

if args.raw:
    raw = ["--raw"]
else:
    raw = []

if args.compression:
    compression = ["-C"]
else:
    compression = []


def ignore_filesystem(fs: str) -> bool:
    if args.ignore_filesystem:
        for ignore_fs in args.ignore_filesystem:
            if fs == ignore_fs or fs.startswith(ignore_fs + "/"):
                return True
    return False


def use_snapshot(snapname: str) -> bool:
    if args.ignore_snapname:
        for ignore_snapname in args.ignore_snapname:
            if snapname.startswith(ignore_snapname):
                return False

    if args.snapname:
        for allowed_snapname in args.snapname:
            if snapname.startswith(allowed_snapname):
                return True
        return False

    return True


def exit(code: int) -> NoReturn:
    if 'pdb' in sys.modules:
        raise RuntimeError()
    else:
        raise SystemExit(code)


def check_returncode(proc: Union[CompletedProcess, Popen]) -> None:
    if not proc.returncode == 0:
        assert isinstance(proc.args, list)
        print("Command \"{0}\" returned error {1}, aborting...".format(" ".join(proc.args), proc.returncode),
              file=sys.stderr)
        if proc.stderr is not None:
            print(proc.stderr.read())
        exit(2)


def do(cmd: List[str]) -> None:
    if args.verbose or args.dry_run:
        print(*cmd, flush=True)
    if not args.dry_run:
        p = run(cmd)
        check_returncode(p)


class Snapshot(object):
    def __init__(self, fs: "Filesystem", name: str, guid: int):
        self.fs = fs
        self.name = name
        self.guid = guid

    def rename(self, newname: str):
        do(self.fs.state.ssh
           + ["zfs", "rename",
              self.fs.name + "@" + self.name,
              self.fs.name + "@" + newname])
        self.name = newname

    def __repr__(self):
        return "Snapshot({0}, '{1}', {2})".format(self.fs, self.name, self.guid)


class Filesystem(object):
    def __init__(self, state: "ZFSState", name: str, encryption: Optional[str] = None,
                 receive_resume_token: Optional[str] = None):
        self.state = state
        self.name = name
        self.encryption: bool = encryption is not None and encryption != "off"
        self.receive_resume_token: Optional[str] = None
        self.receive_resume_token_info: Optional[dict] = None
        if receive_resume_token != "-":
            self.receive_resume_token = receive_resume_token
        self.snapshots: List[Snapshot] = []
        self.snapshots_by_guid: Dict[int, Snapshot] = {}
        if self.receive_resume_token is not None:
            info = check_output(["zfs", "send", "-Pnvt", self.receive_resume_token]).decode()
            token_info_format = textwrap.dedent(r"""
                resume token contents:
                nvlist version: 0
                (\tfromguid = (?P<fromguid>0x.*)
                )?\tobject = (?P<object>0x.*)
                \toffset = (?P<offset>0x.*)
                \tbytes = (?P<bytes>0x.*)
                \ttoguid = (?P<toguid>0x.*)
                \ttoname = (?P<toname>.*)
                (?:incremental\t(?P<fromsnap>[^\t]*)|(?:full))\t(?P<tosnap>[^\t]*)\t(?P<size>.*)
                size\t(?P<size2>.*)""")[1:]
            g = re.match(token_info_format, info)
            if g is None:
                print("Cannot parse receive resume token information:", file=sys.stderr)
                print(info, file=sys.stderr)
                exit(1)
            else:
                self.receive_resume_token_info = g.groupdict()

    def suffix(self) -> str:
        return self.name[len(self.state.prefix):]

    def parent(self) -> str:
        return self.name.rsplit("/", 1)[0]

    def add_snapshot(self, snapname: str, guid: int) -> Snapshot:
        snapshot = Snapshot(self, snapname, guid)
        self.snapshots.append(snapshot)
        self.snapshots_by_guid[snapshot.guid] = snapshot
        return snapshot

    def sync_snapshots_to(self, dest: "ZFSState") -> None:
        if not self.snapshots:
            return

        dest_fs: Optional[Filesystem] = None

        candidates = set()
        for snapshot in self.snapshots:
            if snapshot.guid in dest.snapshots_by_guid:
                candidates.add(dest.snapshots_by_guid[snapshot.guid].fs)

        if len(candidates) == 1:
            dest_fs, = candidates

        elif len(candidates) > 1:
            # Select the one with the same suffix, if exists
            dest_fs = None
            for c in candidates:
                if c.suffix() == self.suffix():
                    dest_fs = c
                    break

            if dest_fs is None:
                print("Found more than one possible destination for filesystem '{0}', aborting:\n  - {1}".format(
                      self.name, "\n  - ".join(sorted(map(str, candidates)))),
                      file=sys.stderr)
                exit(1)
        else:  # A suitable snapshot does not exist at destination

            # But an empty filesystem?
            dest_fs = dest.filesystems_by_fsname.get(dest.prefix + self.suffix(), None)
            if dest_fs is None:
                # if not, create one
                dest_fs = dest.add_filesystem(dest.prefix + self.suffix())

        # Has it been renamed at destination?
        if dest_fs.suffix() != self.suffix():
            dest.rename_filesystem(dest_fs, self.suffix())

        # Apply possible renames to dest
        for dest_snapshot in dest_fs.snapshots:
            if dest_snapshot.guid in self.snapshots_by_guid:
                snapshot = self.snapshots_by_guid[dest_snapshot.guid]
                if dest_snapshot.name != snapshot.name:
                    dest_snapshot.rename(snapshot.name)

        def unsynced(fs1: "Filesystem", fs2: "Filesystem") -> List[Tuple[int, Snapshot]]:
            res = []
            for i, snapshot in enumerate(fs1.snapshots):
                if snapshot.guid not in fs2.snapshots_by_guid:
                    res.append((i, snapshot))
                else:
                    # Forget everything before a common snapshot
                    res = []

            return res

        my_unsynced = unsynced(self, dest_fs)
        their_unsynced = unsynced(dest_fs, self)

        if len(my_unsynced) > 0 and len(their_unsynced) > 0:
            print("\nERROR: The snapshots have diverged on origin and destination:\n", file=sys.stderr)
            print(" Unsynchronized snapshots on", self.name, file=sys.stderr)
            for _, snapshot in my_unsynced:
                print("  - {0}@{1} (guid {2})".format(snapshot.fs.name, snapshot.name, hex(snapshot.guid)),
                      file=sys.stderr)
            print("", file=sys.stderr)
            print(" Unsynchronized snapshots on", dest_fs.name, file=sys.stderr)
            for _, snapshot in their_unsynced:
                print("  - {0}@{1} (guid {2})".format(snapshot.fs.name, snapshot.name, hex(snapshot.guid)),
                      file=sys.stderr)
            print("\nAborting", file=sys.stderr)
            exit(1)

        def send(fs: "Filesystem", unsynced: Iterable[Tuple[int, Snapshot]], dest_fs: "Filesystem"):
            for i, snapshot in unsynced:
                send_cmd = self.state.ssh + ["zfs", "send"] + verbose
                if dest_fs.receive_resume_token_info is not None:
                    token_info = dest_fs.receive_resume_token_info
                    if i > 0:
                        # expect incremental
                        if token_info["fromsnap"] != fs.name + "@" + fs.snapshots[i - 1].name:
                            print("Invalid receive resume token, does not match expected snapshots", file=sys.stderr)
                            exit(1)
                    else:
                        if token_info["fromsnap"] is not None:
                            print("Invalid receive resume token, not expecting an incremental snapshot", file=sys.stderr)
                            exit(1)
                    if token_info["tosnap"] != fs.name + "@" + snapshot.name:
                        print("Invalid receive resume token, does not match expected snapshots", file=sys.stderr)
                        exit(1)
                    assert dest_fs.receive_resume_token is not None
                    send_cmd += ["-t", dest_fs.receive_resume_token]
                else:
                    if fs.encryption:
                        send_cmd += raw
                    if i > 0:
                        send_cmd += ["-i", "@" + fs.snapshots[i - 1].name]
                    send_cmd += [fs.name + "@" + snapshot.name]

                recv_cmd = dest_fs.state.ssh + ["zfs", "receive", "-s"] + force + verbose + ["-u", dest_fs.name]

                if args.verbose or args.dry_run:
                    print(" ".join(send_cmd), "|", " ".join(recv_cmd), flush=True)

                if not args.dry_run:
                    with Popen(send_cmd, stdout=PIPE) as p1:
                        with Popen(recv_cmd, stdin=p1.stdout) as p2:
                            if p1.stdout is not None:
                                p1.stdout.close()
                            p2.wait()
                            check_returncode(p2)
                            p1.wait()
                            check_returncode(p1)

                # Clear potential resume tokens
                dest_fs.receive_resume_token = None
                dest_fs.receive_resume_token_info = None

        if len(my_unsynced) > 0:
            send(self, my_unsynced, dest_fs)

        if args.bidirectional:
            if len(their_unsynced) > 0:
                send(dest_fs, their_unsynced, self)

    def __repr__(self):
        return "Filesystem({0}, '{1}')".format(self.state, self.name)


class ZFSState(object):
    def __init__(self, location: str):
        self.location = location
        if ":" in location:
            self.host, self.prefix = location.split(":")
            self.ssh = ["ssh", self.host] + compression
        else:
            self.prefix = location
            self.host = ""
            self.ssh = []

        self.filesystems: List[Filesystem] = []
        self.filesystems_by_fsname: Dict[str, Filesystem] = {}
        self.filesystems_by_parents: Dict[str, Set[Filesystem]] = {}
        self.snapshots_by_guid: Dict[int, Snapshot] = {}

        # List all filesystems and snapshots
        proc = Popen(self.ssh
                     + ["zfs", "list", "-H",
                        "-t", "filesystem,volume,snapshot",
                        "-o", "type,name,guid,encryption,receive_resume_token",
                        "-r", self.prefix],
                     stdout=PIPE, stderr=PIPE, universal_newlines=True)
        if proc.stdout is None:
            if proc.stderr is not None:
                print(proc.stderr.read())
            exit(2)
        for line in proc.stdout:
            kind, name, _guid, encryption, receive_resume_token = line.strip().split("\t")
            guid = int(_guid)
            if not name.startswith(self.prefix):
                print("Unexpexted {0}: \"{1}\", expected a string starting with \"{2}\"".format(
                      kind, name, self.prefix),
                      file=sys.stderr)
                exit(1)
            if kind in ("filesystem", "volume"):
                # New filesystem (volumes are treated like filesystems)
                if not ignore_filesystem(name):
                    self.add_filesystem(name, encryption, receive_resume_token)
            elif kind == "snapshot":
                fsname, snapname = name.split("@", 2)
                if ignore_filesystem(fsname):
                    continue

                fs = self.filesystems_by_fsname[fsname]

                if use_snapshot(snapname):
                    snapshot = fs.add_snapshot(snapname, guid)
                    self.snapshots_by_guid[snapshot.guid] = snapshot
            else:
                print("Unexpexted ZFS object: '{0}' for '{1}'".format(kind, name), file=sys.stderr)
                exit(1)

        proc.wait()
        if proc.returncode == 1:
            if proc.stderr is not None and proc.stderr.read().endswith("dataset does not exist\n"):
                return
        check_returncode(proc)

    def __repr__(self):
        return "ZFSState('{0}')".format(self.location)

    def add_filesystem(self, name: str, encryption: Optional[str] = None,
                       receive_resume_token: Optional[str] = None) -> Filesystem:
        fs = Filesystem(self, name, encryption, receive_resume_token)
        self.filesystems.append(fs)
        self.filesystems_by_fsname[fs.name] = fs
        parent = fs.parent()
        if fs.parent() in self.filesystems_by_parents:
            self.filesystems_by_parents[parent].add(fs)
        else:
            self.filesystems_by_parents[parent] = {fs}
        return fs

    def rename_filesystem(self, fs: Filesystem, newsuffix: str) -> None:
        do(self.ssh + ["zfs", "rename", fs.name, self.prefix + newsuffix])
        self._rename_filesystem(fs, self.prefix + newsuffix)

    def _rename_filesystem(self, fs: Filesystem, newname: str) -> None:
        oldname = fs.name
        oldparent = fs.parent()

        fs.name = newname
        newparent = fs.parent()

        # Update filesystems_by_fsname
        del self.filesystems_by_fsname[oldname]
        self.filesystems_by_fsname[newname] = fs

        # Put this fs in the proper list of children for its new parent
        self.filesystems_by_parents[oldparent].remove(fs)
        if newparent not in self.filesystems_by_parents:
            self.filesystems_by_parents[newparent] = {fs}
        else:
            self.filesystems_by_parents[newparent].add(fs)

        # Track implicit rename of all children
        if oldname in self.filesystems_by_parents:
            children = self.filesystems_by_parents[oldname]
            for child in sorted(children, key=lambda fs: fs.name):
                if not child.name.startswith(oldname):
                    print("This shouldn't happen", file=sys.stderr)
                    exit(1)
                newchildname = newname + child.name[len(oldname):]
                self._rename_filesystem(child, newchildname)


def sync(a: ZFSState, b: ZFSState):
    for a_fs in a.filesystems:
        a_fs.sync_snapshots_to(b)


orig = ZFSState(args.origin)
dest = ZFSState(args.destination)

sync(orig, dest)
if args.bidirectional:
    sync(dest, orig)
