#!/usr/bin/python2

import argparse
import atexit
import errno
import json
import os
import fmf
import shutil
import shlex
import signal
import logging
import StringIO
import random
import socket
import subprocess
import sys
import tempfile
import time
import distutils.util

import yaml

IDENTITY = """
-----BEGIN RSA PRIVATE KEY-----
MIIEpQIBAAKCAQEA1DrTSXQRF8isQQfPfK3U+eFC4zBrjur+Iy15kbHUYUeSHf5S
jXPYbHYqD1lHj4GJajC9okle9rykKFYZMmJKXLI6987wZ8vfucXo9/kwS6BDAJto
ZpZSj5sWCQ1PI0Ce8CbkazlTp5NIkjRfhXGP8mkNKMEhdNjaYceO49ilnNCIxhpb
eH5dH5hybmQQNmnzf+CGCCLBFmc4g3sFbWhI1ldyJzES5ZX3ahjJZYRUfnndoUM/
TzdkHGqZhL1EeFAsv5iV65HuYbchch4vBAn8jDMmHh8G1ixUCL3uAlosfarZLLyo
3HrZ8U/llq7rXa93PXHyI/3NL/2YP3OMxE8baQIDAQABAoIBAQCxuOUwkKqzsQ9W
kdTWArfj3RhnKigYEX9qM+2m7TT9lbKtvUiiPc2R3k4QdmIvsXlCXLigyzJkCsqp
IJiPEbJV98bbuAan1Rlv92TFK36fBgC15G5D4kQXD/ce828/BSFT2C3WALamEPdn
v8Xx+Ixjokcrxrdeoy4VTcjB0q21J4C2wKP1wEPeMJnuTcySiWQBdAECCbeZ4Vsj
cmRdcvL6z8fedRPtDW7oec+IPkYoyXPktVt8WsQPYkwEVN4hZVBneJPCcuhikYkp
T3WGmPV0MxhUvCZ6hSG8D2mscZXRq3itXVlKJsUWfIHaAIgGomWrPuqC23rOYCdT
5oSZmTvFAoGBAPs1FbbxDDd1fx1hisfXHFasV/sycT6ggP/eUXpBYCqVdxPQvqcA
ktplm5j04dnaQJdHZ8TPlwtL+xlWhmhFhlCFPtVpU1HzIBkp6DkSmmu0gvA/i07Z
pzo5Z+HRZFzruTQx6NjDtvWwiXVLwmZn2oiLeM9xSqPu55OpITifEWNjAoGBANhH
XwV6IvnbUWojs7uiSGsXuJOdB1YCJ+UF6xu8CqdbimaVakemVO02+cgbE6jzpUpo
krbDKOle4fIbUYHPeyB0NMidpDxTAPCGmiJz7BCS1fCxkzRgC+TICjmk5zpaD2md
HCrtzIeHNVpTE26BAjOIbo4QqOHBXk/WPen1iC3DAoGBALsD3DSj46puCMJA2ebI
2EoWaDGUbgZny2GxiwrvHL7XIx1XbHg7zxhUSLBorrNW7nsxJ6m3ugUo/bjxV4LN
L59Gc27ByMvbqmvRbRcAKIJCkrB1Pirnkr2f+xx8nLEotGqNNYIawlzKnqr6SbGf
Y2wAGWKmPyEoPLMLWLYkhfdtAoGANsFa/Tf+wuMTqZuAVXCwhOxsfnKy+MNy9jiZ
XVwuFlDGqVIKpjkmJyhT9KVmRM/qePwgqMSgBvVOnszrxcGRmpXRBzlh6yPYiQyK
2U4f5dJG97j9W7U1TaaXcCCfqdZDMKnmB7hMn8NLbqK5uLBQrltMIgt1tjIOfofv
BNx0raECgYEApAvjwDJ75otKz/mvL3rUf/SNpieODBOLHFQqJmF+4hrSOniHC5jf
f5GS5IuYtBQ1gudBYlSs9fX6T39d2avPsZjfvvSbULXi3OlzWD8sbTtvQPuCaZGI
Df9PUWMYZ3HRwwdsYovSOkT53fG6guy+vElUEDkrpZYczROZ6GUcx70=
-----END RSA PRIVATE KEY-----
"""


