#!/usr/bin/python
# -*- encoding: utf-8; py-indent-offset: 4 -*-
# +------------------------------------------------------------------+
# |             ____ _               _        __  __ _  __           |
# |            / ___| |__   ___  ___| | __   |  \/  | |/ /           |
# |           | |   | '_ \ / _ \/ __| |/ /   | |\/| | ' /            |
# |           | |___| | | |  __/ (__|   <    | |  | | . \            |
# |            \____|_| |_|\___|\___|_|\_\___|_|  |_|_|\_\           |
# |                                                                  |
# | Copyright Mathias Kettner 2014             mk@mathias-kettner.de |
# +------------------------------------------------------------------+
#
# This file is part of Check_MK.
# The official homepage is at http://mathias-kettner.de/check_mk.
#
# check_mk is free software;  you can redistribute it and/or modify it
# under the  terms of the  GNU General Public License  as published by
# the Free Software Foundation in version 2.  check_mk is  distributed
# in the hope that it will be useful, but WITHOUT ANY WARRANTY;  with-
# out even the implied warranty of  MERCHANTABILITY  or  FITNESS FOR A
# PARTICULAR PURPOSE. See the  GNU General Public License for more de-
# tails. You should have  received  a copy of the  GNU  General Public
# License along with GNU Make; see the file  COPYING.  If  not,  write
# to the Free Software Foundation, Inc., 51 Franklin St,  Fifth Floor,
# Boston, MA 02110-1301 USA.
"""
Special agent for monitoring azure cloud applications with Check_MK.
"""
#
# Docs about authentication:
#  https://docs.microsoft.com/en-us/python/azure/python-sdk-azure-authenticate?view=azure-python
# About monitoring:
#  https://github.com/Azure/azure-sdk-for-python/blob/master/doc/sample_azure-monitor.rst
# About the python binding
#   https://azure.microsoft.com/en-us/resources/samples/sql-database-python-manage/
#

import json
import datetime
import calendar
import sys
import re
import argparse
import logging

from multiprocessing import Process, Lock, Queue
from Queue import Empty as QueueEmpty

# We have to set a null handler for logging before importing the azure stuff.
#   Otherwise a warning will be sent to stderr - and if for some other reason
#   the agent returns a non-zero exit code this (irrelevant) warning would be
#   all the user sees.
logging.getLogger().addHandler(logging.NullHandler())
# pylint: disable=wrong-import-position
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.monitor import MonitorManagementClient
from azure.mgmt.monitor.models.error_response import ErrorResponseException
from azure.mgmt.compute import ComputeManagementClient
from azure.common.credentials import ServicePrincipalCredentials
logging.getLogger().handlers.pop()

import cmk.password_store
cmk.password_store.replace_passwords()

LOG = logging.getLogger(__name__)

METRICS_SELECTED = {
    # to add a new metric, just add a made up name, run the
    # agent, and you'll get a error listing available metrics!
    # key: list of (name(s), interval, aggregation, filter)
    'Microsoft.Network/virtualNetworkGateways': [
        ('AverageBandwidth,P2SBandwidth', 'PT1M', 'average', None),
        ('P2SConnectionCount', 'PT1M', 'maximum', None),
    ],
    'Microsoft.Sql/servers/databases': [
        ('storage_percent,deadlock,cpu_percent,dtu_consumption_percent,'
         'connection_successful,connection_failed', 'PT1M', 'average', None),
    ],
    'Microsoft.Storage/storageAccounts': [('UsedCapacity,Ingress,Egress,Transactions,'
                                           'SuccessServerLatency,SuccessE2ELatency,Availability',
                                           'PT1H', 'total', None),],
    'Microsoft.Web/sites': [('CpuTime,AverageResponseTime,Http5xx', 'PT1M', 'total', None),],
}


