#!/usr/bin/python3
# -*- coding: utf-8 -*-
# pylint: disable=invalid-name
"""
saltbroker: A ZeroMQ Proxy (broker) for Salt Minions

The main process spawns a process for each channel of Salt ZMQ transport:

- PubChannelProxy process provides the PUB channel for the minions
- RetChannelProxy process provides the RET channel for the minions

Also acts like a supervisor for the child process, respawning them if they die.

:depends:   python-PyYAML
:depends:   python-pyzmq

Copyright (c) 2016--2025 SUSE LLC

All modifications and additions to the file contributed by third parties
remain the property of their copyright owners, unless otherwise agreed
upon. The license for this file, and modifications and additions to the
file, is the same license as for the pristine package itself (unless the
license for the pristine package is not an Open Source License, in which
case the license is the MIT License). An "Open Source License" is a
license that conforms to the Open Source Definition (Version 1.9)
published by the Open Source Initiative.

Please submit bugfixes or comments via http://bugs.opensuse.org/
"""

# Import python libs
import ipaddress
import logging
import logging.handlers
import multiprocessing
import os
import signal
import socket
import sys
import threading
import time
import traceback
import yaml

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)


# Import pyzmq lib
import zmq

from zmq.utils.monitor import recv_monitor_message

UYUNI_CONFIG_FILE = os.environ.get("UYUNI_CONFIG_FILE", "/etc/uyuni/config.yaml")
SALT_BROKER_CONF_FILE = os.environ.get("SALT_BROKER_CONF_FILE", "/etc/salt/broker")
SALT_BROKER_LOGFILE = os.environ.get("SALT_BROKER_LOGFILE", "/var/log/salt/broker")
SUPERVISOR_TIMEOUT = 5


def ip_bracket(addr, strip=False):
    """
    Ensure IP addresses are URI-compatible - specifically, add brackets
    around IPv6 literals if they are not already present.
    """
    addr = str(addr)
    addr = addr.lstrip("[")
    addr = addr.rstrip("]")
    addr = ipaddress.ip_address(addr)
    return ("[{}]" if addr.version == 6 and not strip else "{}").format(addr)