AUTH_KEY = ("AAAAB3NzaC1yc2EAAAADAQABAAABAQDUOtNJdBEXyKxBB898rdT54ULjMGuO6v4jLX"
            "mRsdRhR5Id/lKNc9hsdioPWUePgYlqML2iSV72vKQoVhkyYkpcsjr3zvBny9+5xej3"
            "+TBLoEMAm2hmllKPmxYJDU8jQJ7wJuRrOVOnk0iSNF+FcY/yaQ0owSF02Nphx47j2K"
            "Wc0IjGGlt4fl0fmHJuZBA2afN/4IYIIsEWZziDewVtaEjWV3InMRLllfdqGMllhFR+"
            "ed2hQz9PN2QcapmEvUR4UCy/mJXrke5htyFyHi8ECfyMMyYeHwbWLFQIve4CWix9qt"
            "ksvKjcetnxT+WWrutdr3c9cfIj/c0v/Zg/c4zETxtp")

DEF_USER = "root"
DEF_PASSWD = "foobar"
DEF_HOST = "127.0.0.3"

USER_DATA = """#cloud-config
users:
  - default
  - name: {0}
    ssh_authorized_keys:
      - ssh-rsa {2} standard-test-qcow2
ssh_pwauth: True
chpasswd:
  list: |
    {0}:{1}
  expire: False
""".format(DEF_USER, DEF_PASSWD, AUTH_KEY)

EMPTY_INVENTORY = {}
LOG_FILE = "default_provisioners.log"


def get_artifact_path(path=""):
    """Return path to an artifact file in artifacts directory. If path == ""
    than return path artifacts dir.  Create artifacts dir if necessary.
    """
    artifacts = os.environ.get("TEST_ARTIFACTS", os.path.join(os.getcwd(), "artifacts"))
    try:
        os.makedirs(artifacts)
    except OSError as exc:
        if exc.errno != errno.EEXIST or not os.path.isdir(artifacts):
            raise
    return os.path.join(artifacts, path)


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# stderr output
conhandler = logging.StreamHandler()
# Print to strerr by default messages with level >= warning, can be changed
# with setting TEST_DEBUG=1.
try:
    diagnose = distutils.util.strtobool(os.getenv("TEST_DEBUG", "0"))
except ValueError:
    diagnose = 0
conhandler.setLevel(logging.WARNING)
if diagnose:
    # Collect all messages with any log level to stderr.
    conhandler.setLevel(logging.NOTSET)
# Log format for stderr.
log_format = "[%(levelname)-5.5s] {}: %(message)s".format(os.path.basename(__file__))
formatter = logging.Formatter(log_format)
conhandler.setFormatter(formatter)
logger.addHandler(conhandler)

# Temporary fix for issue #233.  Strings in Python3 are all unicode.  To get a
# unicode string in Python2 necessary to use `unicode()`.  fmf.Tree() expects a
# unicode string even in Python2 module.
try:
    UNICODE_EXISTS = bool(type(unicode))
except NameError:
    unicode = str


def main(argv):
    parser = argparse.ArgumentParser(description="Inventory for a QCow2 test image")
    parser.add_argument("--list", action="store_true", help="Verbose output")
    parser.add_argument('--host', help="Get host variables")
    parser.add_argument("subjects", nargs="*", default=shlex.split(os.environ.get("TEST_SUBJECTS", "")))
    opts = parser.parse_args()
    # Send logs to common logfile for all default provisioners.
    log_file = get_artifact_path(LOG_FILE)
    fhandler = logging.FileHandler(log_file)
    # Collect all messages with any log level to log file.
    fhandler.setLevel(logging.NOTSET)
    log_format = ("%(asctime)s [{}/%(threadName)-12.12s] [%(levelname)-5.5s]:"
                  "%(message)s").format(os.path.basename(__file__))
    logFormatter = logging.Formatter(log_format)
    fhandler.setFormatter(logFormatter)
    logger.addHandler(fhandler)
    logger.info("Start provisioner.")
    if opts.host:
        data = inv_host(opts.host)
    else:
        data = inv_list(opts.subjects)
    # Dump Ansible inventory.
    sys.stdout.write(json.dumps(data, indent=4, separators=(',', ': ')))


