#!/usr/bin/python3
"""
# SPDX-FileCopyrightText: 2026 Ralf Habacker
# SPDX-License-Identifier: GPL-2.0-or-later
#
# /usr/libexec/zypper-restart-services/restart-services
# Automatically restarts services using deleted or updated files after zypper update
"""

import logging
import subprocess
import sys
from pathlib import Path

LOGFILE = Path("/var/log/zypper-restart-services.log")

# Logging setup
logger = logging.getLogger("zypper-restart-services")
logger.setLevel(logging.INFO)

# Log to file (persistent)
file_handler = logging.FileHandler(LOGFILE, mode="a")
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
logger.addHandler(file_handler)

# Optional console logging if verbose
VERBOSE = "--verbose" in sys.argv
if VERBOSE:
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
    logger.addHandler(console_handler)

IGNORE_SERVICES = ["display-manager.service"]

def run(cmd, capture=False, check=False):
    try:
        if capture:
            result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, check=check)
            return result.stdout.strip(), result.returncode
        else:
            return subprocess.run(cmd, check=check).returncode
    except subprocess.CalledProcessError as e:
        logger.error("Command failed: %s (code %d)", cmd, e.returncode)
        return "", e.returncode

def parse_services_from_zypper():
    output, retcode = run(["zypper", "ps", "-s"], capture=True)
    if retcode == 102:
        logger.info("Core libraries or services updated: a reboot is suggested.")

    services = []
    for line in output.splitlines():
        if "|" not in line or line.startswith("PID") or set(line.strip()) <= {"-", "+"}:
            continue
        parts = [p.strip() for p in line.split("|")]
        if len(parts) < 6:
            continue
        _, _, _, _, command, service = parts[:6]
        if service and all(c.isalnum() or c in "_.@:-" for c in service):
            services.append(service)
        elif command == "systemd":
            services.append("systemd")
    return sorted(set(services))

def systemd_unit_exists(unit):
    output, _ = run(["systemctl", "list-unit-files"], capture=True)
    return any(line.startswith(unit) for line in output.splitlines())

def handle_dbus():
    logger.info("Restarting dbus.socket...")
    run(["systemctl", "restart", "dbus.socket"])
    deps, _ = run(["systemctl", "list-dependencies", "--reverse", "dbus.service", "--plain"], capture=True)
    regs, _ = run(["busctl", "list", "--no-pager"], capture=True)
    extra = []
    for line in deps.splitlines() + regs.splitlines():
        line = line.strip()
        if line.endswith(".service") and line != "dbus.service":
            extra.append(line.replace(".service", ""))
            logger.info("DBus affected service: %s", line)
    return extra

def handle_systemd():
    logger.info("Re-executing systemd daemon")
    run(["systemctl", "daemon-reexec"])

def get_user_systemds():
    output, _ = run(["loginctl", "list-users", "--no-legend"], capture=True)
    users = [line.split()[1] for line in output.splitlines() if line.strip()]
    return sorted(set(users))

def restart_services(services):
    for svc in services:
        unit = f"{svc}.service"
        if svc in IGNORE_SERVICES:
            logger.info("Skipping %s (always ignored)", unit)
            continue
        if not systemd_unit_exists(unit):
            logger.warning("Skipping %s (not a valid systemd unit)", unit)
            continue
        run(["systemctl", "try-restart", unit])

def main():
    logger.info("Checking for processes using deleted or updated files...")
    services = parse_services_from_zypper()
    if not services:
        logger.info("No affected services found. Nothing to restart.")
        return 0
    logger.info("Affected services: %s", ", ".join(services))

    if "dbus" in services:
        extra = handle_dbus()
        services = sorted(set(services + extra))
        services.remove("dbus")

    if "systemd" in services:
        handle_systemd()
        services.remove("systemd")

    users = get_user_systemds()
    for u in users:
        logger.info("Detected user systemd instance: %s", u)
        logger.info("Suggested action: loginctl restart-user %s", u)

    restart_services(services)

    logger.info("All affected services processed. Verify with: zypper ps -s")
    return 0

if __name__ == "__main__":
    sys.exit(main())