class AsyncMapper(object):  # pylint: disable=too-few-public-methods
    '''Create an async drop-in replacement for builtin 'map'

    which does not require the involved values to be pickle-able,
    nor third party modules such as 'multiprocess' or 'dill'.

    Usage:
             map_ = AsyncMapper()

             for results in map_(function, arguments_iter):
                 do_stuff()

    Note that the order of the results does not correspond
    to that of the arguments.

    Keywords for initialization:

      * timeout:  number of seconds we will wait for the next result
                  before terminating all remaining jobs (default: None)
      * debug:    raise exceptions in jobs (default: False)
      * fallback: specify a function, called in case an exception occurs in
                  the mapped function. The fallback function should return
                  a tuple (err, value). If err is falsey, value will we be
                  yielded (default: (1, None)).
    '''

    def __init__(self, timeout=None, debug=False, fallback=lambda x: (1, None)):
        super(AsyncMapper, self).__init__()
        self.timeout = timeout
        self.debug = debug
        self.fallback = fallback

    def __call__(self, function, args_iter):
        queue = Queue()
        jobs = {}

        def produce(id_, args):
            try:
                queue.put((id_, 0, function(args)))
            except Exception as _e:  # pylint: disable=broad-except
                queue.put((id_,) + self.fallback(args))
                if self.debug:
                    raise

        # start
        for id_, args in enumerate(args_iter):
            jobs[id_] = Process(target=produce, args=(id_, args))
            jobs[id_].start()

        # consume
        while jobs:
            try:
                id_, err, result = queue.get(block=True, timeout=self.timeout)
            except QueueEmpty:
                break
            if not err:
                yield result
            jobs.pop(id_)

        for job in jobs.values():
            job.terminate()


def parse_arguments(argv):
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--debug", action="store_true", help='''Debug mode: raise Python exceptions''')
    parser.add_argument(
        "-v",
        "--verbose",
        action="count",
        default=0,
        help='''Verbose mode (for even more output use -vvv)''')
    parser.add_argument(
        "--sequential", action="store_true", help='''Sequential mode: do not use multiprocessing''')
    parser.add_argument(
        "--dump-config", action="store_true", help='''Dump parsed configuration and exit''')
    parser.add_argument(
        "--timeout",
        default=10,
        type=int,
        help='''Timeout for individual processes in seconds (default 10)''')
    parser.add_argument(
        "--piggyback_vms",
        default="grouphost",
        choices=["grouphost", "self"],
        help='''Send VM piggyback data to group host (default) or the VM iteself''')

    # REQUIRED
    parser.add_argument("--subscription", required=True, help="Azure subscription ID")
    parser.add_argument("--client", required=True, help="Azure client ID")
    parser.add_argument("--tenant", required=True, help="Azure tenant ID")
    parser.add_argument("--secret", required=True, help="Azure authentication secret")
    # CONSTRAIN DATA TO REQUEST
    parser.add_argument(
        "--explicit-config",
        default=[],
        nargs='*',
        help='''list of arguments providing the configuration in <key>=<value> format.
             If omitted, all groups and resources are fetched.
             If specified, every 'group=<name>' argument starts a new group configuration,
             and every 'resource=<name>' arguments specifies a resource.''')
    args = parser.parse_args(argv)

    # LOGGING
    if args.verbose >= 3:
        # this will show third party log messages as well
        fmt = "%(levelname)s: %(name)s: %(filename)s: %(lineno)s: %(message)s"
        lvl = logging.DEBUG
    elif args.verbose == 2:
        # be verbose, but silence msrest, urllib3 and requests_oauthlib
        fmt = "%(levelname)s: %(funcName)s: %(lineno)s: %(message)s"
        lvl = logging.DEBUG
        logging.getLogger('msrest').setLevel(logging.WARNING)
        logging.getLogger('urllib3').setLevel(logging.WARNING)
        logging.getLogger('requests_oauthlib').setLevel(logging.WARNING)
    elif args.verbose == 1:
        fmt = "%(levelname)s: %(funcName)s: %(message)s"
        lvl = logging.INFO
    else:
        fmt = "%(levelname)s: %(message)s"
        lvl = logging.WARNING
    logging.basicConfig(level=lvl, format=fmt)

    # V-VERBOSE INFO
    for key, value in vars(args).items():
        if key == "secret":
            value = '****'
        LOG.debug('argparse: %s = %r', key, value)

    return args


# The following *Config objects provide a Configuration instance as described in
# CMK-513 (comment-12620).
# For now the passed commandline arguments are used to create it.


