#!/usr/bin/env python3
#
# Copyright(c) 2012-2021 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
#
import sys

min_ver = (3, 6)
if sys.version_info < min_ver:
    print("Minimum required python version is {}.{}. Detected python version is '{}'"
          .format(*min_ver, sys.version), file=sys.stderr)
    exit(1)

import argparse
import re

import opencas


def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)


# Start - load all the caches and add cores


def start():
    try:
        config = opencas.cas_config.from_file(
            "/etc/opencas/opencas.conf", allow_incomplete=True
        )
    except Exception as e:
        eprint(e)
        eprint("Unable to parse config file.")
        exit(1)

    for cache in config.caches.values():
        try:
            opencas.start_cache(cache, load=True)
        except opencas.casadm.CasadmError as e:
            eprint(
                "Unable to load cache {0} ({1}). Reason:\n{2}".format(
                    cache.cache_id, cache.device, e.result.stderr
                )
            )


# Initial cache start


def add_core_recursive(core, config):
    with_error = False
    if core.added:
        return with_error
    if core.marked:
        eprint(
            "Unable to add core {0} to cache {1}. Reason:\nRecursive core configuration!".format(
                core.device, core.cache_id
            )
        )
        exit(3)
    core.marked = True
    match = re.match(r"/dev/cas(\d{1,5})-(\d{1,4})", core.device)
    if match:
        cache_id, core_id = match.groups()
        with_error = add_core_recursive(
            config.caches[int(cache_id)].cores[int(core_id)], config
        )
    try:
        opencas.add_core(core, False)
        core.added = True
    except opencas.casadm.CasadmError as e:
        eprint(
            "Unable to add core {0} to cache {1}. Reason:\n{2}".format(
                core.device, core.cache_id, e.result.stderr
            )
        )
        with_error = True
    return with_error


def init(force):
    exit_code = 0
    try:
        config = opencas.cas_config.from_file("/etc/opencas/opencas.conf")
    except Exception as e:
        eprint(e)
        eprint("Unable to parse config file.")
        exit(1)

    if not force:
        for cache in config.caches.values():
            try:
                status = opencas.check_cache_device(cache.device)
                if status["Is cache"] == "yes" and status["Cache dirty"] == "yes":
                    eprint(
                        "Unable to perform initial configuration.\n"
                        "One of cache devices contains dirty data."
                    )
                    exit(1)
            except opencas.casadm.CasadmError as e:
                eprint(
                    "Unable to check status of device {0}. Reason:\n{1}".format(
                        cache.device, e.result.stderr
                    )
                )
                exit(e.result.exit_code)

    for cache in config.caches.values():
        try:
            opencas.start_cache(cache, load=False, force=force)
        except opencas.casadm.CasadmError as e:
            eprint(
                "Unable to start cache {0} ({1}). Reason:\n{2}".format(
                    cache.cache_id, cache.device, e.result.stderr
                )
            )
            exit_code = 2
        try:
            opencas.configure_cache(cache)
        except opencas.casadm.CasadmError as e:
            eprint(
                "Unable to configure cache {0} ({1}). Reason:\n{2}".format(
                    cache.cache_id, cache.device, e.result.stderr
                )
            )
            exit_code = 2

    for core in config.cores:
        core.added = False
        core.marked = False
    for core in config.cores:
        with_error = add_core_recursive(core, config)
        if with_error:
            exit_code = 2

    exit(exit_code)


def settle(timeout, interval):
    try:
        not_initialized = opencas.wait_for_startup(timeout, interval)
    except Exception as e:
        eprint(e)
        # Don't fail the boot if we're missing the config
        exit(0)

    fail = False
    if not_initialized:
        for device in not_initialized:
            fail = fail or not device.is_lazy()
            eprint("Couldn't initialize device {}".format(device.device))

        eprint("Open CAS initialization failed. Couldn't set up all required devices")

    exit(1 if fail else 0)


# Stop - detach cores and stop caches
def stop(flush):
    try:
        opencas.stop(flush)
    except Exception as e:
        eprint(e)
        exit(1)

    exit(0)


# Command line arguments parsing


class cas:
    def __init__(self):
        parser = argparse.ArgumentParser(prog="casctl")
        subparsers = parser.add_subparsers(title="actions")

        parser_init = subparsers.add_parser("init", help="Setup initial configuration")
        parser_init.set_defaults(command="init")
        parser_init.add_argument(
            "--force", action="store_true", help="Force cache start"
        )

        parser_start = subparsers.add_parser("start", help="Start cache configuration")
        parser_start.set_defaults(command="start")

        parser_settle = subparsers.add_parser(
            "settle", help="Wait for startup of devices"
        )
        parser_settle.set_defaults(command="settle")
        parser_settle.add_argument(
            "--timeout",
            action="store",
            help="How long should command wait [s]",
            default=270,
            type=int,
        )
        parser_settle.add_argument(
            "--interval",
            action="store",
            help="Polling interval [s]",
            default=5,
            type=int,
        )

        parser_stop = subparsers.add_parser("stop", help="Stop cache configuration")
        parser_stop.set_defaults(command="stop")
        parser_stop.add_argument(
            "--flush", action="store_true", help="Flush data before stopping"
        )

        if len(sys.argv[1:]) == 0:
            parser.print_help()
            return

        args = parser.parse_args(sys.argv[1:])
        getattr(self, "command_" + args.command)(args)

    def command_init(self, args):
        init(args.force)

    def command_start(self, args):
        start()

    def command_settle(self, args):
        settle(args.timeout, args.interval)

    def command_stop(self, args):
        stop(args.flush)


if __name__ == "__main__":
    opencas.wait_for_cas_ctrl()
    cas()
