#!/usr/bin/env python3

# Upstream: https://github.com/stbuehler/monitoring-plugins-stbuehler
# Copyright: Stefan Bühler
# License: MIT

from abc import abstractmethod
import argparse
import dataclasses
import dns
import dns.asyncquery
import dns.dnssec
import dns.exception
import dns.message
import dns.name
import dns.rcode
import dns.rdata
import dns.rdataclass
import dns.rdataset
import dns.rdatatype
import dns.rdtypes.ANY.CNAME
import dns.rdtypes.ANY.DNSKEY
import dns.rdtypes.ANY.DS
import dns.rdtypes.ANY.NS
import dns.rdtypes.ANY.SOA
import dns.resolver
import dns.rrset
import ipaddress
import sys
import trio
import typing


IPAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
DomainName = str | dns.name.Name


# dns.rrset.RRset
T = typing.TypeVar('T', bound=dns.rdata.Rdata)
class RRset(typing.Iterable[T], typing.Protocol[T]):
    # dns.rdataset.Rdataset:
    rdclass: int # dns.rdataclass.?
    rdtype: int # dns.rdatatype.?
    ttl: int
    covers: int # dns.rdatatype.?
    # dns.rrset.RRset:
    name: DomainName
    # deleting: dns.rdataclass.? | None

    @abstractmethod
    def add(self, rd: T, ttl: int | None = None) -> None:
         ...


class IcingaResult:
    def __init__(self, debug: bool = False) -> None:
        self._debug = debug
        self._cmp_code = -1.0
        self._code = -1
        self._title = 'OK'
        self._log: list[str] = []

    def _set(self, code: int, cmp_code: float, title: str) -> None:
        self._cmp_code = cmp_code
        self._code = code
        self._title = title

    def _add(self, code: int, cmp_code: float, msg: str, title: str) -> None:
        assert 0 <= code <= 3
        assert 0 <= cmp_code <= 2
        if cmp_code > self._cmp_code:
            self._set(code, cmp_code, msg)
        self._log.append(msg)

    def log(self, msg: str) -> None:
        self._log.append(msg)

    def debug(self, msg: str) -> None:
        if self._debug:
            self._log.append(msg)

    def info(self, msg: str) -> None:
        self._add(0, 0.0, msg, f'OK: {msg}')

    def warn(self, msg: str) -> None:
        msg = f'WARNING: {msg}'
        self._add(1, 1.0, msg, msg)

    def error(self, msg: str) -> None:
        msg = f'ERROR: {msg}'
        self._add(2, 2.0, msg, msg)

    def unknown(self, msg: str) -> None:
        msg = f'UNKNOWN: {msg}'
        self._add(3, 0.5, msg, msg)

    def finish(self) -> typing.NoReturn:
        print(f'{self._title}||')
        for line in self._log:
            print(line)
        if self._code == -1:
            self._code = 0
        sys.exit(self._code)


class DnsMessageError(dns.exception.DNSException):
    def __init__(self, query: dns.message.Message, response: dns.message.Message, server: IPAddress) -> None:
        self.query = query
        self.response = response
        self.server = server

    def __str__(self) -> str:
        query = self.query.question[0]
        return f"Query @{self.server} {query} failed with {dns.rcode.to_text(self.response.rcode())}"


class DnsTimeoutError(dns.exception.DNSException):
    def __init__(self, query: dns.message.Message, server: IPAddress) -> None:
        self.query = query
        self.server = server

    def __str__(self) -> str:
        query = self.query.question[0]
        return f"Query @{self.server} {query} timed out"


class DnsConnectionClosedError(dns.exception.DNSException):
    # created from various "server didn't like us" errors (closed through various means)
    def __init__(self, query: dns.message.Message, server: IPAddress, error: Exception) -> None:
        self.query = query
        self.server = server
        self.error = error

    def __str__(self) -> str:
        query = self.query.question[0]
        return f"Query @{self.server} {query} failed: server closed connection too early {self.error}"


class DnsOsError(dns.exception.DNSException):
    def __init__(self, query: dns.message.Message, server: IPAddress, error: OSError) -> None:
        self.query = query
        self.server = server
        self.error = error

    def __str__(self) -> str:
        query = self.query.question[0]
        return f"Query @{self.server} {query} failed: {self.error}"


@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class Delegation:
    nameservers_tuple: tuple[tuple[dns.name.Name, typing.FrozenSet[IPAddress]], ...]

    @staticmethod
    def _build(nameservers: dict[dns.name.Name, set[IPAddress]]) -> 'Delegation':
        return Delegation(
            nameservers_tuple=tuple((ns, frozenset(nameservers[ns])) for ns in sorted(nameservers.keys())),
        )

    @property
    def nameservers(self) -> dict[dns.name.Name, typing.FrozenSet[IPAddress]]:
        return dict(self.nameservers_tuple)

    def to_lines(self) -> typing.Generator[str, None, None]:
        for ns, addrs in self.nameservers_tuple:
            if addrs:
                yield f'* NS {ns} with glue ' + ' '.join(map(str, addrs))
            else:
                yield f'* NS {ns} - no glue'