class GroupConfig(object):
    def __init__(self, name):
        super(GroupConfig, self).__init__()
        if not name:
            raise ValueError("falsey group name: %r" % name)
        self.name = name
        self.resources = []

    @property
    def fetchall(self):
        return not self.resources

    def add_key(self, key, value):
        if key == "resources":
            self.resources = value.split(",")
            return
        raise ValueError("unknown config key: %s" % key)

    def __str__(self):
        if self.fetchall:
            return "[%s]\n  <fetchall>" % self.name
        return "[%s]\n" % self.name + "\n".join("resource: %s" % r for r in self.resources)


class ExplicitConfig(object):
    def __init__(self, raw_list=()):
        super(ExplicitConfig, self).__init__()
        self.groups = {}
        self.current_group = None
        for item in raw_list:
            if '=' not in item:
                raise ValueError("must be in <key>=<value> format: %r" % item)
            key, value = item.split('=', 1)
            self.add_key(key, value)

    @property
    def fetchall(self):
        return not self.groups

    def add_key(self, key, value):
        if key == "group":
            self.current_group = self.groups.setdefault(value, GroupConfig(value))
            return
        if self.current_group is None:
            raise RuntimeError("missing arg: group=<name>")
        self.current_group.add_key(key, value)

    def is_configured(self, resource):
        if self.fetchall:
            return True
        group_config = self.groups.get(resource.info['group'])
        if group_config is None:
            return False
        if group_config.fetchall:
            return True
        return resource.info['name'] in group_config.resources

    def __str__(self):
        if self.fetchall:
            return "[<fetchall>]"
        return "\n".join(str(group) for group in self.groups.values())


class Section(object):
    SEP = 124
    LOCK = Lock()

    def __init__(self, name, piggytarget=''):
        super(Section, self).__init__()
        self._head = []
        self._cont = []
        self._foot = []
        if piggytarget != '':
            self._head.append('<<<<%s>>>>\n' % piggytarget)
            self._foot.append('<<<<>>>>\n')
        self._head.append('<<<azure_%s:sep(%s)>>>\n' % (name, Section.SEP))

    @staticmethod
    def formatline(tokens):
        return chr(Section.SEP).join(map(str, tokens)) + '\n'

    def add(self, info):
        if not info:
            return
        if isinstance(info[0], (list, tuple)):  # we got a list of lines
            for row in info:
                self._cont.append(self.formatline(row))
        else:  # assume one single line
            self._cont.append(self.formatline(info))

    def write(self):
        if not self._cont:
            return
        with self.LOCK:
            sys.stdout.writelines(self._head + self._cont + self._foot)
            sys.stdout.flush()


class IssueCollecter(object):
    def __init__(self):
        super(IssueCollecter, self).__init__()
        self._list = []

    def add(self, issue_type, issued_by, issue_msg):
        issue = {'type': issue_type, 'issued_by': issued_by, 'msg': issue_msg}
        self._list.append(('issue', json.dumps(issue)))

    def dumpinfo(self):
        return self._list

    def __len__(self):
        return len(self._list)


class AzureMetricParseError(ValueError):
    pass


class AzureMetric(object):  # pylint: disable=too-many-instance-attributes

    HEADER = ("name", "aggregation", "value", "unit", "timestamp", "filter", "interval_id",
              "interval")

    def __init__(self, metric, aggregation, interval_id, filter_):
        super(AzureMetric, self).__init__()

        self.name = metric.name.value
        self.aggregation = aggregation
        self.value = None
        self.unit = metric.unit.name
        self.timestamp = None
        self.filter = filter_
        self.interval_id = interval_id
        self.interval = None

        count = len(metric.timeseries)
        if count == 0:
            msg = "no timeseries found for metric %r" % self.name
            raise AzureMetricParseError('info', msg)

        for measurement in reversed(metric.timeseries):
            if not measurement.data:
                continue

            try:
                self.interval = (measurement.data[-1].time_stamp - measurement.data[-2].time_stamp)
            except (IndexError, TypeError):
                pass

            for data in reversed(measurement.data):
                LOG.debug("data: %s", data)
                self.value = getattr(data, aggregation, None)
                if self.value is not None:
                    self.timestamp = calendar.timegm(data.time_stamp.utctimetuple())
                    return

        raise AzureMetricParseError('warning', "not found: %s (%s)" % (self.name, aggregation))

    @property
    def tuple(self):
        return tuple(getattr(self, field) for field in AzureMetric.HEADER)