def inv_list(subjects):
    hosts = []
    variables = {}
    for subject in subjects:
        host_vars = inv_host(subject)
        if host_vars:
            hosts.append(subject)
            variables[subject] = host_vars
    if not hosts:
        return EMPTY_INVENTORY
    return {"localhost": {"hosts": hosts, "vars": {}},
            "subjects": {"hosts": hosts, "vars": {}},
            "_meta": {"hostvars": variables}}


def get_qemu_smp_arg():
    """Determine the number of CPUs that should be visible in the guest.
    See e.g. https://www.redhat.com/archives/libvirt-users/2017-June/msg00025.html
    We want to match the number of host physical cores.
    """
    # We may be run in a cgroup with fewer cores available than physical.
    available_cpu = int(subprocess.check_output(['nproc']).strip())
    # https://stackoverflow.com/questions/6481005/how-to-obtain-the-number-of-cpus-cores-in-linux-from-the-command-line
    core_sockets = set()
    for line in StringIO.StringIO(subprocess.check_output(['lscpu', '-b', '-p=Core,Socket'])):
        if line.startswith('#'):
            continue
        core_sockets.add(line.strip())
    sockets = min(available_cpu, len(core_sockets))
    return '{},sockets={},cores=1,threads=1'.format(sockets, sockets)


def write_debug_inventory(file_, host_vars):
    raw_inventory = {"all": {"children": {"localhost": {"hosts": host_vars},
                                          "subjects": {"hosts": host_vars},
                                          }
                             }
                     }
    with open(file_, "w") as ofile:
        inventory = yaml.dump(raw_inventory, default_flow_style=False)
        ofile.write(inventory)
    return inventory


class FmfMetadataTree(object):
    """This is aux class to hold one copy FMF tree. fmf.Tree(path) could be
    very resource consuming when walking through big project with many
    directories.

    Returns
    -------
    fmf.Tree() object or False.

    """
    tree = None
    """fmf.Tree() object."""
    path = None
    """Metadata tree is created for this path."""
    def get(self, path="."):
        if self.path != path or self.tree is None:
            FmfMetadataTree.path = path
            try:
                FmfMetadataTree.tree = fmf.Tree(unicode(path))
            except Exception:
                """Fmf initialization failed. Do not try initialize further for this path.
                """
                FmfMetadataTree.tree = False
        return self.tree


def fmf_get(path, default=None):
    """Return parameter from FMF at desired path or default.

    Parameters
    ----------
    path: dict
        List of strings. Strings form a path for looking parameter.
    default: str
        Function ignores this parameter.

    Returns
    -------
    str
        Found parameter in FMF or `default`.

    """
    tree = FmfMetadataTree().get()
    if not tree:
        return default
    path.insert(0, 'standard-inventory-qcow2')
    value = default
    for provision in tree.prune(names=[".*/provision$"]):
        value = provision.data
        for node in path:
            try:
                value = value[node]
            except (KeyError, TypeError):
                value = default
                break
    return value


class AdditionalDrives(object):
    """Prepare additional drives options for qemu

    Based on FMF config creates temporary sparse files and returns
    corresponding qemu command options.
    cleanup() will be called eventually to close the files.
    """

    _tempfiles = list()

    @classmethod
    def generate(cls):
        """Generate sparse files and return drive qemu options
        Returns
        -------
        list of str
                qemu -drive options
        """
        drives = fmf_get(['qemu', 'drive'], list())
        result = []
        for drive in drives:
            # create temporary sparse file
            size = int(drive.get('size', 2 * 1024 ** 3))  # default size: 2G
            path = drive.get('path', None)
            path = str(path) if path is not None else None
            drive_file = tempfile.NamedTemporaryFile(dir=path)
            drive_file.truncate(size)
            cls._tempfiles.append({'file': drive_file, 'path': path})
            logger.info("Created temporary sparse file '%s'." % drive_file.name)

            # translate data into qemu command options
            result += ["-drive", "file=%s,media=disk,if=virtio" % drive_file.name]
        atexit.register(cls.cleanup)
        return result

    @classmethod
    def cleanup(cls):
        """Close all temporary files created by this class
        """
        for tempfile in cls._tempfiles:
            fullname = os.path.join(tempfile['path'], tempfile['file'].name)
            logger.info("Closing and removing temporary sparse file '%s'" % fullname)
            if os.path.isfile(fullname):
                tempfile['file'].close()