@dataclasses.dataclass(kw_only=True, slots=True)
class SoaSets:
    soa_sets: dict[dns.rdtypes.ANY.SOA.SOA, set[IPAddress]] = dataclasses.field(default_factory=dict)
    serials: set[int] = dataclasses.field(default_factory=set)
    mnames: set[dns.name.Name] = dataclasses.field(default_factory=set)
    _min_soas: set[dns.rdtypes.ANY.SOA.SOA] = dataclasses.field(default_factory=set)

    def insert(self, soa: dns.rdtypes.ANY.SOA.SOA, address: IPAddress) -> None:
        # remember full SOAs
        self.soa_sets.setdefault(soa, set()).add(address)
        # remember serials and mnames
        self.serials.add(soa.serial)
        self.mnames.add(soa.mname)
        # remember minimized SOA (without mname and serial) to count "really distinct" SOAs
        self._min_soas.add(typing.cast(dns.rdtypes.ANY.SOA.SOA, soa.replace(serial=1, mname=dns.name.root)))

    def main_soa(self) -> dns.rdtypes.ANY.SOA.SOA | None:
        if self.distinct_soas != 1:
            return None
        base_soa = next(iter(self._min_soas))
        # combination of serial and mname not necessarily seen in self.soa_sets
        return typing.cast(dns.rdtypes.ANY.SOA.SOA, base_soa.replace(
            serial=max(self.serials),
            mname=next(iter(self.mnames)),  # select random mname
        ))

    @property
    def distinct_soas(self) -> int:
        return len(self._min_soas)


ActionResult = typing.TypeVar('ActionResult')
async def retry_action(
    action: typing.Callable[[], typing.Awaitable[ActionResult]],
    *,
    tries: int = 3,
    timeout: float = 2.0,
    continue_exceptions: tuple[type[Exception], ...] = (),
) -> ActionResult:
    exc: Exception | None = None
    result: tuple[()] | tuple[ActionResult] = ()

    if tries < 1:
        raise ValueError(f"Can't try less than one time")

    async def single(cancel_scope: trio.CancelScope) -> None:
        nonlocal exc, result
        try:
            result = (await action(),)
            cancel_scope.cancel()
        except trio.Cancelled:
            if not (result or exc):
                raise
        except Exception as e:
            if not isinstance(e, continue_exceptions):
                # there might be previous exceptions, but they were considered less important
                # (we continued running after all)
                exc = e
                cancel_scope.cancel()
            elif exc is None:
                exc = e

    async with trio.open_nursery() as nursery:
        try:
            for _ in range(tries):
                nursery.start_soon(single, nursery.cancel_scope)
                await trio.sleep(timeout)
        except trio.Cancelled:
            if not (result or exc):
                raise

    if result:
        return result[0]
    elif exc:
        raise exc
    else:
        raise RuntimeError("internal error: Missing both exception and result")


