#!/usr/bin/python3
#
# Copyright (c) 2017-2022, SUSE LLC
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer. Redistributions
# in binary form must reproduce the above copyright notice, this list of
# conditions and the following disclaimer in the documentation and/or
# other materials provided with the distribution.
#
# Neither the name of the SUSE Linux Products GmbH nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""
 Author: Bo Maryniuk <bo@suse.de>

  This tool helps to:
  1. Format patches from Git the way it has a minimal impact on
     the changes in the future

  2. Update patches to the current package source

  3. Detect content differences, if the filename is still the same

  4. Generate include message for .changes logfile
"""

from io import TextIOWrapper
import os
import sys
import re
import argparse
import shutil
import distro
from collections import OrderedDict
from typing import Any


ORDERING_FILE = "patches.orders.txt"
CHANGES_FILE = "patches.changes.txt"

# Simplest detection if we are on Leap/SLES/RHEL or debian family. Add more IDs on demand.
IS_DEB: bool = distro.id() in ["ubuntu", "debian", "linuxmint"] and bool(
    shutil.which("dpkg")
)
IS_RPM: bool = not IS_DEB and bool(shutil.which("rpm"))


def remove_order(filename):
    """
    Remove order of the patch filename.

    Git formats patches: XXXX-filename.patch
    This function removes the "XXXX-" part, if any.
    """
    ordnum = os.path.basename(filename).split("-")[0]
    if ordnum and not re.sub(r"[0-9]", "", ordnum):
        filename = os.path.join(
            os.path.dirname(filename), filename.split("-", 1)[-1]
        ).lower()
        ordnum = int(ordnum)
    else:
        ordnum = None

    return ordnum, filename


def remove_order_from_subject(src_file, dst_file, use_unique=False) -> str:
    """
    Remove subject inside the patch.

    Git format patches inside with the following subject format:
    Subject: [PATCH X/Y] .........

    This function removes [PATCH X/Y] part, if any. In Git
    format-patches one can add "-N" flag, so then subject won't have
    these numbers, but just "[PATCH]". In this case we leave it out.

    Returns final written filename.
    """

    if os.path.exists(dst_file) and not use_unique:
        raise IOError("the file {0} exists".format(dst_file))

    if os.path.exists(dst_file) and use_unique:
        dst_file = unique(dst_file)
    with open(dst_file, "w", encoding="utf-8", errors="surrogateescape") as dst, open(
        src_file, encoding="utf-8", errors="surrogateescape"
    ) as src:
        for line in src:
            line_tk = re.split(r"\s+\[PATCH \d+/\d+\]\s+", line)
            if len(line_tk) == 2 and line_tk[0] == "Subject:":
                line = " [PATCH] ".join(line_tk)
            dst.write("{0}".format(line))

    return dst_file


def git_format_patch(tag):
    """
    Formats patches from the given tag.
    """
    patches = 0
    for patch in os.popen("git format-patch {0}".format(tag)).read().split(os.linesep):
        if patch.split(".")[-1] == "patch":
            patches += 1

    print("Patches fetched: {0}".format(patches))


def get_diff_contents(data):
    """
    Get diff contents only.
    """
    # Yes, I know about library https://github.com/cscorley/whatthepatch
    # But for now we go ultra-primitive to keep no deps
    data = "--".join(data.split("--")[:-1])
    contents = []
    for chunk in re.split(r"@@.*?@@.*?\n", data)[1:]:
        contents.append(chunk.split("diff --git")[0])

    return contents


def unique(fname):
    """
    Change name to the unique, in case it isn't.

    :param fname:
    :param use:
    :return:
    """
    fname = fname.split(".")
    if "-" not in fname[0]:
        fname[0] = "{0}-{1}".format(fname[0], 1)
    else:
        chnk = fname[0].split("-")
        try:
            fname[0] = "{0}-{1}".format("-".join(chnk[:-1]), int(chnk[-1]) + 1)
        except ValueError:
            # Filename is not in "str-int", but "str-str".
            fname[0] = "{0}-{1}".format(fname[0], 1)

    return ".".join(fname)


def extract_spec_source_patches(specfile):
    """
    Extracts source patches from the .spec file to match existing
    comments, according to the
    https://en.opensuse.org/openSUSE:Packaging_Patches_guidelines

    :param: specfile
    :return:
    """
    patch_sec_start = False
    patch_sec_end = False
    head_buff = []
    patch_section = []
    for spec_line in open(specfile).read().split(os.linesep):
        if re.match(r"^[Pp]atch[0-9]+:", spec_line) and not patch_sec_start:
            patch_sec_start = True

        if (
            not spec_line.startswith("#")
            and not re.match(r"^[Pp]atch[0-9]+:", spec_line)
            and patch_sec_start
            and not patch_sec_end
        ):
            patch_sec_end = True

        if not patch_sec_start and not patch_sec_end:
            head_buff.append(spec_line)

        if patch_sec_start and not patch_sec_end:
            patch_section.append(spec_line)

    first_comment = []
    for head_line in reversed(head_buff):
        if not head_line:
            break
        if head_line.startswith("#"):
            first_comment.append(head_line)
    patch_section.insert(0, os.linesep.join(first_comment))

    patchset: dict[str, list[Any]] = {}
    curr_key = None
    for line in reversed(patch_section):
        if re.match(r"^[Pp]atch[0-9]+:", line):
            curr_key = re.sub(r"^[Pp]atch[0-9]+:", "", line).strip()
            patchset[curr_key] = []
            continue
        if curr_key and line and line.startswith("#"):
            patchset[curr_key].append(line)

    return patchset


def do_remix_spec(args):
    """
    Remix spec file.

    :param args:
    :return:
    """
    if not os.path.exists(args.spec or ""):
        raise IOError(
            "Specfile {0} is not accessible or is somewhere else".format(args.spec)
        )
    if not os.path.exists(args.ordering or ""):
        args.ordering = "./{0}".format(ORDERING_FILE)
        if not os.path.exists(args.ordering):
            raise IOError(
                'Ordering file is expected "./{0}" but is not visible'.format(
                    ORDERING_FILE
                )
            )

    patchset = extract_spec_source_patches(args.spec)
    for o_line in open(args.ordering).read().split(os.linesep):
        if re.match(r"^[Pp]atch[0-9]+:", o_line):
            ref, pname = [_f for _f in o_line.split(" ") if _f]
            print(os.linesep.join(patchset.get(pname) or ["# Description N/A"]))
            print(ref.ljust(15), pname)


def do_create_patches(args):
    """
    Create and reformat patches for the package.
    """
    current_dir = os.path.abspath(".")

    if not args.existing:
        if os.listdir(current_dir):
            print("Error: this directory has to be empty!")
            sys.exit(1)

        git_format_patch(args.format)
    else:
        if not [fname for fname in os.listdir(current_dir) if fname.endswith(".patch")]:
            print(
                "Error: can't find a single patch in {0} to work with!".format(
                    current_dir
                )
            )
            sys.exit(1)

    ord_fh: TextIOWrapper = open(args.ordering or ORDERING_FILE, "w")
    ord_patches_p: OrderedDict[int, str] = OrderedDict()

    if IS_RPM:
        ord_fh.write("#\n#\n# This is pre-generated snippets of patch ordering\n#\n")

    patches = 0
    for fname in os.listdir(current_dir):
        if fname.split(".")[-1] == "patch":
            # Check if we should skip this patch in case subject starts with SKIP_TAG
            with open(fname, encoding="utf-8", errors="surrogatereplace") as patch_file:
                if any(
                    re.match(
                        r"^Subject: \[PATCH.*] {}".format(re.escape(args.skip_tag)), i
                    )
                    for i in patch_file.readlines()
                ):
                    print(f"Skipping {fname}")
                    os.unlink(fname)
                    continue

            print(f"Preparing {fname}")
            order, nfname = remove_order(fname)
            if args.index is not None:
                order += args.index
            nfname = remove_order_from_subject(fname, nfname, use_unique=args.increment)
            os.unlink(fname)
            if IS_RPM:
                ord_fh.write(
                    "{patch}{fname}\n".format(
                        patch=f"Patch{order}:".ljust(15), fname=nfname
                    )
                )
            ord_patches_p[order] = nfname

            patches += 1

    if ord_patches_p:
        if IS_RPM:
            ord_fh.write("#\n#\n# List of available patches:\n")
        else:
            ord_fh.write("#\n#\n# debian/series:\n")

        for order, fname in ord_patches_p.items():
            if IS_RPM:
                ord_fh.write(f"%patch{order} -p1\n")
            else:
                ord_fh.write(f"{fname} -p1\n")
    else:
        ord_fh.write("# No patches found\n")
    ord_fh.close()

    print(f"\nRe-formatted {patches} patch{patches > 1 and 'es' or ''}")


def do_update_patches(args):
    """
    Update patches on the target package source.
    """
    print(f"Updating packages from {args.update} directory")
    added = []
    removed = []
    changed = []

    # Gather current patches
    current_patches = {}
    for fname in os.listdir(os.path.abspath(".")):
        if fname.endswith(".patch"):
            current_patches[os.path.basename(fname)] = True

    for fname in os.listdir(args.update):
        if fname.endswith(".patch"):
            fname = os.path.join(args.update, fname)
            if os.path.isfile(fname):
                current_patches[os.path.basename(fname)] = False
                n_fname = os.path.basename(fname)
                if not os.path.exists(n_fname):
                    print(f"Adding {fname} patch")
                    shutil.copyfile(fname, os.path.join(os.path.abspath("."), n_fname))
                    added.append(n_fname)
                else:
                    with open(
                        fname, encoding="utf-8", errors="surrogateescape"
                    ) as updated_patch, open(
                        n_fname, encoding="utf-8", errors="surrogateescape"
                    ) as existing_patch:
                        if get_diff_contents(updated_patch.read()) != get_diff_contents(
                            existing_patch.read()
                        ):
                            if args.changed:
                                print(f"Replacing {n_fname} patch")
                                os.unlink(n_fname)
                                shutil.copyfile(
                                    fname, os.path.join(os.path.abspath("."), n_fname)
                                )
                                changed.append(n_fname)
                            else:
                                print(
                                    f"WARNING: Patches {fname} and {n_fname} are different!"
                                )

    for fname in sorted(
        [patch_name for patch_name, is_dead in list(current_patches.items()) if is_dead]
    ):
        print(f"Removing {fname} patch")
        os.unlink(fname)
        removed.append(fname)

    # Generate an include for spec changes
    with open(CHANGES_FILE, "w") as changes:
        for title, data in [
            ("Changed", changed),
            ("Added", added),
            ("Removed", removed),
        ]:
            if not data:
                continue
            print(f"- {title}:", file=changes)
            for fname in sorted(data):
                print(f"  * {fname}", file=changes)
            print(file=changes)

        if not removed and not added and not changes:
            print("No files has been changed")


def main():
    """
    Main app.
    """
    VERSION = "0.3"
    parser = argparse.ArgumentParser(description="Git patch formatter for RPM packages")
    parser.add_argument(
        "-u",
        "--update",
        action="store",
        const=None,
        help="update current patches with the destination path",
    )
    parser.add_argument(
        "-f",
        "--format",
        action="store",
        const=None,
        help="specify tag or range of commits for patches to be formatted",
    )
    parser.add_argument(
        "-o",
        "--ordering",
        action="store",
        const=None,
        help="specify ordering spec inclusion file. Default: {0}".format(ORDERING_FILE),
    )
    parser.add_argument(
        "-x",
        "--index",
        action="store",
        const=None,
        help="specify start ordering index. Default: 0",
    )
    parser.add_argument(
        "-s",
        "--spec",
        action="store",
        const=None,
        help="remix spec file and extract sources with their comments to match new patch ordering",
    )
    parser.add_argument(
        "-i",
        "--increment",
        action="store_const",
        const=True,
        help="use increments for unique names when patch commits repeated",
    )
    parser.add_argument(
        "-c",
        "--changed",
        action="store_const",
        const=True,
        help="update also changed files with the content",
    )
    parser.add_argument(
        "-e",
        "--existing",
        action="store_const",
        const=True,
        help="work with already formatted patches from Git",
    )
    parser.add_argument(
        "-k",
        "--skip-tag",
        action="store",
        const=None,
        default="[skip]",
        help="skip commits starting with this tag. Default: [skip]",
    )
    parser.add_argument(
        "-v", "--version", action="store_const", const=True, help="show version"
    )
    args = parser.parse_args()

    try:
        if args.index:
            try:
                args.index = int(args.index)
            except ValueError:
                raise Exception('Value "{0}" should be a digit'.format(args.index))

        if args.version:
            print("Version: {0}".format(VERSION))
        elif args.spec:
            do_remix_spec(args)
        elif args.update and not args.format:
            do_update_patches(args)
        elif (args.format and not args.update) or args.existing:
            do_create_patches(args)
        else:
            parser.print_help()
            sys.exit(1)
    except Exception as ex:
        print("Critical error:", ex, file=sys.stderr)


if __name__ == "__main__":
    main()