def start_qemu(image, cloudinit, portrange=(2222, 5555)):
    for _ in range(10):
        port = random.randint(*portrange)
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        try:
            sock.bind((DEF_HOST, port))
        except IOError:
            pass
        else:
            break
        finally:
            sock.close()
    else:
        raise RuntimeError("unable to find free local port to map SSH to")

    # Log all traffic received from the guest to a file.
    log_file = "{0}.guest.log".format(os.path.basename(image))
    log_guest = get_artifact_path(log_file)
    # Log from qemu itself.
    log_qemu = log_guest.replace(".guest.log", ".qemu.log")
    # Parameters from FMF:
    param_m = str(fmf_get(['qemu', 'm'], "1024"))
    param_net_nic_model = str(fmf_get(['qemu', 'net_nic', 'model'], 'virtio'))
    # Use -cpu host and -smp by default.
    # virtio-rng-pci: https://wiki.qemu.org/Features/VirtIORNG

    qemu_cmd = ["/usr/bin/qemu-system-x86_64",
                "-cpu", "host", "-smp", get_qemu_smp_arg(),
                "-m", param_m, image, "-enable-kvm", "-snapshot", "-cdrom", cloudinit,
                "-net", "nic,model=%s" % param_net_nic_model, "-net", "user,hostfwd=tcp:127.0.0.3:{0}-:22".format(port),
                "-device", "virtio-rng-pci", "-rtc", "base=utc",
                "-device", "isa-serial,chardev=pts2", "-chardev", "file,id=pts2,path=" + log_guest,
                "-display", "none"]

    qemu_cmd += AdditionalDrives.generate()

    if diagnose:
        qemu_cmd += ["-vnc", DEF_HOST + ":1,to=4095"]

    qemu_proc = subprocess.Popen(qemu_cmd, stdout=open(log_qemu, 'a'), stderr=subprocess.STDOUT)
    time.sleep(5)

    if qemu_proc and diagnose:
        logger.info("qemu-kvm is running with VNC server. PID: {}".format(qemu_proc.pid))
        logger.info("netstat -ltpn4 | grep {0} # to find VNC server port".format(qemu_proc.pid))

    return qemu_proc, port