class DnsLookup:
    @staticmethod
    def _get_first_system_resolver() -> IPAddress:
        resolver = dns.resolver.get_default_resolver()
        return ipaddress.ip_address(resolver.nameservers[0])

    @staticmethod
    def _extract_answer(query: dns.message.Message, response: dns.message.Message) -> dns.rrset.RRset | None:
        q_rr: dns.rrset.RRset = query.question[0]
        rr = response.get_rrset(dns.message.ANSWER, q_rr.name, q_rr.rdclass, q_rr.rdtype)
        if rr:
            return rr
        if q_rr.rdtype == dns.rdatatype.NS:
            rr = response.get_rrset(dns.message.AUTHORITY, q_rr.name, q_rr.rdclass, q_rr.rdtype)
            if rr:
                return rr
        cname_rr: dns.rrset.RRset | None = response.get_rrset(dns.message.ANSWER, q_rr.name, dns.rdataclass.ANY, dns.rdatatype.CNAME)
        if not cname_rr:
            return None
        assert len(cname_rr) == 1
        cname_entry: dns.rdtypes.ANY.CNAME.CNAME = cname_rr[0]
        return response.get_rrset(dns.message.ANSWER, cname_entry.target, q_rr.rdclass, q_rr.rdtype)

    def __init__(self, *, server: IPAddress | None = None, tcp_only: bool = False) -> None:
        if server is None:
            server = DnsLookup._get_first_system_resolver()
        self._tcp_only = tcp_only
        self._server = server
        self._server_str = str(server)

    async def _single_query(self, query: dns.message.Message) -> tuple[dns.message.Message, bool]:
        response: dns.message.Message
        if self._tcp_only:
            used_tcp = True
            response = await dns.asyncquery.tcp(query, self._server_str)
        else:
            (response, used_tcp) = await dns.asyncquery.udp_with_fallback(query, self._server_str)
        return (response, used_tcp)

    async def query(self, query: dns.message.Message) -> tuple[dns.rrset.RRset | None, dns.message.Message, bool]:
        try:
            with trio.fail_after(9):
                (response, used_tcp) = await retry_action(
                    lambda: self._single_query(query),
                    continue_exceptions=(EOFError, trio.BrokenResourceError, OSError),
                )
        except (EOFError, trio.BrokenResourceError) as error:
            # with TCP query: server closed connection without answering / too early
            raise DnsConnectionClosedError(query=query, server=self._server, error=error)
        except OSError as error:
            # sadly need a quite generic catch here to handle "no route to host" (and probably similar others)
            raise DnsOsError(query=query, server=self._server, error=error)
        except trio.TooSlowError:
            raise DnsTimeoutError(query=query, server=self._server)
        rcode = response.rcode()
        if not rcode in (dns.rcode.NOERROR, dns.rcode.NXDOMAIN):
            raise DnsMessageError(query=query, response=response, server=self._server)
        answer = DnsLookup._extract_answer(query, response)
        return (answer, response, used_tcp)

    async def soa(self, name: DomainName) -> RRset[dns.rdtypes.ANY.SOA.SOA] | None:
        soa_set = (await self.query(dns.message.make_query(name, dns.rdatatype.SOA)))[0]
        if soa_set is None:
            return None
        assert len(soa_set) == 1
        for soa in soa_set:
            assert isinstance(soa, dns.rdtypes.ANY.SOA.SOA)
        return typing.cast(RRset[dns.rdtypes.ANY.SOA.SOA], soa_set)

    async def ns(self, name: DomainName) -> RRset[dns.rdtypes.ANY.NS.NS] | None:
        ns_set = (await self.query(dns.message.make_query(name, dns.rdatatype.NS)))[0]
        if ns_set is None:
            return None
        for ns in ns_set:
            assert isinstance(ns, dns.rdtypes.ANY.NS.NS)
        return typing.cast(RRset[dns.rdtypes.ANY.NS.NS], ns_set)

    async def delegation(self, name: DomainName) -> Delegation | None:
        if isinstance(name, str):
            name = dns.name.from_text(name)
        ns_set, response, _ = await self.query(dns.message.make_query(name, dns.rdatatype.NS))
        if ns_set is None:
            return None
        nameservers: dict[dns.name.Name, set[IPAddress]] = {}
        for ns in ns_set:
            assert isinstance(ns, dns.rdtypes.ANY.NS.NS)
            addrs: set[IPAddress] = set()
            # only accept glue records (addresses for name) if name lives in delegated zone:
            if ns.target.is_subdomain(name):
                a_set = response.get_rrset(dns.message.ADDITIONAL, ns.target, dns.rdataclass.IN, dns.rdatatype.A)
                if a_set:
                    for a in a_set:
                        addrs.add(ipaddress.IPv4Address(a.address))
                aaaa_set = response.get_rrset(dns.message.ADDITIONAL, ns.target, dns.rdataclass.IN, dns.rdatatype.AAAA)
                if aaaa_set:
                    for aaaa in aaaa_set:
                        addrs.add(ipaddress.IPv6Address(aaaa.address))
            nameservers[ns.target] = addrs
        return Delegation._build(nameservers)

    async def ds(self, name: DomainName) -> RRset[dns.rdtypes.ANY.DS.DS] | None:
        ds_set = (await self.query(dns.message.make_query(name, dns.rdatatype.DS)))[0]
        if ds_set is None:
            return None
        for ds in ds_set:
            assert isinstance(ds, dns.rdtypes.ANY.DS.DS)
        return typing.cast(RRset[dns.rdtypes.ANY.DS.DS], ds_set)

    async def find_parent_zone(self, name: DomainName) -> RRset[dns.rdtypes.ANY.NS.NS]:
        if isinstance(name, str):
            name = dns.name.from_text(name)
        assert name != dns.name.root, f"Can't get parent of root: {name!r}"
        parent = name
        while True:
            parent = parent.parent()
            ns = await self.ns(parent)
            if ns:
                return ns

    async def ipv4_address(self, name: DomainName) -> list[ipaddress.IPv4Address]:
        a_set = (await self.query(dns.message.make_query(name, dns.rdatatype.A)))[0]
        if a_set is None:
            return []
        return [ipaddress.IPv4Address(a.address) for a in a_set]

    async def ipv6_address(self, name: DomainName) -> list[ipaddress.IPv6Address]:
        aaaa_set = (await self.query(dns.message.make_query(name, dns.rdatatype.AAAA)))[0]
        if aaaa_set is None:
            return []
        return [ipaddress.IPv6Address(aaaa.address) for aaaa in aaaa_set]

    async def addresses(self, name: DomainName) -> list[IPAddress]:
        a: list[IPAddress] | None = None
        aaaa: list[IPAddress] | None = None
        # avoid trio.MultiError: forward first exception we get
        exc: Exception | None = None

        async def get_v4() -> None:
            nonlocal a, exc
            try:
                a = typing.cast(list[IPAddress], await self.ipv4_address(name))
            except Exception as e:
                if exc is None:
                    exc = e
        async def get_v6() -> None:
            nonlocal aaaa, exc
            try:
                aaaa = typing.cast(list[IPAddress], await self.ipv6_address(name))
            except Exception as e:
                if exc is None:
                    exc = e

        async with trio.open_nursery() as nursery:
            nursery.start_soon(get_v4)
            nursery.start_soon(get_v6)
        if not exc is None:
            raise exc
        assert not a is None, f"get_v4 failed: {a!r}"
        assert not aaaa is None, f"get_v6 failed: {aaaa!r}"
        return a + aaaa


class DnsLookupServersCache:
    def __init__(self, *, tcp_only: bool=False):
        self._servers: dict[tuple[IPAddress, bool], DnsLookup] = {}
        self._tcp_only = tcp_only

    def get(self, address: IPAddress, *, tcp_only: bool | tuple[()] = ()) -> DnsLookup:
        if isinstance(tcp_only, tuple):
            tcp_only = self._tcp_only
        l = self._servers.get((address, tcp_only), None)
        if not l:
            l = DnsLookup(server=address, tcp_only=tcp_only)
            self._servers[(address, tcp_only)] = l
        return l