class AbstractChannelProxy(multiprocessing.Process):
    """
    Abstract class for ChannelProxy objects
    """

    class ChannelException(Exception):
        """
        Custom Exception definition
        """

        pass

    _BACKEND_SOCKOPTS = (
        ("TCP_KEEPALIVE", "tcp_keepalive"),
        ("TCP_KEEPALIVE_IDLE", "tcp_keepalive_idle"),
        ("TCP_KEEPALIVE_CNT", "tcp_keepalive_cnt"),
        ("TCP_KEEPALIVE_INTVL", "tcp_keepalive_intvl"),
        ("CONNECT_TIMEOUT", "connect_timeout"),
        ("RECONNECT_IVL", "reconnect_ivl"),
        ("HEARTBEAT_IVL", "heartbeat_ivl"),
        ("HEARTBEAT_TIMEOUT", "heartbeat_timeout"),
    )
    _FRONTEND_SOCKOPTS = ()

    def __init__(self, opts):
        self.opts = opts
        self.backend_connected = False
        if "master" not in self.opts:
            raise self.ChannelException(
                # pylint: disable-next=consider-using-f-string
                '[{}] No "master" opts is provided'.format(self.__class__.__name__)
            )
        try:
            self.opts["master_ip"] = socket.gethostbyname(self.opts["master"])
        except socket.gaierror as exc:
            raise self.ChannelException(
                # pylint: disable-next=consider-using-f-string
                "[{}] Error trying to resolve '{}': {}".format(
                    self.__class__.__name__, self.opts["master"], exc
                )
            )
        super().__init__()

    def run(self):
        try:
            context = zmq.Context()

            log.debug(
                # pylint: disable-next=logging-format-interpolation,consider-using-f-string
                "Setting up a {} sock on {}".format(
                    self.backend_type, self._backend_uri
                )
            )
            self.backend = context.socket(self._backend_sock_type)
            self.set_sockopts(
                self.backend,
                self._BACKEND_SOCKOPTS,
                self.backend_type,
                self.opts["master_ip"],
            )

            self.reconnect_retries = self.opts["drop_after_retries"]
            if self.reconnect_retries != -1:
                self.monitor_socket = self.backend.get_monitor_socket()
                self.monitor_thread = threading.Thread(
                    target=self.backend_socket_monitor, args=(self.monitor_socket,)
                )
                self.monitor_thread.start()

            self.backend.connect(self._backend_uri)

            if self.opts["wait_for_backend"]:
                while not self.backend_connected:
                    if self.backend.closed:
                        log.warning(
                            "Backend %s socket was closed while waiting for it. Terminating...",
                            self.backend_type,
                        )
                        return
                    time.sleep(0.5)

            log.debug(
                # pylint: disable-next=logging-format-interpolation,consider-using-f-string
                "Setting up a {} sock on {}".format(
                    self.frontend_type, self._frontend_uri
                )
            )

            self.frontend = context.socket(self._frontend_sock_type)
            self.set_sockopts(
                self.frontend,
                self._FRONTEND_SOCKOPTS,
                self.frontend_type,
                self.opts["interface"],
            )

            self.frontend.bind(self._frontend_uri)

            # Forward all messages
            log.info(
                "Staring ZMQ proxy on %s and %s sockets",
                self.frontend_type,
                self.backend_type,
            )
            try:
                zmq.proxy(self.frontend, self.backend)
            # pylint: disable-next=broad-exception-caught,unused-variable
            except Exception as e:
                log.error(
                    "Error while processing proxy with %s and %s sockets. Terminating...",
                    self.frontend_type,
                    self.backend_type,
                )
                return

        except zmq.ZMQError as zmq_error:
            if self.reconnect_retries == 0:
                # Do not raise error if drop_after_retries was used
                return
            # pylint: disable-next=consider-using-f-string
            msg = "ZMQ Error: {}".format(zmq_error)
            log.error(msg)
            raise self.ChannelException(msg)

        # pylint: disable-next=broad-exception-caught
        except Exception as exc:
            log.error("Exception: %s", exc)
            log.debug("Traceback: %s", traceback.format_exc())

    # pylint: disable-next=redefined-outer-name
    def set_sockopts(self, socket, sockopts, sock_type, addr=None):

        # pylint: disable-next=invalid-name
        def __set_sock_opt(opt, val, opt_name, sock_type):
            log.debug("Setting socket opt %s to %s on %s", opt_name, val, sock_type)
            socket.setsockopt(opt, val)

        for opt_name, opt_src in sockopts:
            opt = getattr(zmq, opt_name, None)
            if opt is None:
                log.error("Unable to ZMQ doesn't have %s socket opt", opt_name)
                continue
            if opt_src in self.opts:
                opt_val = self.opts[opt_src]
            elif isinstance(opt_src, tuple) and len(opt_src) == 1:
                opt_val = opt_src[0]
            else:
                log.error("Unable to get the value for socket opt %s", opt_name)
                continue
            __set_sock_opt(opt, opt_val, opt_name, sock_type)
        if (
            self.opts["ipv6"] is True or (addr is not None and ":" in addr)
        ) and hasattr(zmq, "IPV4ONLY"):
            # IPv6 sockets work for both IPv6 and IPv4 addresses
            __set_sock_opt(zmq.IPV4ONLY, 0, "IPV4ONLY", sock_type)
        if (addr is not None and ":" in addr) and hasattr(zmq, "IPV6"):
            __set_sock_opt(zmq.IPV6, 1, "IPV6", sock_type)

    def backend_socket_monitor(self, monitor_socket):
        while monitor_socket.poll():
            mon_evt = recv_monitor_message(monitor_socket)
            if self.reconnect_retries != -1:
                if mon_evt["event"] == zmq.EVENT_DISCONNECTED:
                    # pylint: disable-next=logging-format-interpolation,consider-using-f-string
                    log.warning("{} socket disconnected".format(self.backend_type))
                    self.backend_connected = False
                elif mon_evt["event"] == zmq.EVENT_CONNECTED:
                    # pylint: disable-next=logging-format-interpolation,consider-using-f-string
                    log.info("{} socket connected".format(self.backend_type))
                    self.backend_connected = True
                    self.reconnect_retries = self.opts["drop_after_retries"]
                elif mon_evt["event"] == zmq.EVENT_CONNECT_RETRIED:
                    if self.reconnect_retries == 0:
                        log.warning(
                            # pylint: disable-next=logging-format-interpolation,consider-using-f-string
                            "Closing {} socket due to retry attempts reached!".format(
                                self.backend_type
                            )
                        )
                        self.backend.close()
                        break
                    else:
                        self.reconnect_retries -= 1
            if mon_evt["event"] == zmq.EVENT_MONITOR_STOPPED:
                break
        monitor_socket.close()

    def terminate(self):
        """
        custom terminate function for the child process
        """
        log.info("Terminate called. Exiting")
        super().terminate()


