#!/usr/bin/python3
# pylint: disable=missing-module-docstring,invalid-name

import sys
import os
import time

stream = os.popen("runuser -l postgres -c env | grep PGDATA | cut -f2- -d=")
pgdata = stream.read().strip()

pg_hba_conf = pgdata + "/pg_hba.conf"
pg_hba_bak = (
    pg_hba_conf
    + "-"
    + "-".join([str(el).zfill(2) for el in list(time.localtime())][:6])
)

dbconf = {}
dborder = ["replication", "all"]

# - "local" is a Unix-domain socket
# - "host" is a TCP/IP socket (encrypted or not)
# - "hostssl" is a TCP/IP socket that is SSL-encrypted
# - "hostnossl" is a TCP/IP socket that is not SSL-encrypted
# - "hostgssenc" is a TCP/IP socket that is GSSAPI-encrypted
# - "hostnogssenc" is a TCP/IP socket that is not GSSAPI-encrypted
typeorder = ["local", "host", "hostssl", "hostnossl", "hostgssenc", "hostnogssenc"]


if not os.path.exists(pg_hba_conf):
    sys.stderr.write("pg_hba.conf does not exist\n")
    sys.exit(1)
if os.path.exists(pg_hba_bak):
    sys.stderr.write("pg_hba.conf backup file already exists\n")
    sys.exit(1)

os.rename(pg_hba_conf, pg_hba_bak)


# pylint: disable-next=unspecified-encoding
for line in open(pg_hba_bak).readlines():
    line = line.strip()
    if not line or line.startswith("#"):
        continue
    vals = list(filter(None, line.replace("\t", " ").split(" ")))
    if vals[1] not in dbconf:
        dbconf[vals[1]] = {}
        if vals[1] not in dborder:
            dborder.insert(0, vals[1])
    if vals[0] not in dbconf[vals[1]]:
        dbconf[vals[1]][vals[0]] = []
    dbconf[vals[1]][vals[0]].append(vals)

# pylint: disable-next=unspecified-encoding
with open(pg_hba_conf, "w") as pghba:
    for db in dborder:
        # pylint: disable-next=redefined-builtin
        for type in typeorder:
            if db in dbconf and type in dbconf[db]:
                user_all_entries = []
                for entry in dbconf[db][type]:
                    if (
                        entry[2] != "all"
                    ):  # if user is all, it should be always the last
                        pghba.write("\t".join(entry))
                        pghba.write("\n")
                    else:
                        user_all_entries.insert(0, entry)
                for user_entry in user_all_entries:
                    pghba.write("\t".join(user_entry))
                    pghba.write("\n")