class CheckNameContext:
    def __init__(self, ir: IcingaResult, opts: argparse.Namespace) -> None:
        self.ir = ir
        self.opts = opts
        self.name = dns.name.from_text(opts.zone)
        self.resolver = DnsLookup(server=opts.resolver)
        self.lookup_servers = DnsLookupServersCache(tcp_only=opts.auth_tcp)
        self.host = dns.name.from_text(opts.host) if opts.host else None

        #### Values we fill during lookup ####

        # List of all NS names for parent zone
        # Filled in `retrieve_parent_addresses`
        self.parent_ns_set: set[dns.name.Name] = set()

        # IP addresses of the authoritative nameservers for the parent zone,
        # and for each address the name it was resolved from.
        # If multiple names resolve to the same address, only a single (random) name is selected.
        # Filled in `retrieve_parent_addresses`
        self.parent_addresses: dict[IPAddress, dns.name.Name] = {}

        # Delegations and DS sets in dictionaries where the key represents values that
        # were found mapping to a list of IPs the values were received from.
        # Filled in `retreive_delegations`
        self.parent_delegations: dict[Delegation, list[IPAddress]] = {}
        self.parent_ds_sets: dict[typing.FrozenSet[dns.rdtypes.ANY.DS.DS], list[IPAddress]] = {}

        # if parent_delegations has exactly one entry, this is its key
        # Set in `check_delegations`
        self.main_delegation: Delegation | None = None

        # Similar to `parent_delegations`, but results from authoritatives
        # Filled in `retreive_auth_responses`
        self.auth_delegations: dict[Delegation, set[IPAddress]] = {}
        self.auth_soas = SoaSets()

        # if `auth_soas` has single entry (ignoring serial and mname)
        # set in main `check`
        self.main_soa: dns.rdtypes.ANY.SOA.SOA | None = None

        # `auths` similar to `parent_addresses` - only a single (random) name select per IP
        # set in main `check`
        self.auths_by_name: dict[dns.name.Name, set[IPAddress]] = {}
        self.auths: dict[IPAddress, dns.name.Name] = {}

    @staticmethod
    def sources_names(sources: typing.Iterable[IPAddress], names: dict[IPAddress, typing.Any]) -> str:
        return ', '.join(
            f'{source} ({names[source]})'
            for source in sources
        )

    async def retrieve_parent_addresses(self) -> None:
        try:
            parent_ns_rrset = await self.resolver.find_parent_zone(self.name)
        except dns.exception.DNSException as e:
            self.ir.error(f'Failed to find parent zone for {self.name}')
            self.ir.error(str(e))
            return
        self.ir.debug(f'NS for parent zone {parent_ns_rrset.name} of {self.name}: ' + ' '.join(str(ns.target) for ns in parent_ns_rrset))

        async def get_parent_ns_addr(parent_ns: dns.name.Name) -> None:
            try:
                addresses = await self.resolver.addresses(parent_ns)
                self.ir.debug(f'Addresses of NS {parent_ns}: ' + ' '.join(str(addr) for addr in addresses))
            except dns.exception.DNSException as e:
                self.ir.error(f'Failed to lookup addresses for NS {parent_ns} of parent zone {parent_ns_rrset.name} of {self.name}')
                self.ir.error(str(e))
                return
            owns_an_address = False
            for addr in addresses:
                if addr in self.parent_addresses:
                    self.ir.log(f'IP {addr} duplicate in NS set for {parent_ns_rrset.name} (known as {self.parent_addresses[addr]} and {parent_ns})')
                else:
                    owns_an_address = True
                    self.parent_addresses[addr] = parent_ns
            # don't add to `self.parent_ns_set` if it only contains duplicate IPs
            if owns_an_address:
                self.parent_ns_set.add(parent_ns)
            else:
                self.ir.log(f'NS {parent_ns} for parent zone has no unique IPs, removing it from further consideration')

        async with trio.open_nursery() as nursery:
            for ns in parent_ns_rrset:
                nursery.start_soon(get_parent_ns_addr, ns.target)

    async def retreive_delegations(self) -> None:
        reachable_auths: set[dns.name.Name] = set()
        auth_ips_down: set[IPAddress] = set()

        async def get_delegation(ns_addr: IPAddress, ns_name: dns.name.Name):
            r = self.lookup_servers.get(ns_addr)
            try:
                delegation = await r.delegation(self.name)
            except (DnsTimeoutError, DnsConnectionClosedError, DnsOsError):
                # assume other "OS" error indicate some kind of unreachability
                auth_ips_down.add(ns_addr)
                self.ir.log(f'Failed to get delegation (NS) to {self.name} from {ns_addr} ({ns_name})')
                return
            except dns.exception.DNSException as e:
                self.ir.error(f'Failed to get delegation (NS) to {self.name} from {ns_addr} ({ns_name})')
                self.ir.error(str(e))
                return
            if delegation is None:
                self.ir.error(f'Missing delegation (NS) to {self.name} from {ns_addr} ({ns_name})')
                return
            self.parent_delegations.setdefault(delegation, []).append(ns_addr)
            try:
                ds = await r.ds(self.name)
            except (DnsTimeoutError, DnsConnectionClosedError, DnsOsError):
                # assume other "OS" error indicate some kind of unreachability
                auth_ips_down.add(ns_addr)
                self.ir.log(f'Failed to get (dnssec) (DS) to {self.name} from {ns_addr} ({ns_name})')
                return
            except dns.exception.DNSException as e:
                self.ir.error(f'Failed to get (dnssec) DS to {self.name} from {ns_addr} ({ns_name})')
                self.ir.error(str(e))
                return
            if not ds is None:
                ds_set = frozenset(ds)
                self.parent_ds_sets.setdefault(ds_set, []).append(ns_addr)
            reachable_auths.add(ns_name)

        async with trio.open_nursery() as nursery:
            for ns_addr, ns_name in self.parent_addresses.items():
                nursery.start_soon(get_delegation, ns_addr, ns_name)

        parent_auths_down = self.parent_ns_set - reachable_auths

        if len(reachable_auths) == 0:
            self.ir.error(f"No parent authoritatives reachable")
            return
        assert self.parent_delegations, "Must have at least one delegation here"

        # TODO: what are proper warning levels?
        # we have:
        # * number of unreachable IPs (vs total IPs)
        # * number of unreachable NS, i.e. unreachable on all their IPs (vs number of NS)
        #   (although if NS have overlapping IPs this is a "simplified" measurement)

        # IP address heuristic: half the IPs should be reachable
        if 2 * len(auth_ips_down) > len(self.parent_addresses):
            self.ir.warn(f"Too many parent authoritative IPs unreachable: {self.sources_names(auth_ips_down, self.parent_addresses)}")
        # NS heuristic: half of the NS should be reachable
        elif 2 * len(parent_auths_down) > len(self.parent_ns_set):
            self.ir.warn(f"Too many parent authoritatives unreachable: {' '.join(map(str, parent_auths_down))}")
        elif parent_auths_down:
            self.ir.log(f"Some parent authoritatives are unreachable: {' '.join(map(str, parent_auths_down))}")

    async def check_delegations(self) -> None:
        if len(self.parent_delegations) == 0:
            self.ir.error(f'No Delegations to {self.name} anywhere')
            return

        logged_delegation: set[Delegation] = set()

        for delegation, sources in self.parent_delegations.items():
            missing_glue: set[dns.name.Name] = set()
            for ns, addrs in delegation.nameservers_tuple:
                if not addrs:
                    if ns.is_subdomain(self.name):
                        missing_glue.add(ns)
            if missing_glue:
                logged_delegation.add(delegation)
                self.ir.log(f'Delegation from {self.sources_names(sources, self.parent_addresses)}:')
                for line in delegation.to_lines():
                    self.ir.log(line)
                self.ir.error(f"Delegation missing glue for {' '.join(map(str, missing_glue))}")

        if len(self.parent_delegations) != 1:
            self.ir.warn(f'Different delegations to {self.name}')
        else:
            self.main_delegation = list(self.parent_delegations.keys())[0]

        if len(self.parent_delegations) != 1 or self.ir._debug:
            for delegation, sources in self.parent_delegations.items():
                if delegation in logged_delegation:
                    # don't log twice
                    continue
                self.ir.log(f'Delegation from {self.sources_names(sources, self.parent_addresses)}:')
                for line in delegation.to_lines():
                    self.ir.log(line)

        if not self.parent_ds_sets:
            self.ir.log(f'No DS records - insecure delegation')
        elif len(self.parent_ds_sets) > 1:
            self.ir.warn(f'Different DS record sets to {self.name}')
        if len(self.parent_ds_sets) > 1 or self.ir._debug:
            for ds_set, sources in self.parent_ds_sets.items():
                self.ir.log(f'DS set from {self.sources_names(sources, self.parent_addresses)}:')
                for ds in ds_set:
                    self.ir.log(f'* DS {ds}')

    async def build_merged_delegation(self) -> dict[dns.name.Name, set[IPAddress]]:
        """
        Merge all delegations from parent auths to a single set of names mapping to IP addresses.
        Delegations might have different glue; merged all IP addresses returned for a name.
        (Delegations only contain glue addresses if the name is actually in the zone.)
        Additionally lookup all names we didn't find addresses in the delegations for (i.e. no glue).
        """
        merged_delegation: dict[dns.name.Name, set[IPAddress]] = {}
        for delegation in self.parent_delegations:
            for ns, frozen_addrs in delegation.nameservers_tuple:
                merged_delegation.setdefault(ns, set()).update(frozen_addrs)

        async def get_unglued_ns_addr(ns: dns.name.Name) -> None:
            if ns.is_subdomain(self.name):
                # missing glue - logged above.
                return
            try:
                addrs = await self.resolver.addresses(ns)
            except dns.exception.DNSException as e:
                self.ir.error(f'Failed to lookup (glue) addresses for NS {ns} for {self.name}')
                self.ir.error(str(e))
                return
            if not addrs:
                self.ir.error(f'NS {ns} has no addresses')
                return
            merged_delegation[ns] = set(addrs)

        async with trio.open_nursery() as nursery:
            for ns, addrs in list(merged_delegation.items()):
                if not addrs:
                    nursery.start_soon(get_unglued_ns_addr, ns)

        if self.ir._debug:
            self.ir.debug('Using merged delegation:')
            for ns, addrs in merged_delegation.items():
                self.ir.log(f'* NS {ns}: ' + ' '.join(map(str, addrs)))

        return merged_delegation

    def transpose_delegation_auths(
        self,
        auths: dict[dns.name.Name, set[IPAddress]],
    ) -> dict[IPAddress, dns.name.Name]:
        result: dict[IPAddress, dns.name.Name] = {}
        for ns_name, addrs in auths.items():
            for addr in addrs:
                if addr in result:
                    self.ir.log(f'IP {addr} duplicate in NS set for {self.name} (known as {result[addr]} and {ns_name})')
                else:
                    result[addr] = ns_name
        return result

    def is_dnskey_in_dssets(self, key: dns.rdtypes.ANY.DNSKEY.DNSKEY) -> bool:
        key_id = dns.dnssec.key_id(key)
        for ds_set in self.parent_ds_sets:
            for ds in ds_set:
                if ds.key_tag != key_id or ds.algorithm != key.algorithm:
                    continue
                cmp_ds = dns.dnssec.make_ds(self.name, key, ds.digest_type)
                if cmp_ds.digest == ds.digest:
                    return True
        return False

    def filter_dnskeys(self, dnskey_set: RRset[dns.rdtypes.ANY.DNSKEY.DNSKEY]) -> dict[dns.name.Name, dns.rdataset.Rdataset]:
        result: list[dns.rdtypes.ANY.DNSKEY.DNSKEY] = []
        for key in dnskey_set:
            if 0 == key.flags & dns.rdtypes.ANY.DNSKEY.ZONE or 3 != key.protocol:
                continue
            if self.is_dnskey_in_dssets(key):
                result.append(key)
        return { self.name: dns.rdataset.from_rdata_list(dnskey_set.ttl, result) }

    async def check_dnskey(self, ns_addr: IPAddress, ns_name: dns.name.Name) -> bool:
        l = self.lookup_servers.get(ns_addr)
        try:
            dnskey_set, response, _ = await l.query(dns.message.make_query(self.name, dns.rdatatype.DNSKEY, want_dnssec=True))
        except (DnsTimeoutError, DnsConnectionClosedError, DnsOsError):
            self.ir.log(f'Failed to get DNSKEY for {self.name} from {ns_addr} ({ns_name})')
            self.ir.log(str(sys.exception()))
            return False
        except dns.exception.DNSException as e:
            self.ir.error(f'Failed to get DNSKEY for {self.name} from {ns_addr} ({ns_name})')
            self.ir.error(str(e))
            return False
        if not dnskey_set:
            self.ir.error(f'Missing DNSKEY for {self.name} from {ns_addr} ({ns_name}) - but delegation is secured by DS')
            return False
        acceptable_dnskeys = self.filter_dnskeys(dnskey_set)
        rrsig_set = response.get_rrset(dns.message.ANSWER, self.name, dnskey_set.rdclass, dns.rdatatype.RRSIG, dns.rdatatype.DNSKEY)
        if not rrsig_set:
            self.ir.error(f'DNSKEY for {self.name} not signed from {ns_addr} ({ns_name})')
            return False
        try:
            dns.dnssec.validate(dnskey_set, rrsig_set, acceptable_dnskeys)
            self.ir.debug(f'DNSKEY good for {self.name} from {ns_addr} ({ns_name})')
        except dns.dnssec.ValidationFailure:
            self.ir.error(f'Invalid signatures on DNSKEY for {self.name} from {ns_addr} ({ns_name})')
            return False
        return True

    async def retrieve_auth_delegation(self, ns_addr: IPAddress, ns_name: dns.name.Name) -> bool:
        l = self.lookup_servers.get(ns_addr)
        try:
            delegation = await l.delegation(self.name)
        except (DnsTimeoutError, DnsConnectionClosedError, DnsOsError):
            self.ir.log(f'Failed to get delegation (NS) to {self.name} from authoritative {ns_addr} ({ns_name})')
            self.ir.log(str(sys.exception()))
            return False
        except dns.exception.DNSException as e:
            self.ir.error(f'Failed to get delegation (NS) to {self.name} from authoritative {ns_addr} ({ns_name})')
            self.ir.error(str(e))
            return False
        if delegation:
            self.auth_delegations.setdefault(delegation, set()).add(ns_addr)
            return True
        else:
            self.ir.error(f'Got no delegation (NS) to {self.name} from authoritative {ns_addr} ({ns_name})')
            return False

    async def get_soa_into(self, soa_set: SoaSets, ns_addr: IPAddress, ns_name: dns.name.Name) -> bool:
        l = self.lookup_servers.get(ns_addr)
        try:
            soa = await l.soa(self.name)
        except (DnsTimeoutError, DnsConnectionClosedError, DnsOsError):
            self.ir.log(f'Failed to get SOA for {self.name} from {ns_addr} ({ns_name})')
            self.ir.log(str(sys.exception()))
            return False
        except dns.exception.DNSException as e:
            self.ir.error(f'Failed to get SOA for {self.name} from {ns_addr} ({ns_name})')
            self.ir.error(str(e))
            return False

        if soa:
            soa_set.insert(soa[0], ns_addr)
        else:
            self.ir.error(f'Got no SOA for {self.name} from {ns_addr} ({ns_name})')
            return False

        other_l: DnsLookup | None = None
        other_msg = ""
        if self.opts.auth_tcp:
            if self.opts.host_udp and self.host == ns_name:
                other_l = self.lookup_servers.get(ns_addr, tcp_only=False)
                other_msg = "via UDP (after successful TCP)"
        else:
            other_l = self.lookup_servers.get(ns_addr, tcp_only=True)
            other_msg = "via TCP (after successful UDP)"

        if other_l:
            try:
                other_soa = await other_l.soa(self.name)
            except (DnsTimeoutError, DnsConnectionClosedError, DnsOsError):
                self.ir.log(f'Failed to get SOA {other_msg} for {self.name} from {ns_addr} ({ns_name})')
                self.ir.log(str(sys.exception()))
                return False
            except dns.exception.DNSException as e:
                self.ir.error(f'Failed to get SOA {other_msg} for {self.name} from {ns_addr} ({ns_name})')
                self.ir.error(str(e))
                return False
            if other_soa != soa:
                self.ir.error(f"Got different SOA {other_msg}")
                self.ir.error(f"{other_soa} != {soa}")
                return False
            self.ir.debug(f"Successful SOA {other_msg} from {ns_addr} ({ns_name})")

        return True


    async def retreive_auth_response(self, soa_set: SoaSets, ns_addr: IPAddress, ns_name: dns.name.Name) -> bool:
        success = True
        if not await self.get_soa_into(soa_set, ns_addr, ns_name):
            success = False
        if not await self.retrieve_auth_delegation(ns_addr, ns_name):
            success = False
        if self.parent_ds_sets:
            if not await self.check_dnskey(ns_addr, ns_name):
                success = False
        return success

    async def retreive_auth_responses(self) -> None:
        count_success = 0

        async def handle(ns_addr: IPAddress, ns_name: dns.name.Name) -> None:
            nonlocal count_success
            if await self.retreive_auth_response(self.auth_soas, ns_addr, ns_name):
                count_success += 1
            elif ns_name == self.host:
                self.ir.error(f'Required authoritative nameserver {self.host} failed to respond (see log)')

        async with trio.open_nursery() as nursery:
            for ns_addr, ns_name in self.auths.items():
                nursery.start_soon(handle, ns_addr, ns_name)

        if 2 * count_success < len(self.auths):
            self.ir.error(f'Too many authoritative nameserver down (see log)')

    async def check_hidden_primary(self, hm: dns.name.Name) -> None:
        assert self.main_soa
        hm_soas: SoaSets = SoaSets()

        count_success = 0

        async def handle(ns_addr: IPAddress) -> None:
            nonlocal count_success
            if await self.retreive_auth_response(hm_soas, ns_addr, hm):
                count_success += 1

        count_hm_queries = 0

        try:
            hm_addrs = await self.resolver.addresses(hm)
        except dns.exception.DNSException as e:
            self.ir.error(f'Failed to lookup addresses for hidden primary {hm}')
            self.ir.error(str(e))
            return
        async with trio.open_nursery() as nursery:
            for ns_addr in hm_addrs:
                if ns_addr in self.auths:
                    self.ir.log(f'IP {ns_addr} of hidden primary {hm} duplicate of NS set for {self.name} (known as {self.auths[ns_addr]})')
                    continue
                count_hm_queries += 1
                self.auths[ns_addr] = hm
                nursery.start_soon(handle, ns_addr)

        if count_hm_queries == 0:
            self.ir.log(f"All addresses of hidden primary {hm} overlap with addresses of NS set")
            return

        if count_success == 0:
            self.ir.warn(f"Hidden primary {hm} not reachable")
            return

        if count_success < count_hm_queries:
            # we also want to know if dual-stack isn't working properly, so
            # allowing 50% to fail is not helpful.
            self.ir.warn(f"Not all addresses of hidden primary {hm} are reachable")

        for soa, sources in hm_soas.soa_sets.items():
            # no need to print name for sources - there is only a single hidden primary
            sources_str = ' '.join(map(str, sources))
            if self.main_soa != soa:
                if self.main_soa == soa.replace(serial=self.main_soa.serial):
                    self.ir.warn(f"Hidden primary SOA from {sources_str} serial mismatch for {self.name}")
                    self.ir.log(f"Main authoritative SOA: {self.main_soa}")
                    self.ir.log(f"Hidden primary SOA from {sources_str}: {soa}")
                else:
                    self.ir.error(f"Hidden primary SOA from {sources_str} mismatch for {self.name}")
                    self.ir.log(f"Main authoritative SOA: {self.main_soa}")
                    self.ir.log(f"Hidden primary SOA from {sources_str}: {soa}")
            else:
                self.ir.debug(f"Hidden primary SOA from {sources_str} matches main SOA")

    _DIRECT_PRIVATE_ADDRESS_ZONES = {
        '10.in-addr.arpa',  # 10.0.0.0/8
        '168.192.in-addr.arpa',  # 192.168.0.0/16
        # fc00::/7 not included for now.
        # ULA: only fd00::/8 as entry; entries below can be specified through --private-entry ...
        '0.0.d.f.ip6.arpa',
    }

    def is_private_zone(self) -> bool:
        if self.opts.private:
            return True
        private_entries = {
            dns.name.from_text(part.strip())
            for entry in (self.opts.private_entry or [])
            for part in entry.split(',')
        }
        if self.name in private_entries:
            return True
        if self.opts.private_reverse:
            if self.name == dns.name.root:
                return False
            # simply cases in self._DIRECT_PRIVATE_ADDRESS_ZONES, but shouldn't be too many
            for z in self._DIRECT_PRIVATE_ADDRESS_ZONES:
                if self.name == dns.name.from_text(z):
                    return True
            parent = self.name.parent()
            if parent == dns.name.from_text('172.in-addr.arpa'):
                # 172.16.0.0/12: 172.16 - 172.31
                if self.name.labels[0] in (f'{i}'.encode('ascii') for i in range(16, 32)):
                    return True
            elif parent == dns.name.from_text('100.in-addr.arpa'):
                # 100.64.0.0/10, 100.64 - 100.127
                if self.name.labels[0] in (f'{i}'.encode('ascii') for i in range(64, 128)):
                    return True
        return False

    async def get_private_zone_auths(self) -> dict[dns.name.Name, set[IPAddress]]:
        result: dict[dns.name.Name, set[IPAddress]] = {}

        async def get_ns_addr(ns: dns.name.Name) -> None:
            try:
                result[ns] = set(await self.resolver.addresses(ns))
            except dns.exception.DNSException as e:
                self.ir.error(f'Failed to lookup addresses for NS {ns} of zone {self.name}')
                self.ir.error(str(e))
                return

        try:
            ns_set = await self.resolver.ns(self.name)
        except dns.exception.DNSException as e:
            self.ir.error(f'Failed to lookup NS for zone {self.name}')
            self.ir.error(str(e))
            return result
        if not ns_set:
            self.ir.error(f'Got no NS for zone {self.name} from resolver')
            return result

        async with trio.open_nursery() as nursery:
            for ns in ns_set:
                nursery.start_soon(get_ns_addr, ns.target)

        if self.ir._debug:
            self.ir.debug(f'Using delegation received from resolver for private zone {self.name}:')
            for ns_name, addrs in result.items():
                self.ir.log(f'* NS {ns_name}: ' + ' '.join(map(str, addrs)))

        return result

    async def check(self) -> None:
        private_zone = self.is_private_zone()
        if not private_zone:
            await self.retrieve_parent_addresses()
            await self.retreive_delegations()
            await self.check_delegations()

            self.auths_by_name = await self.build_merged_delegation()
        else:
            self.ir.log("Not checking delegation, was marked as private zone entry")
            self.auths_by_name = await self.get_private_zone_auths()

        if len(self.auths_by_name) == 0:
            # there should have been an error message about this already
            self.ir.error("Can't continue without authoritative nameservers to query")
            return

        delegated_auths: set[dns.name.Name] = set(self.auths_by_name)

        if self.host and not self.host in delegated_auths:
            # add explicit host to NS
            try:
                self.auths_by_name[self.host] = set(await self.resolver.addresses(self.host))
            except dns.exception.DNSException as e:
                self.ir.error(f'Failed to lookup addresses for explicit NS-host {self.host} of zone {self.name}')
                self.ir.error(str(e))
                return

        self.auths = self.transpose_delegation_auths(self.auths_by_name)

        await self.retreive_auth_responses()

        if self.auth_soas.distinct_soas == 0:
            self.ir.error(f'Found no SOA for {self.name}')
            return

        log_soas = False
        if self.auth_soas.distinct_soas == 1:
            self.main_soa = self.auth_soas.main_soa()
            if len(self.auth_soas.serials) > 1:
                self.ir.warn(f'Found multiple SOA serials for {self.name}')
                log_soas = True
            if len(self.auth_soas.mnames) > 1:
                self.ir.log(f'Found multiple SOA mnames for {self.name}')
                log_soas = True
            if not log_soas:
                self.ir.log(f'Found single SOA: {self.name} SOA {self.main_soa}')

        if log_soas or self.ir._debug:
            if not log_soas:
                self.ir.debug('SOAs from authoritatives:')
            for soa, sources in self.auth_soas.soa_sets.items():
                self.ir.log(f'From {self.sources_names(sources, self.auths)}:')
                self.ir.log(f'  {self.name} SOA {soa}')

        if self.main_soa:
            if not self.main_soa.mname in delegated_auths:
                self.ir.log(f'Found hidden primary {self.main_soa.mname}')
                if not self.opts.skip_hidden_primary:
                    if self.main_soa.mname == self.host:
                        # this would only show "duplicate IP in NS"
                        self.ir.log(f'Checked hidden primary as explicit host already')
                    else:
                        await self.check_hidden_primary(self.main_soa.mname)
                else:
                    self.ir.log(f'Not checking hidden primary as requested')
            else:
                self.ir.debug(f'SOA mname {self.main_soa.mname} is a public auth, no hidden primary detected')

        # This includes the delegation received from the hidden primary, so
        # `check_hidden_primary` must come before this.
        logged_expected_delegation = False
        if self.main_delegation:
            for delegation, sources in self.auth_delegations.items():
                if self.main_delegation != delegation:
                    if not logged_expected_delegation:
                        logged_expected_delegation = True
                        self.ir.log('Delegation:')
                        for line in self.main_delegation.to_lines():
                            self.ir.log(line)
                    self.ir.warn(f'Different delegation data from authoritatives {self.sources_names(sources, self.auths)}:')
                    for line in delegation.to_lines():
                        self.ir.log(line)
                else:
                    self.ir.debug(f'Delegation matched data from authoritive {self.sources_names(sources, self.auths)}')
        elif private_zone:
            # Can't compare with delegation from parent(s), so check between auths
            if len(self.auth_delegations) > 1:
                self.ir.error("Multiple delegations from authoritatives")
            if len(self.auth_delegations) > 1 or self.ir._debug:
                for delegation, sources in self.auth_delegations.items():
                    self.ir.log(f'Delegation data from authoritatives {self.sources_names(sources, self.auths)}:')
                    for line in delegation.to_lines():
                        self.ir.log(line)

        expect_auths = [
            dns.name.from_text(auth.strip().lower())
            for opt in (self.opts.expect_auth or [])
            for auth in opt.split(',')
        ]
        for auth in expect_auths:
            if not auth in delegated_auths:
                self.ir.error(f'Missing expected auth {auth} in NS set for {self.name}')
            else:
                self.ir.debug(f'Found expected auth {auth} in NS set for {self.name}')

        if self.main_delegation:
            self.ir.log('+ Checked delegation with authoritative')
        if self.parent_ds_sets:
            self.ir.log('+ Checked DNSSEC delegation with authoritative')
        if self.main_soa:
            if self.opts.auth_tcp:
                if self.opts.host and self.opts.host_udp:
                    self.ir.log(f'+ Checked SOA records via TCP (plus UDP for {self.host})')
                else:
                    self.ir.log('+ Checked SOA records via TCP')
            else:
                self.ir.log('+ Checked SOA records via UDP + TCP')