class PubChannelProxy(AbstractChannelProxy):
    """
    Salt PUB Channel Proxy.

    Subscribes to the zmq PUB channel in the Salt master and binds a zmq SUB
    socket that allows minion to subscribe it and receive the forwarded
    messages from the Salt master.
    """

    # Prevent stopping publishing messages on XPUB socket. (bsc#1182954)
    _FRONTEND_SOCKOPTS = (
        ("XPUB_VERBOSE", (1,)),
        ("XPUB_VERBOSER", (1,)),
    )

    def __init__(self, opts):
        super().__init__(opts)
        self.name = "PubChannelProxy"

        self._backend_sock_type = zmq.XSUB
        self._frontend_sock_type = zmq.XPUB

        self.backend_type = "XSUB"
        self.frontend_type = "XPUB"

        # pylint: disable-next=consider-using-f-string
        self._backend_uri = "tcp://{}:{}".format(
            ip_bracket(self.opts["master_ip"]),
            self.opts["publish_port"],
        )
        # pylint: disable-next=consider-using-f-string
        self._frontend_uri = "tcp://{}:{}".format(
            ip_bracket(self.opts["interface"]),
            self.opts["publish_port"],
        )


class RetChannelProxy(AbstractChannelProxy):
    """
    Salt RET Channel Proxy.

    Connects to the zmq RET channel in the Salt master and binds a zmq ROUTER
    socket to receive messages from minions which are then forwarded to
    the Salt master.
    """

    def __init__(self, opts):
        super().__init__(opts)
        self.name = "RetChannelProxy"

        self._backend_sock_type = zmq.DEALER
        self._frontend_sock_type = zmq.ROUTER

        self.backend_type = "DEALER"
        self.frontend_type = "ROUTER"

        # pylint: disable-next=consider-using-f-string
        self._backend_uri = "tcp://{}:{}".format(
            ip_bracket(self.opts["master_ip"]),
            self.opts["ret_port"],
        )
        # pylint: disable-next=consider-using-f-string
        self._frontend_uri = "tcp://{}:{}".format(
            ip_bracket(self.opts["interface"]),
            self.opts["ret_port"],
        )


class SaltBroker(object):
    """
    Creates a SaltBroker that forward messages and responses from
    minions to Salt Master by creating a ZeroMQ proxy that manage
    the PUB/RET channels of the Salt ZMQ transport.
    """

    def __init__(self, opts):
        log.debug("Readed config: %s", opts)
        self.opts = opts
        self.exit = False
        self.default_sigterm = signal.getsignal(signal.SIGTERM)
        self.pub_proxy_proc = None
        self.ret_proxy_proc = None
        super().__init__()

    def _start_pub_proxy(self):
        """
        Spawn a new PubChannelProxy process
        """
        # setting up the default SIGTERM handler for the new process
        signal.signal(signal.SIGTERM, self.default_sigterm)

        # Spawn a new PubChannelProxy process
        pub_proxy = PubChannelProxy(opts=self.opts)
        pub_proxy.start()

        # setting up again the custom SIGTERM handler
        signal.signal(signal.SIGTERM, self.sigterm_clean)

        log.info("Spawning PUB channel proxy process [PID: %s]", pub_proxy.pid)

        return pub_proxy

    def _start_ret_proxy(self):
        """
        Spawn a new RetChannelProxy process
        """
        # setting up the default SIGTERM handler for the new process
        signal.signal(signal.SIGTERM, self.default_sigterm)

        # Spawn a new RetChannelProxy process
        ret_proxy = RetChannelProxy(opts=self.opts)
        ret_proxy.start()

        # setting up again the custom SIGTERM handler
        signal.signal(signal.SIGTERM, self.sigterm_clean)

        log.info("Spawning RET channel proxy process [PID: %s]", ret_proxy.pid)

        return ret_proxy

    # pylint: disable-next=unused-argument
    def sigterm_clean(self, signum, frame):
        """
        Custom SIGTERM handler
        """
        log.info("Caught signal %s, stopping all channels", signum)

        if self.pub_proxy_proc:
            self.pub_proxy_proc.terminate()
        if self.ret_proxy_proc:
            self.ret_proxy_proc.terminate()

        self.exit = True
        log.info("Terminating main process")

    def start(self):
        """
        Starts a SaltBroker. It spawns the PubChannelProxy and
        RetChannelProxy processes and also acts like a supervisor
        of these child process respawning them if they died.
        """
        log.info("Starting Salt ZeroMQ Proxy [PID: %s]", os.getpid())

        # Attach a handler for SIGTERM signal
        signal.signal(signal.SIGTERM, self.sigterm_clean)

        try:
            self.pub_proxy_proc = self._start_pub_proxy()
            self.ret_proxy_proc = self._start_ret_proxy()
        except AbstractChannelProxy.ChannelException as exc:
            log.error("Exception: %s", exc)
            log.error("Exiting")
            sys.exit(exc)

        # Supervisor. Restart a channel if died
        while not self.exit:
            if not self.pub_proxy_proc.is_alive():
                log.error("PUB channel proxy has died. Respawning")
                self.pub_proxy_proc = self._start_pub_proxy()
            if not self.ret_proxy_proc.is_alive():
                log.error("RET channel proxy has died. Respawning")
                self.ret_proxy_proc = self._start_ret_proxy()
            time.sleep(SUPERVISOR_TIMEOUT)