class AzureResource(object):

    ID_PATTERN = re.compile("/subscriptions/([^/]*)/resourceGroups/([^/]*)/providers/([^/]*)")

    def __init__(self, resource):
        super(AzureResource, self).__init__()
        self.info = resource.as_dict()
        match = AzureResource.ID_PATTERN.match(resource.id)
        self.info["subscription"] = match.group(1)
        self.info["group"] = match.group(2)
        self.info["provider"] = match.group(3)

        self.section = resource.type.split('/')[-1].lower()
        self.piggytargets = [self.info["group"]]
        self.metrics = []

    def dumpinfo(self):
        lines = [("Resource",), (json.dumps(self.info),)]
        if self.metrics:
            lines += [("metrics following", len(self.metrics)), AzureMetric.HEADER]
            lines += [m.tuple for m in self.metrics]
        return lines


class AzureComputeClient(object):
    def __init__(self, creds, subs):
        super(AzureComputeClient, self).__init__()
        self._client = ComputeManagementClient(creds, subs)

    def process(self, vmach, args):
        use_keys = ('statuses',)
        inst_view = self._client.virtual_machines.get(
            vmach.info["group"], vmach.info["name"], expand='instanceView').instance_view.as_dict()
        items = ((k, inst_view.get(k)) for k in use_keys)
        vmach.info["specific_info"] = {k: v for k, v in items if v is not None}

        if args.piggyback_vms not in ("grouphost",):
            vmach.piggytargets.remove(vmach.info["group"])
        if args.piggyback_vms in ("self",):
            vmach.piggytargets.append(vmach.info["name"])


class AzureClient(object):  # pylint: disable=too-many-instance-attributes

    SPECIFIC_CLIENTS = {
        'Microsoft.Compute/virtualMachines': AzureComputeClient,
    }

    def __init__(self, args):
        super(AzureClient, self).__init__()
        self.args = args
        self.remaining_reads = "unknown (no metrics fetched)"

        self._creds = ServicePrincipalCredentials(
            client_id=args.client, secret=args.secret, tenant=args.tenant)
        # azure-api-call
        self._monitor_client = MonitorManagementClient(self._creds, self.args.subscription)
        # azure-api-call
        self._resource_client = ResourceManagementClient(self._creds, self.args.subscription)

        self.specific_clients = {}

        self.timespans = self._get_timespans()
        self.resources_cache = None
        self.metrics_cache = {}

    @staticmethod
    def _get_timespans():
        """compute timespans dict"""
        t_end = datetime.datetime.utcnow()

        def interval(**kwargs):
            start = t_end - datetime.timedelta(**kwargs)
            return "%s/%s" % (start.strftime("%Y-%m-%dT%H:%M:%SZ"),
                              t_end.strftime("%Y-%m-%dT%H:%M:%SZ"))

        timespans = {}
        timespans["PT1M"] = interval(minutes=3)
        timespans["PT5M"] = interval(minutes=15)
        timespans["PT1H"] = interval(hours=3)
        return timespans

    def init_specific(self, resources):
        r_types = set(r.info["type"] for r in resources)
        for rtp in r_types:
            client_class = AzureClient.SPECIFIC_CLIENTS.get(rtp)
            if client_class is not None:
                self.specific_clients[rtp] = client_class(self._creds, self.args.subscription)

    def process_specific(self, resource):
        client = self.specific_clients.get(resource.info["type"])
        if client is None:
            return
        client.process(resource, self.args)

    def discover_resources(self):
        if self.resources_cache is None:
            # azure-api-call
            raw_resources = self._resource_client.resources.list()
            self.resources_cache = map(AzureResource, raw_resources)
        return self.resources_cache

    def _metric_api_call(self, rid, timespan, interval, metric, aggregation, filter_):
        LOG.debug(
            "metrics.list(%r, timespan=%r, interval=%r, metric=%r, aggregation=%r,"
            " filter=%r, raw=True)", rid, timespan, interval, metric, aggregation, filter_)
        raw = self._monitor_client.metrics.list(
            rid,
            timespan=timespan,
            interval=interval,
            metric=metric,
            aggregation=aggregation,
            filter=filter_,
            raw=True,
        )
        # raw.output is what we'd gotten had we set raw=False.
        # It is a paged object, make sure to actually retrieve
        # all pages (as this may raise exceptions)
        raw_metrics = list(raw.output.value)
        return raw.response, raw_metrics

    def _fetch_specific_metrics(self, resource, metricnames, interval, aggregation, filter_, err):
        if metricnames.count(',') >= 20:
            raise ValueError("Azure API won't have requests with more than 20 metrics!")

        rid = resource.info["id"]
        timespan = self.timespans[interval]

        try:
            response, raw_metrics = self._metric_api_call(rid, timespan, interval, metricnames,
                                                          aggregation, filter_)
        except ErrorResponseException as exc:
            if self.args.debug:
                raise exc
            err.add("exception", rid, exc.message)
            LOG.exception(exc)
            return []

        LOG.debug("response: %s", response)
        self.remaining_reads = response.headers['x-ms-ratelimit-remaining-subscription-reads']

        metrics = []
        for raw_metric in raw_metrics:
            try:
                metrics.append(AzureMetric(raw_metric, aggregation, interval, filter_))
            except AzureMetricParseError as exc:
                err.add(exc[0], rid, exc[1])
                LOG.warning(exc.args[1])

        return metrics

    def get_metrics(self, resource, err):
        metric_params = METRICS_SELECTED.get(resource.info["type"], [])

        metrics = []
        for metricnames, interval, aggregation, filter_ in metric_params:
            metrics += self._fetch_specific_metrics(resource, metricnames, interval, aggregation,
                                                    filter_, err)

        return metrics