def inv_host(image):
    if not image.endswith((".qcow2", ".qcow2c")):
        logger.info("Return empty inventory for image: %s.", image)
        return EMPTY_INVENTORY

    null = open(os.devnull, 'w')

    try:
        tty = os.open("/dev/tty", os.O_WRONLY)
        os.dup2(tty, 2)
    except OSError:
        tty = None

    # A directory for temporary stuff
    directory = tempfile.mkdtemp(prefix="inventory-cloud")
    identity = os.path.join(directory, "identity")
    with open(identity, 'w') as f:
        f.write(IDENTITY)
    os.chmod(identity, 0o600)
    metadata = os.path.join(directory, "meta-data")
    with open(metadata, 'w') as f:
        f.write("")
    userdata = os.path.join(directory, "user-data")
    with open(userdata, 'w') as f:
        f.write(USER_DATA)

    # Create our cloud init so we can log in
    cloudinit = os.path.join(directory, "cloud-init.iso")
    subprocess.check_call(["/usr/bin/genisoimage", "-input-charset", "utf-8",
                           "-volid", "cidata", "-joliet", "-rock", "-quiet",
                           "-output", cloudinit, userdata, metadata], stdout=null)

    logger.info("Launching virtual machine for {0}".format(image))

    # And launch the actual VM

    proc = None  # for failure detection
    cpe = None  # for exception scoping
    for _ in range(0, 5):
        try:
            proc, port = start_qemu(image, cloudinit)
            break
        except subprocess.CalledProcessError as cpe:
            time.sleep(1)
            continue
    if proc is None:
        raise RuntimeError("Could not launch VM for qcow2 image"
                           " '{0}':{1}".format(image, cpe.output))

    for _ in range(0, 30):
        try:
            # The variables
            variables = {
                "ansible_port": "{0}".format(port),
                "ansible_host": DEF_HOST,
                "ansible_user": DEF_USER,
                "ansible_ssh_pass": DEF_PASSWD,
                "ansible_ssh_private_key_file": identity,
                "ansible_ssh_common_args": "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no"
            }

            # Write out a handy inventory file, for our use and for debugging
            inventory = os.path.join(directory, "inventory")
            write_debug_inventory(inventory, {image: variables})

            # Wait for ssh to come up
            ping = [
                "/usr/bin/ansible",
                "--inventory",
                inventory,
                "localhost",
                "--module-name",
                "raw",
                "--args",
                "/bin/true"
            ]

            (pid, _) = os.waitpid(proc.pid, os.WNOHANG)
            if pid != 0:
                raise RuntimeError("qemu failed to launch qcow2 image: {0}".format(image))
            subprocess.check_call(ping, stdout=null, stderr=null)
            break
        except subprocess.CalledProcessError:
            time.sleep(3)
    else:
        # Kill the qemu process
        try:
            os.kill(proc.pid, signal.SIGTERM)
        except OSError:
            pass
        raise RuntimeError("could not access launched qcow2 image: {0}".format(image))

    # Process of our parent
    ppid = os.getppid()

    child = os.fork()
    if child:
        # Need to figure out what python interpreter to use
        interpreters = ["/usr/bin/python2", "/usr/bin/python3", "/usr/libexec/platform-python"]
        for interpreter in interpreters:
            check_file = [
                "/usr/bin/ansible",
                "--inventory",
                inventory,
                "localhost",
                "--module-name",
                "raw",
                "--args",
                "ls %s" % interpreter
            ]
            try:
                subprocess.check_call(check_file, stdout=null, stderr=null)
                ansible_python_interpreter = interpreter
                break
            except subprocess.CalledProcessError:
                pass
        else:
            logger.error("Could not set ansible_python_interpreter.")
            return None
        variables["ansible_python_interpreter"] = ansible_python_interpreter
        # Update inventory file
        write_debug_inventory(inventory, {image: variables})
        return variables

    # Daemonize and watch the processes
    os.chdir("/")
    os.setsid()
    os.umask(0)

    if tty is None:
        tty = null.fileno()

    # Duplicate standard input to standard output and standard error.
    os.dup2(null.fileno(), 0)
    os.dup2(tty, 1)
    os.dup2(tty, 2)

    # alternatively, lock on a file
    lock_file = os.environ.get("LOCK_ON_FILE", None)
    ssh_cmd = ("ssh -p {port} -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i {identity} {user}@{host}"
               ).format(port=port, identity=identity, user=DEF_USER, host=DEF_HOST)
    logger.info(ssh_cmd)
    logger.info("Password is: {}".format(DEF_PASSWD))
    logger.info("Cloudinit init location: {}".format(directory))
    logger.info("export ANSIBLE_INVENTORY={0}".format(inventory))
    logger.info("Wait until parent for provision-script (ansible-playbook) dies or qemu.")
    while True:
        time.sleep(3)

        if lock_file:
            if not os.path.exists(lock_file):
                logger.error("Lock file is gone.")
                break
        else:
            # Now wait for the parent process to go away, then kill the VM
            try:
                os.kill(ppid, 0)
                os.kill(proc.pid, 0)
            except OSError:
                logger.info("Either parent process or VM process is gone.")
                break
    if diagnose:
        def _signal_handler(*args):
            logger.info("Diagnose ending.")
        signal.signal(signal.SIGTERM, _signal_handler)
        logger.info("kill {0} # when finished to debug VM.".format(os.getpid()))
        signal.pause()
    # Kill the qemu process
    try:
        os.kill(proc.pid, signal.SIGTERM)
    except OSError:
        pass
    shutil.rmtree(directory)
    sys.exit(0)


if __name__ == '__main__':
    ret = -1
    try:
        main(sys.argv)
        ret = 0
    except Exception:
        # Backtrace stack goes to log file. If TEST_DEBUG == 1, it goes to stderr too.
        logger.info("Fatal error in provision script.", exc_info=True)
    sys.exit(ret)