if __name__ == "__main__":
    # Try to get config from /etc/uyuni/config.yaml
    rhn_parent = None
    if os.path.isfile(UYUNI_CONFIG_FILE):
        try:
            uyuni_config = yaml.load(open(UYUNI_CONFIG_FILE), Loader=yaml.SafeLoader)
            if uyuni_config:
                rhn_parent = uyuni_config["server"]
        except yaml.scanner.ScannerError as exc:
            # pylint: disable-next=consider-using-f-string
            log.error("Error reading YAML config file: %s", exc)
    if not rhn_parent:
        log.debug(
            "Missing or invalid config file %s, running as standalone",
            UYUNI_CONFIG_FILE,
        )

    # Check for the config file
    if not os.path.isfile(SALT_BROKER_CONF_FILE):
        # pylint: disable-next=consider-using-f-string
        sys.exit("Config file not found: {0}".format(SALT_BROKER_CONF_FILE))

    # default config
    _DEFAULT_OPTS = {
        "publish_port": "4505",
        "ret_port": "4506",
        "interface": "0.0.0.0",
        "ipv6": False,
        "tcp_keepalive": True,
        "tcp_keepalive_idle": 300,
        "tcp_keepalive_cnt": -1,
        "tcp_keepalive_intvl": -1,
        "log_to_file": 1,
        "connect_timeout": 0,
        "reconnect_ivl": 100,
        "heartbeat_ivl": 0,
        "heartbeat_timeout": 0,
        "drop_after_retries": -1,
        "wait_for_backend": False,
    }

    try:
        # pylint: disable-next=unspecified-encoding
        config = yaml.load(open(SALT_BROKER_CONF_FILE), Loader=yaml.SafeLoader)
        if not config:
            config = {}
        if not isinstance(config, dict):
            # pylint: disable-next=consider-using-f-string
            sys.exit("Bad format in config file: {0}".format(SALT_BROKER_CONF_FILE))

        saltbroker_opts = _DEFAULT_OPTS.copy()

        if rhn_parent:
            saltbroker_opts.update({"master": rhn_parent})

        saltbroker_opts.update(config)

        formatter = logging.Formatter(
            "%(asctime)s [%(levelname)-8s][%(processName)-16s][%(process)s] %(message)s",
        )
        # log to file or to standard output and error depending on the configuration
        if saltbroker_opts.get("log_to_file"):
            fileloghandler = logging.handlers.RotatingFileHandler(
                SALT_BROKER_LOGFILE, maxBytes=200000, backupCount=5
            )
            fileloghandler.setFormatter(formatter)
            log.addHandler(fileloghandler)
        else:
            # prepare two log handlers, 1 for stdout and 1 for stderr
            stdout_handler = logging.StreamHandler(sys.stdout)
            stderr_handler = logging.StreamHandler(sys.stderr)
            # stdout handler filters out everything above the ERROR level included
            stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
            # stderror handler looks only for everything above the ERROR level included
            stderr_handler.setLevel(logging.ERROR)
            # same format for both handlers
            stdout_handler.setFormatter(formatter)
            stderr_handler.setFormatter(formatter)
            # add handlers to log Object
            log.addHandler(stdout_handler)
            log.addHandler(stderr_handler)

        proxy = SaltBroker(opts=saltbroker_opts)
        proxy.start()

    except yaml.scanner.ScannerError as exc:
        # pylint: disable-next=consider-using-f-string
        sys.exit("Error reading YAML config file: {0}".format(exc))