def process_resource(args):
    resource, client = args

    client.process_specific(resource)

    err = IssueCollecter()  # pass this to methods to collect issues
    for metric in client.get_metrics(resource, err):
        resource.metrics.append(metric)

    agent_info_section = Section('agent_info')
    agent_info_section.add(('remaining-reads', client.remaining_reads))
    agent_info_section.add(err.dumpinfo())

    sections = [agent_info_section]
    for piggytarget in resource.piggytargets:
        section = Section(resource.section, piggytarget)
        section.add(resource.dumpinfo())
        sections.append(section)

    return sections


def write_groups(resources):
    groups = sorted(set(r.info['group'] for r in resources))
    section = Section('agent_info')
    section.add(('monitored-groups', json.dumps(groups)))
    section.write()
    # create empty agent_info section for all groups, otherwise
    # the service will only be discovered if something goes wrong
    for group in groups:
        Section('agent_info', group).write()


def write_exception_to_agent_info_section(exception):
    # those exeptions are quite noisy. try to make them more concise:
    msg = str(exception).split('Trace ID')[0]
    msg = msg.split(':', 2)[-1].strip(' ,')

    if "does not have authorization to perform action" in msg:
        msg += "HINT: Make sure you have a proper role asigned to your client!"

    value = json.dumps((2, msg))
    section = Section('agent_info')
    section.add(('agent-bailout', value))
    section.write()


def main(argv=None):

    args = parse_arguments(argv or sys.argv[1:])
    config = ExplicitConfig(raw_list=args.explicit_config)
    str_config = "Configuration:\n%s\n" % config
    if args.dump_config:
        sys.stdout.write(str_config)
        return 0
    else:
        LOG.debug(str_config)

    try:
        client = AzureClient(args)
        resources = [r for r in client.discover_resources() if config.is_configured(r)]
        client.init_specific(resources)

        write_groups(resources)

        func_args = ((resource, client) for resource in resources)
        map_ = map if args.sequential else AsyncMapper(args.timeout, args.debug)
        for sections in map_(process_resource, func_args):
            for section in sections:
                section.write()

    except () if args.debug else Exception as exc:
        write_exception_to_agent_info_section(exc)
    return 0


if __name__ == "__main__":
    sys.exit(main())