def main():
    try:
        p = argparse.ArgumentParser(description="Check DNS zone (delegation, SOA, DNSSEC) on all authoritatives and hidden primary")
        p.add_argument('zone', help='Name of DNS zone to check')
        p.add_argument('--skip-hidden-primary', '--skip-hidden-master', action='store_true', help="Don't verify hidden primary data (might be blocked by a firewall)")
        p.add_argument('--auth-tcp', action='store_true', help="Use TCP for DNS queries to authoritative servers")
        p.add_argument('--debug', action='store_true', help="Show debug output")
        p.add_argument('--expect-auth', action='append', help="Make sure listed servers are part of the NS set of the zone")
        p.add_argument('--private', action='store_true', help="Don't check the delegation (i.e. only SOA)")
        p.add_argument('--private-reverse', action='store_true', help="Don't check the delegation (i.e. only SOA), if the zone is an entry to the reverse zone for private IP addresses")
        p.add_argument('--private-entry', action='append', help="Mark given zone name(s) as entry to private zones where delegation isn't checked")
        p.add_argument('--resolver', action='store', help="Use given address as DNS resolver")
        p.add_argument('--host', action='store', help="Make sure zone is hosted on given host")
        p.add_argument('--host-udp', action='store_true', help="Check SOA via UDP too (when --auth-tcp is used) for nameserver given via --host")
        opts = p.parse_args()
        ir = IcingaResult(debug=opts.debug)
        trio.run(CheckNameContext(ir, opts).check)
    except Exception as e:
        import traceback
        ir.unknown("Unknown exception")
        # although traceback returns "lines", they include \n - undo that
        for line in ''.join(traceback.format_exception(e)).splitlines():
            ir.log(line)
    ir.finish()


if __name__ == '__main__':
    main()
