#!/usr/bin/python2

# The MIT License (MIT)
#
# Copyright (c) 2017-2018 Red Hat Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# Authors: Merlin Mathesius <merlinm@redhat.com>
#          Stef Walter <stefw@redhat.com>

import argparse
import json
import os
import shlex
import subprocess
import sys
import logging
import errno
import distutils.util

EMPTY_INVENTORY = {}
LOG_FILE = "default_provisioners.log"


def get_artifact_path(path=""):
    """Return path to an artifact file in artifacts directory. If path == ""
    than return path artifacts dir.  Create artifacts dir if necessary.
    """
    artifacts = os.environ.get("TEST_ARTIFACTS", os.path.join(os.getcwd(), "artifacts"))
    try:
        os.makedirs(artifacts)
    except OSError as exc:
        if exc.errno != errno.EEXIST or not os.path.isdir(artifacts):
            raise
    return os.path.join(artifacts, path)


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# stderr output
conhandler = logging.StreamHandler()
# Print to strerr by default messages with level >= warning, can be changed
# with setting TEST_DEBUG=1.
try:
    diagnose = distutils.util.strtobool(os.getenv("TEST_DEBUG", "0"))
except ValueError:
    diagnose = 0
conhandler.setLevel(logging.WARNING)
if diagnose:
    # Collect all messages with any log level to stderr.
    conhandler.setLevel(logging.NOTSET)
# Log format for stderr.
log_format = "[%(levelname)-5.5s] {}: %(message)s".format(os.path.basename(__file__))
formatter = logging.Formatter(log_format)
conhandler.setFormatter(formatter)
logger.addHandler(conhandler)


def main(argv):
    parser = argparse.ArgumentParser(description="Inventory for local RPM installed host")
    parser.add_argument("--list", action="store_true", help="Verbose output")
    parser.add_argument('--host', help="Get host variables")
    parser.add_argument("subjects", nargs="*", default=os.environ.get("TEST_SUBJECTS", ""))
    opts = parser.parse_args()
    # Send logs to common logfile for all default provisioners.
    log_file = get_artifact_path(LOG_FILE)
    fhandler = logging.FileHandler(log_file)
    # Collect all messages with any log level to log file.
    fhandler.setLevel(logging.NOTSET)
    log_format = ("%(asctime)s [{}/%(threadName)-12.12s] [%(levelname)-5.5s]:"
                  "%(message)s").format(os.path.basename(__file__))
    logFormatter = logging.Formatter(log_format)
    fhandler.setFormatter(logFormatter)
    logger.addHandler(fhandler)
    logger.info("Start provisioner.")
    if opts.host:
        data = gethost(opts.host)
    else:
        data = getlist(opts.subjects)
    sys.stdout.write(json.dumps(data, indent=4, separators=(',', ': ')))


def getlist(subjects):
    hosts = []
    variables = {}

    host_vars = gethost(subjects)
    if host_vars:
        hosts.append("rpms")
        variables["rpms"] = host_vars

    if not hosts:
        return EMPTY_INVENTORY
    return {"subjects": {"hosts": hosts, "vars": {}},
            "localhost": {"hosts": hosts, "vars": {}},
            "_meta": {"hostvars": variables}}


def gethost(subjects):
    subjects = shlex.split(subjects)
    repos = []
    rpms = []

    for subject in subjects:
        if subject.endswith(".rpm"):
            rpms.append(subject)
        elif isrepo(subject):
            repos.append(subject)

    if not repos and not rpms:
        return EMPTY_INVENTORY

    # The variables
    variables = {
        "ansible_connection": "local"
    }

    try:
        tty = os.open("/dev/tty", os.O_WRONLY)
        os.dup2(tty, 2)
    except OSError:
        tty = None
        pass

    # enable any provided repos first so RPMs can pull dependencies from them if needed
    for repo in repos:
        addrepo = ["/usr/bin/yum", "config-manager", "--add-repo", repo]
        try:
            subprocess.check_call(addrepo, stdout=sys.stderr.fileno())
        except subprocess.CalledProcessError:
            raise RuntimeError("could not add repo: {0}".format(repo))

    if rpms:
        install = ["/usr/bin/yum", "-y", "install"] + rpms
        try:
            subprocess.check_call(install, stdout=sys.stderr.fileno())
        except subprocess.CalledProcessError:
            raise RuntimeError("could not install rpms: {0}".format(rpms))

    return variables


def isrepo(subject):
    return os.path.isfile(os.path.join(subject, "repodata", "repomd.xml"))


if __name__ == '__main__':
    ret = -1
    try:
        main(sys.argv)
        ret = 0
    except Exception:
        # Backtrace stack goes to log file. If TEST_DEBUG == 1, it goes to stderr too.
        logger.info("Fatal error in provision script.", exc_info=True)
    sys.exit(ret)
