#!/usr/bin/python3

from datetime import datetime
import argparse
import logging
import uvicorn
from dbus import SystemBus, Interface, DBusException
from mcp.server.fastmcp import FastMCP


logging.basicConfig(filename = "mcp-server-snapper.log", level = logging.DEBUG)


mcp = FastMCP("SnapperServer")


# TODO esp. when using http protocol the dbus connection to snapperd
# stays up forever


def get_snapper():
    """
    Create a snapper object
    """
    # TODO reuse - seems to work already
    bus = SystemBus()
    snapper = Interface(bus.get_object('org.opensuse.Snapper', '/org/opensuse/Snapper'),
                        dbus_interface = 'org.opensuse.Snapper')
    return snapper


def format_row(row):
    """
    Format an row in markdown. Row elements can be int or str
    """
    return " | ".join([ str(e).replace("|", "\\|") for e in row ])


@mcp.tool()
def list_configs() -> str:
    """
    List snapper configs.
    """

    try:
        snapper = get_snapper()

        configs = snapper.ListConfigs()

        result = [ format_row([ "Config", "Subvolume" ]) ]

        for config in configs:
            row = [ config[0], config[1] ]
            result.append(format_row(row))

        logging.error(f"list of snapper configs: {result}")

        return f"List of snapper configs as table with header: {result}"

    except DBusException as e:
        logging.error(f"Snapper error: {e}")
        return f"Error listing configs: {e}"


@mcp.tool()
def list_snapshots() -> str:
    """
    List file system snapshots using snapper.
    """

    try:
        snapper = get_snapper()

        snapshots = snapper.ListSnapshots("root")

        result = [ format_row([ "Number", "Type", "Date", "Cleanup", "Description" ]) ]

        for snapshot in snapshots:

            row = [ snapshot[0] ]

            if snapshot[1] == 0:
                row.append("single")
            elif snapshot[1] == 1:
                row.append("pre")
            elif snapshot[1] == 2:
                row.append("post")
            else:
                row.append("")

            if snapshot[3] != -1:
                row.append(datetime.fromtimestamp(snapshot[3]).strftime("%Y-%m-%d %H:%M:%S"))
            else:
                row.append("")

            row.append(snapshot[6])
            row.append(snapshot[5])

            result.append(format_row(row))

        logging.error(f"list of snapper snapshots: {result}")

        return f"List of snapper snapshots as table with header: {result}"

    except DBusException as e:
        logging.error(f"Snapper error: {e}")
        return f"Error listing snapshots: {e}"


@mcp.tool()
def create_snapshot(type: str, pre_number: int, description: str, cleanup: str) -> str:
    """
    Create a file system snapshot using snapper.
    :param type: Type for the snapshot, either 'single', 'pre' or 'post'.
    :param pre_number: Number of the corresponding pre snapshot. Required if type is 'post',
           otherwise ignored.
    :param description: Description for the snapshot.
    :param cleanup: Cleanup algorithm for the snapshot like 'number' or 'timeline'.
    """

    try:
        snapper = get_snapper()

        userdata = {}

        if type == "single":
            number = snapper.CreateSingleSnapshot("root", description, cleanup, userdata)
        elif type == "pre":
            number = snapper.CreatePreSnapshot("root", description, cleanup, userdata)
        elif type == "post":
            number = snapper.CreatePostSnapshot("root", pre_number, description, cleanup, userdata)
        else:
            logging.error(f"Invalid snapshot type: {type}")
            return "Error creating snapshot: Invalid snapshot type"

        logging.info(f"snapper number of new snapshot: {number}")

        return f"Snapper number of new snapshot: {number}"

    except DBusException as e:
        logging.error(f"Snapper error: {e}")
        return f"Error creating snapshot: {e}"


if __name__ == "__main__":
    logging.info("Server started")

    parser = argparse.ArgumentParser(description = "Run the MCP time server.")

    parser.add_argument("--transport", choices = [ "stdio", "http", "https" ], default = "stdio",
                        help = "Transport type (default: stdio)")

    parser.add_argument("--port", type = int, default = 8000, help = "Port for HTTP")

    parser.add_argument("--key", type = str, help = "Key for HTTPS")
    parser.add_argument("--cert", type = str, help = "Cert for HTTPS")

    args = parser.parse_args()

    if args.transport == "https":
        if not args.key or not args.cert:
            parser.error("--transport https requires both --key and --cert")

    if args.transport == "http":
        uvicorn.run(mcp.sse_app(), host = "0.0.0.0", port = args.port)
    elif args.transport == "https":
        uvicorn.run(mcp.sse_app(), host = "0.0.0.0", port = args.port, ssl_keyfile = args.key,
                    ssl_certfile = args.cert)
    else:
        mcp.run(transport = "stdio")
