#!/usr/bin/python
# -*- encoding: utf-8; py-indent-offset: 4 -*-
# +------------------------------------------------------------------+
# |             ____ _               _        __  __ _  __           |
# |            / ___| |__   ___  ___| | __   |  \/  | |/ /           |
# |           | |   | '_ \ / _ \/ __| |/ /   | |\/| | ' /            |
# |           | |___| | | |  __/ (__|   <    | |  | | . \            |
# |            \____|_| |_|\___|\___|_|\_\___|_|  |_|_|\_\           |
# |                                                                  |
# | Copyright Mathias Kettner 2018             mk@mathias-kettner.de |
# +------------------------------------------------------------------+
#
# This file is part of Check_MK.
# The official homepage is at http://mathias-kettner.de/check_mk.
#
# check_mk is free software;  you can redistribute it and/or modify it
# under the  terms of the  GNU General Public License  as published by
# the Free Software Foundation in version 2.  check_mk is  distributed
# in the hope that it will be useful, but WITHOUT ANY WARRANTY;  with-
# out even the implied warranty of  MERCHANTABILITY  or  FITNESS FOR A
# PARTICULAR PURPOSE. See the  GNU General Public License for more de-
# tails. You should have  received  a copy of the  GNU  General Public
# License along with GNU Make; see the file  COPYING.  If  not,  write
# to the Free Software Foundation, Inc., 51 Franklin St,  Fifth Floor,
# Boston, MA 02110-1301 USA.
"""
Special agent for monitoring Amazon web services (AWS) with Check_MK.
"""

import abc
import argparse
import datetime
import json
import logging
import sys
import time
import errno
from typing import NamedTuple, Any, List, AnyStr
from pathlib2 import Path
import boto3  # type: ignore
import botocore  # type: ignore
from cmk.paths import tmp_dir
import cmk.store as store
import cmk.password_store

#   .--helpers-------------------------------------------------------------.
#   |                  _          _                                        |
#   |                 | |__   ___| |_ __   ___ _ __ ___                    |
#   |                 | '_ \ / _ \ | '_ \ / _ \ '__/ __|                   |
#   |                 | | | |  __/ | |_) |  __/ |  \__ \                   |
#   |                 |_| |_|\___|_| .__/ \___|_|  |___/                   |
#   |                              |_|                                     |
#   '----------------------------------------------------------------------'


def _datetime_converter(o):
    if isinstance(o, datetime.datetime):
        return o.__str__()


def _chunks(list_, length=100):
    return [list_[i:i + length] for i in xrange(0, len(list_), length)]


#.
#   .--section API---------------------------------------------------------.
#   |                       _   _                  _    ____ ___           |
#   |         ___  ___  ___| |_(_) ___  _ __      / \  |  _ \_ _|          |
#   |        / __|/ _ \/ __| __| |/ _ \| '_ \    / _ \ | |_) | |           |
#   |        \__ \  __/ (__| |_| | (_) | | | |  / ___ \|  __/| |           |
#   |        |___/\___|\___|\__|_|\___/|_| |_| /_/   \_\_|  |___|          |
#   |                                                                      |
#   '----------------------------------------------------------------------'

#   ---result distributor---------------------------------------------------


class ResultDistributor(object):
    """
    Mediator which distributes results from sections
    in order to reduce queries to AWS account.
    """

    def __init__(self):
        self._colleagues = []

    def add(self, colleague):
        self._colleagues.append(colleague)

    def distribute(self, sender, result):
        for colleague in self._colleagues:
            if colleague.name != sender.name:
                colleague.receive(sender, result)


#   ---sections/colleagues--------------------------------------------------

AWSSectionResults = NamedTuple("AWSSectionResults", [
    ("results", List),
    ("cache_timestamp", float),
])

AWSSectionResult = NamedTuple("AWSSectionResult", [
    ("piggyback_hostname", AnyStr),
    ("content", Any),
])

AWSColleagueContents = NamedTuple("AWSColleagueContents", [
    ("content", Any),
    ("cache_timestamp", float),
])

AWSRawContent = NamedTuple("AWSRawContent", [
    ("content", Any),
    ("cache_timestamp", float),
])

AWSComputedContent = NamedTuple("AWSComputedContent", [
    ("content", Any),
    ("cache_timestamp", float),
])

AWSCacheFilePath = Path(tmp_dir) / "agents" / "agent_aws"


class AWSSection(object):
    __metaclass__ = abc.ABCMeta

    def __init__(self, client, region, config, distributor=None):
        self._client = client
        self._region = region
        self._config = config
        self._distributor = ResultDistributor() if distributor is None else distributor
        self._received_results = {}
        self._cache_file_dir = AWSCacheFilePath / self._region / self._config.hostname
        self._cache_file = AWSCacheFilePath / self._region / self._config.hostname / self.name

    @abc.abstractproperty
    def name(self):
        pass

    @abc.abstractproperty
    def interval(self):
        """
        In general the default resolution of AWS metrics is 5 min (300 sec)
        The default resolution of AWS S3 metrics is 1 day (86400 sec)
        We use interval property for cached section.
        """
        pass

    @property
    def period(self):
        return 2 * self.interval

    def _send(self, content):
        self._distributor.distribute(self, content)

    def receive(self, sender, content):
        self._received_results.setdefault(sender.name, content)

    def run(self, use_cache=False):
        colleague_contents = self._get_colleague_contents()
        assert isinstance(
            colleague_contents, AWSColleagueContents
        ), "%s: Colleague contents must be of type 'AWSColleagueContents'" % self.name
        assert isinstance(
            colleague_contents.cache_timestamp,
            float), "%s: Cache timestamp of colleague contents must be of type 'float'" % self.name

        raw_content = self._get_raw_content(colleague_contents, use_cache=use_cache)
        assert isinstance(
            raw_content,
            AWSRawContent), "%s: Raw content must be of type 'AWSRawContent'" % self.name
        assert isinstance(
            raw_content.cache_timestamp,
            float), "%s: Cache timestamp of raw content must be of type 'float'" % self.name

        computed_content = self._compute_content(raw_content, colleague_contents)
        assert isinstance(computed_content, AWSComputedContent
                         ), "%s: Computed content must be of type 'AWSComputedContent'" % self.name
        assert isinstance(
            computed_content.cache_timestamp,
            float), "%s: Cache timestamp of computed content must be of type 'float'" % self.name

        self._send(computed_content)
        created_results = self._create_results(computed_content)
        assert isinstance(created_results,
                          list), "%s: Created results must be fo type 'list'" % self.name

        final_results = []
        for result in created_results:
            assert isinstance(
                result,
                AWSSectionResult), "%s: Result must be of type 'AWSSectionResult'" % self.name

            if not result.content:
                logging.info("%s: Result is empty or None", self.name)
                continue

            assert isinstance(
                result.piggyback_hostname, (unicode, str)
            ), "%s: Piggyback hostname of created result must be of type 'unicode' or 'str'" % self.name
            # In the related check plugin aws.include we parse these results and
            # extend list of json-loaded results.
            assert isinstance(result.content,
                              list), "%s: Result content must be of type 'list'" % self.name

            final_results.append(result)
        return AWSSectionResults(final_results, computed_content.cache_timestamp)

    def _get_raw_content(self, colleague_contents, use_cache=False):
        # Cache is only used if the age is lower than section interval AND
        # the collected data from colleagues are not newer
        self._cache_file_dir.mkdir(parents=True, exist_ok=True)
        if use_cache and self._cache_is_recent_enough(colleague_contents):
            raw_content, cache_timestamp = self._read_from_cache()
        else:
            raw_content = self._fetch_raw_content(colleague_contents)
            # TODO: Write cache only when _compute_section_content succeeded?
            if use_cache:
                self._write_to_cache(raw_content)
            cache_timestamp = time.time()
        return AWSRawContent(raw_content, cache_timestamp)

    def _cache_is_recent_enough(self, colleague_contents):
        if not self._cache_file.exists():
            logging.info("New cache file %s", self._cache_file)
            return False

        now = time.time()
        try:
            age = now - self._cache_file.stat().st_mtime
        except OSError as e:
            if e.errno == 2:  # No such file or directory
                logging.info("Cannot calculate cache file age of %s", self._cache_file)
                return False
            else:
                raise

        if age >= self.interval:
            logging.info("Cache file %s is outdated", self._cache_file)
            return False

        if colleague_contents.cache_timestamp > now:
            logging.info("Colleague data is newer than cache file %s", self._cache_file)
            return False
        return True

    def _read_from_cache(self):
        try:
            with self._cache_file.open(encoding="utf-8") as f:
                raw_content = f.read().strip()
        except IOError as e:
            if e.errno == errno.ENOENT:  # No such file or directory
                return None, 0.0
            else:
                raise
        try:
            content = json.loads(raw_content)
        except ValueError as e:
            logging.info(e)
            content = None
        return content, self._cache_file.stat().st_mtime

    def _write_to_cache(self, raw_content):
        json_dump = json.dumps(raw_content, default=_datetime_converter)
        store.save_file(str(self._cache_file), json_dump)

    @abc.abstractmethod
    def _get_colleague_contents(self):
        # type: AWSColleagueContents
        """
        Receive section contents from colleagues. The results are stored in
        self._receive_results: {<KEY>: AWSComputedContent}.
        The relation between two sections must be declared in the related
        distributor in advance to make this work.
        Use max. cache_timestamp of all received results for
        AWSColleagueContents.cache_timestamp
        """
        pass

    @abc.abstractmethod
    def _fetch_raw_content(self, colleague_contents):
        """
        Call API methods, eg. 'response = ec2_client.describe_instances()' and
        extract content from raw content.  Raw contents basically consist of
        two sub results:
        - 'ResponseMetadata'
        - '<KEY>'
        Return raw_result['<KEY>'].
        """
        pass

    @abc.abstractmethod
    def _compute_content(self, raw_content, colleague_contents):
        # type: (AWSRawContent, Any) -> AWSComputedContent
        """
        Compute the final content of this section based on the raw content of
        this section and the content received from the optional colleague
        sections.
        """
        pass

    @abc.abstractmethod
    def _create_results(self, computed_content):
        # type: (Any) -> List[AWSSectionResult]
        pass


class AWSSectionGeneric(AWSSection):
    __metaclass__ = abc.ABCMeta


class AWSSectionCloudwatch(AWSSection):
    __metaclass__ = abc.ABCMeta

    def _fetch_raw_content(self, colleague_contents):
        end_time = time.time()
        start_time = end_time - self.period
        metrics = self._get_metrics(colleague_contents)
        if not metrics:
            return []

        # A single GetMetricData call can include up to 100 MetricDataQuery structures
        # There's no pagination for this operation:
        # self._client.can_paginate('get_metric_data') = False
        raw_content = []
        for chunk in _chunks(metrics):
            if not chunk:
                continue
            response = self._client.get_metric_data(
                MetricDataQueries=chunk,
                StartTime=start_time,
                EndTime=end_time,
            )
            try:
                metrics = response['MetricDataResults']
            except KeyError as e:
                logging.info("%s: KeyError %s; Available are %s", self.name, e, response.keys())
                continue
            raw_content.extend(metrics)
        return raw_content

    @abc.abstractmethod
    def _get_metrics(self, colleague_contents):
        pass

    def _create_id_for_metric_data_query(self, index, metric_name, *args):
        """
        ID field must be unique in a single call.
        The valid characters are letters, numbers, and underscore.
        The first character must be a lowercase letter.
        Regex: ^[a-z][a-zA-Z0-9_]*$
        """
        return "_".join(["id", str(index)] + list(args) + [metric_name])


#.
#   .--costs/usage---------------------------------------------------------.
#   |                      _          __                                   |
#   |         ___ ___  ___| |_ ___   / /   _ ___  __ _  __ _  ___          |
#   |        / __/ _ \/ __| __/ __| / / | | / __|/ _` |/ _` |/ _ \         |
#   |       | (_| (_) \__ \ |_\__ \/ /| |_| \__ \ (_| | (_| |  __/         |
#   |        \___\___/|___/\__|___/_/  \__,_|___/\__,_|\__, |\___|         |
#   |                                                  |___/               |
#   '----------------------------------------------------------------------'

# Interval between 'Start' and 'End' must be a DateInterval. 'End' is exclusive.
# Example:
# 2017-01-01 - 2017-05-01; cost and usage data is retrieved from 2017-01-01 up
# to and including 2017-04-30 but not including 2017-05-01.
# The GetCostAndUsageRequest operation supports only DAILY and MONTHLY granularities.


class CostsAndUsage(AWSSectionGeneric):
    @property
    def name(self):
        return "costs_and_usage"

    @property
    def interval(self):
        return 86400

    def _get_colleague_contents(self):
        return AWSColleagueContents(None, 0.0)

    def _fetch_raw_content(self, colleague_contents):
        fmt = "%Y-%m-%d"
        now = time.time()
        response = self._client.get_cost_and_usage(
            TimePeriod={
                'Start': time.strftime(fmt, time.gmtime(now - self.interval)),
                'End': time.strftime(fmt, time.gmtime(now)),
            },
            Granularity='DAILY',
            Metrics=['UnblendedCost'],
            GroupBy=[{
                'Type': 'DIMENSION',
                'Key': 'LINKED_ACCOUNT'
            }, {
                'Type': 'DIMENSION',
                'Key': 'SERVICE'
            }],
        )
        try:
            return response['ResultsByTime']
        except KeyError as e:
            logging.info("%s: KeyError %s; Available are %s", self.name, e, response.keys())
            return []

    def _compute_content(self, raw_content, colleague_contents):
        return AWSComputedContent(raw_content.content, raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [AWSSectionResult("", computed_content.content)]


#.
#   .--EC2-----------------------------------------------------------------.
#   |                          _____ ____ ____                             |
#   |                         | ____/ ___|___ \                            |
#   |                         |  _|| |     __) |                           |
#   |                         | |__| |___ / __/                            |
#   |                         |_____\____|_____|                           |
#   |                                                                      |
#   '----------------------------------------------------------------------'


class EC2Summary(AWSSectionGeneric):
    def __init__(self, client, region, config, distributor=None):
        super(EC2Summary, self).__init__(client, region, config, distributor=distributor)
        self._names = self._config.service_config['ec2']['names']
        self._tags = self._config.service_config['ec2']['tags']

    @property
    def name(self):
        return "ec2_summary"

    @property
    def interval(self):
        return 300

    def _get_colleague_contents(self):
        return AWSColleagueContents(None, 0.0)

    def _fetch_raw_content(self, colleague_contents):
        return self._describe_instances()

    def _get_reservations(self, response):
        # Little hack: This was refactored in later agent versions
        try:
            return response['Reservations']
        except KeyError as e:
            logging.info("%s: KeyError %s; Available are %s", self.name, e, response.keys())
            return []

    def _describe_instances(self):
        if self._names is not None:
            response = self._client.describe_instances(InstanceIds=self._names)
            return self._get_reservations(response)

        elif self._tags is not None:
            instances = []
            for chunk in _chunks(self._tags, length=200):
                # EC2 FilterLimitExceeded: The maximum number of filter values
                # specified on a single call is 200
                response = self._client.describe_instances(Filters=chunk)
                instances.extend(self._get_reservations(response))
            return instances

        response = self._client.describe_instances()
        return self._get_reservations(response)

    def _compute_content(self, raw_content, colleague_contents):
        # PrivateIpAddress and InstanceId is available although the instance is stopped
        instances = {}
        for reservation in raw_content.content:
            for instance in reservation.get('Instances', []):
                try:
                    inst_id = "%s-%s-%s" % (instance['PrivateIpAddress'], self._region, instance['InstanceId'])
                except KeyError:
                    continue
                else:
                    instances.setdefault(inst_id, instance)
        return AWSComputedContent(instances, raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [AWSSectionResult("", computed_content.content.values())]


class EC2SecurityGroups(AWSSectionGeneric):
    def __init__(self, client, region, config, distributor=None):
        super(EC2SecurityGroups, self).__init__(client, region, config, distributor=distributor)
        self._names = self._config.service_config['ec2']['names']
        self._tags = self._config.service_config['ec2']['tags']

    @property
    def name(self):
        return "ec2_security_groups"

    @property
    def interval(self):
        return 300

    def _get_colleague_contents(self):
        colleague = self._received_results.get('ec2_summary')
        if colleague and colleague.content:
            return AWSColleagueContents(colleague.content, colleague.cache_timestamp)
        return AWSColleagueContents({}, 0.0)

    def _fetch_raw_content(self, colleague_contents):
        return {group['GroupId']: group for group in self._describe_security_groups()}

    def _get_security_groups(self, response):
        # Little hack: This was refactored in later agent versions
        try:
            return response['SecurityGroups']
        except KeyError as e:
            logging.info("%s: KeyError %s; Available are %s", self.name, e, response.keys())
            return []

    def _describe_security_groups(self):
        if self._names is not None:
            response = self._client.describe_security_groups(InstanceIds=self._names)
            return self._get_security_groups(response)

        elif self._tags is not None:
            sec_groups = []
            for chunk in _chunks(self._tags, length=200):
                # EC2 FilterLimitExceeded: The maximum number of filter values
                # specified on a single call is 200
                response = self._client.describe_security_groups(Filters=chunk)
                sec_groups.extend(self._get_security_groups(response))
            return sec_groups

        response = self._client.describe_security_groups()
        return self._get_security_groups(response)

    def _compute_content(self, raw_content, colleague_contents):
        content_by_piggyback_hosts = {}
        for instance_name, instance in colleague_contents.content.iteritems():
            for security_group_from_instance in instance.get('SecurityGroups', []):
                security_group = raw_content.content.get(security_group_from_instance['GroupId'])
                if security_group is None:
                    continue
                content_by_piggyback_hosts.setdefault(instance_name, []).append(security_group)
        return AWSComputedContent(content_by_piggyback_hosts, raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [
            AWSSectionResult(piggyback_hostname, rows)
            for piggyback_hostname, rows in computed_content.content.iteritems()
        ]


class EC2(AWSSectionCloudwatch):
    @property
    def name(self):
        return "ec2"

    @property
    def interval(self):
        return 300

    def _get_colleague_contents(self):
        colleague = self._received_results.get('ec2_summary')
        if colleague and colleague.content:
            return AWSColleagueContents(colleague.content, colleague.cache_timestamp)
        return AWSColleagueContents({}, 0.0)

    def _get_metrics(self, colleague_contents):
        metrics = []
        for idx, (instance_name, instance) in enumerate(colleague_contents.content.iteritems()):
            instance_id = instance['InstanceId']
            for metric_name, unit in [
                ("CPUCreditUsage", "Count"),
                ("CPUCreditBalance", "Count"),
                ("CPUUtilization", "Percent"),
                ("DiskReadOps", "Count"),
                ("DiskWriteOps", "Count"),
                ("DiskReadBytes", "Bytes"),
                ("DiskWriteBytes", "Bytes"),
                ("NetworkIn", "Bytes"),
                ("NetworkOut", "Bytes"),
                ("StatusCheckFailed_Instance", "Count"),
                ("StatusCheckFailed_System", "Count"),
            ]:
                metrics.append({
                    'Id': self._create_id_for_metric_data_query(idx, metric_name),
                    'Label': instance_name,
                    'MetricStat': {
                        'Metric': {
                            'Namespace': 'AWS/EC2',
                            'MetricName': metric_name,
                            'Dimensions': [{
                                'Name': "InstanceId",
                                'Value': instance_id,
                            }]
                        },
                        'Period': self.period,
                        'Stat': 'Average',
                        'Unit': unit,
                    },
                })
        return metrics

    def _compute_content(self, raw_content, colleague_contents):
        content_by_piggyback_hosts = {}
        for row in raw_content.content:
            content_by_piggyback_hosts.setdefault(row['Label'], []).append(row)
        return AWSComputedContent(content_by_piggyback_hosts, raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [
            AWSSectionResult(piggyback_hostname, rows)
            for piggyback_hostname, rows in computed_content.content.iteritems()
        ]


#.
#   .--S3------------------------------------------------------------------.
#   |                             ____ _____                               |
#   |                            / ___|___ /                               |
#   |                            \___ \ |_ \                               |
#   |                             ___) |__) |                              |
#   |                            |____/____/                               |
#   |                                                                      |
#   '----------------------------------------------------------------------'


class S3Summary(AWSSectionGeneric):
    def __init__(self, client, region, config, distributor=None):
        super(S3Summary, self).__init__(client, region, config, distributor=distributor)
        self._names = self._config.service_config['s3']['names']
        self._tags = self._prepare_s3_tags(self._config.service_config['s3']['tags'])

    def _prepare_s3_tags(self, tags):
        """
        S3 tags have a different format:
        [{'Key': KEY, 'Value': VALUE}, ...]
        """
        if not tags:
            return
        prepared_tags = []
        for tag in tags:
            tag_name = tag['Name']
            if tag_name.startswith('tag:'):
                tag_key = tag_name[4:]
            else:
                tag_key = tag_name
            prepared_tags.extend([{'Key': tag_key, 'Value': v} for v in tag['Values']])
        return prepared_tags

    @property
    def name(self):
        return "s3_summary"

    @property
    def interval(self):
        return 86400

    def _get_colleague_contents(self):
        return AWSColleagueContents(None, 0.0)

    def _fetch_raw_content(self, colleague_contents):
        found_buckets = []
        for bucket in self._list_buckets():
            bucket_name = bucket['Name']
            try:
                response = self._client.get_bucket_location(Bucket=bucket_name)
            except botocore.exceptions.ClientError as e:
                # An error occurred (AccessDenied) when calling the GetBucketLocation operation: Access Denied
                logging.info("%s/%s: Access denied, %s", self.name, bucket_name, e)
                continue

            try:
                location = response['LocationConstraint']
            except KeyError as e:
                logging.info("%s/%s: KeyError %s; Available are %s", self.name, bucket_name, e,
                             response.keys())
                continue

            # We can request buckets globally but if a bucket is located in
            # another region we do not get any results
            if location is None or location != self._region:
                continue
            bucket['LocationConstraint'] = location

            #TODO
            # Why do we get the following error while calling these methods:
            #_response = self._client.get_public_access_block(Bucket=bucket_name)
            #_response = self._client.get_bucket_policy_status(Bucket=bucket_name)
            # 'S3' object has no attribute 'get_bucket_policy_status'

            tagging = []
            try:
                response = self._client.get_bucket_tagging(Bucket=bucket_name)
                tagging = response['TagSet']
            except botocore.exceptions.ClientError as e:
                # If there are no tags attached to a bucket we receive a 'ClientError'
                logging.info("%s/%s: No tags set, %s", self.name, bucket_name, e)
            except KeyError as e:
                logging.info("%s/%s: KeyError %s; Available are %s", self.name, bucket_name, e,
                             response.keys())
            if self._matches_conditions(tagging):
                bucket['Tagging'] = tagging
                found_buckets.append(bucket)
        return found_buckets

    def _list_buckets(self):
        if self._tags is None and self._names is not None:
            return [{'Name': n} for n in self._names]
        response = self._client.list_buckets()
        try:
            return response['Buckets']
        except KeyError as e:
            logging.info("%s: KeyError %s; Available are %s", self.name, e, response.keys())
            return []

    def _matches_conditions(self, tagging):
        if self._names is not None:
            return True
        if self._tags is None:
            return True
        for tag in tagging:
            if tag in self._tags:
                return True
        return False

    def _compute_content(self, raw_content, colleague_contents):
        return AWSComputedContent({bucket['Name']: bucket for bucket in raw_content.content},
                                  raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [AWSSectionResult("", None)]


class S3(AWSSectionCloudwatch):
    @property
    def name(self):
        return "s3"

    @property
    def interval(self):
        return 86400

    def _get_colleague_contents(self):
        colleague = self._received_results.get('s3_summary')
        if colleague and colleague.content:
            return AWSColleagueContents(colleague.content, colleague.cache_timestamp)
        return AWSColleagueContents({}, 0.0)

    def _get_metrics(self, colleague_contents):
        metrics = []
        for idx, bucket_name in enumerate(colleague_contents.content.iterkeys()):
            for metric_name, unit, storage_classes in [
                ("BucketSizeBytes", "Bytes", [
                    "StandardStorage",
                    "StandardIAStorage",
                    "ReducedRedundancyStorage",
                ]),
                ("NumberOfObjects", "Count", ["AllStorageTypes"]),
            ]:
                for storage_class in storage_classes:
                    metrics.append({
                        'Id':
                            self._create_id_for_metric_data_query(idx, metric_name, storage_class),
                        'Label':
                            bucket_name,
                        'MetricStat': {
                            'Metric': {
                                'Namespace':
                                    'AWS/S3',
                                'MetricName':
                                    metric_name,
                                'Dimensions': [{
                                    'Name': "BucketName",
                                    'Value': bucket_name,
                                }, {
                                    'Name': 'StorageType',
                                    'Value': storage_class,
                                }]
                            },
                            'Period': self.period,
                            'Stat': 'Average',
                            'Unit': unit,
                        },
                    })
        return metrics

    def _compute_content(self, raw_content, colleague_contents):
        for row in raw_content.content:
            bucket = colleague_contents.content.get(row['Label'])
            if bucket:
                row.update(bucket)
        return AWSComputedContent(raw_content.content, raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [AWSSectionResult("", computed_content.content)]


#.
#   .--ELB-----------------------------------------------------------------.
#   |                          _____ _     ____                            |
#   |                         | ____| |   | __ )                           |
#   |                         |  _| | |   |  _ \                           |
#   |                         | |___| |___| |_) |                          |
#   |                         |_____|_____|____/                           |
#   |                                                                      |
#   '----------------------------------------------------------------------'


class ELBSummary(AWSSectionGeneric):
    def __init__(self, client, region, config, distributor=None):
        super(ELBSummary, self).__init__(client, region, config, distributor=distributor)
        self._names = self._config.service_config['elb']['names']
        self._tags = self._prepare_elb_tags(self._config.service_config['elb']['tags'])

    def _prepare_elb_tags(self, tags):
        """
        ELB tags have a different format:
        [{'Key': KEY, 'Value': VALUE}, ...]
        """
        if not tags:
            return
        prepared_tags = []
        for tag in tags:
            tag_name = tag['Name']
            if tag_name.startswith('tag:'):
                tag_key = tag_name[4:]
            else:
                tag_key = tag_name
            prepared_tags.extend([{'Key': tag_key, 'Value': v} for v in tag['Values']])
        return prepared_tags

    @property
    def name(self):
        return "elb_summary"

    @property
    def interval(self):
        return 300

    def _get_colleague_contents(self):
        return AWSColleagueContents(None, 0.0)

    def _fetch_raw_content(self, colleague_contents):
        found_load_balancers = []
        for load_balancer in self._describe_load_balancers():
            load_balancer_name = load_balancer['LoadBalancerName']
            tagging = []
            try:
                response = self._client.describe_tags(LoadBalancerNames=[load_balancer_name])
                tagging = [
                    tag for tag_descr in response['TagDescriptions'] for tag in tag_descr['Tags']
                ]
            except botocore.exceptions.ClientError as e:
                # If there are no tags attached to a bucket we receive a 'ClientError'
                logging.info("%s/%s: No tags set, %s", self.name, load_balancer_name, e)
            except KeyError as e:
                logging.info("%s/%s: KeyError %s; Available are %s", self.name, load_balancer_name,
                             e, response.keys())
            if self._matches_conditions(tagging):
                load_balancer['TagDescriptions'] = tagging
                found_load_balancers.append(load_balancer)
        return found_load_balancers

    def _describe_load_balancers(self):
        if self._tags is None and self._names is not None:
            response = self._client.describe_load_balancers(LoadBalancerNames=self._names)
        else:
            response = self._client.describe_load_balancers()
        try:
            return response['LoadBalancerDescriptions']
        except KeyError, e:
            logging.info("%s: KeyError %s; Available are %s", self.name, e, response.keys())
            return []

    def _matches_conditions(self, tagging):
        if self._names is not None:
            return True
        if self._tags is None:
            return True
        for tag in tagging:
            if tag in self._tags:
                return True
        return False

    def _compute_content(self, raw_content, colleague_contents):
        content_by_piggyback_hosts = {}
        for load_balancer in raw_content.content:
            content_by_piggyback_hosts.setdefault(load_balancer['DNSName'], load_balancer)
        return AWSComputedContent(content_by_piggyback_hosts, raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [AWSSectionResult("", computed_content.content.values())]


class ELBHealth(AWSSectionGeneric):
    @property
    def name(self):
        return "elb_health"

    @property
    def interval(self):
        return 300

    def _get_colleague_contents(self):
        colleague = self._received_results.get('elb_summary')
        if colleague and colleague.content:
            return AWSColleagueContents(colleague.content, colleague.cache_timestamp)
        return AWSColleagueContents({}, 0.0)

    def _fetch_raw_content(self, colleague_contents):
        load_balancers = {}
        for load_balancer_dns_name, load_balancer in colleague_contents.content.iteritems():
            load_balancer_name = load_balancer['LoadBalancerName']
            response = self._client.describe_instance_health(LoadBalancerName=load_balancer_name)
            try:
                states = response['InstanceStates']
            except KeyError as e:
                logging.info("%s/%s: KeyError %s; Available are %s", self.name, load_balancer_name,
                             e, response.keys())
            else:
                load_balancers.setdefault(load_balancer_dns_name, states)
        return load_balancers

    def _compute_content(self, raw_content, colleague_contents):
        return AWSComputedContent(raw_content.content, raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [
            AWSSectionResult(piggyback_hostname, content)
            for piggyback_hostname, content in computed_content.content.iteritems()
        ]


class ELB(AWSSectionCloudwatch):
    @property
    def name(self):
        return "elb"

    @property
    def interval(self):
        return 300

    def _get_colleague_contents(self):
        colleague = self._received_results.get('elb_summary')
        if colleague and colleague.content:
            return AWSColleagueContents(colleague.content, colleague.cache_timestamp)
        return AWSColleagueContents({}, 0.0)

    def _get_metrics(self, colleague_contents):
        metrics = []
        for idx, (load_balancer_dns_name,
                  load_balancer) in enumerate(colleague_contents.content.iteritems()):
            load_balancer_name = load_balancer['LoadBalancerName']
            for metric_name, stat in [
                ("RequestCount", "Sum"),
                ("SurgeQueueLength", "Maximum"),
                ("SpilloverCount", "Sum"),
                ("Latency", "Average"),
                ("HTTPCode_ELB_4XX", "Sum"),
                ("HTTPCode_ELB_5XX", "Sum"),
                ("HTTPCode_Backend_2XX", "Sum"),
                ("HTTPCode_Backend_3XX", "Sum"),
                ("HTTPCode_Backend_4XX", "Sum"),
                ("HTTPCode_Backend_5XX", "Sum"),
                ("HealthyHostCount", "Average"),
                ("UnHealthyHostCount", "Average"),
                ("BackendConnectionErrors", "Sum"),
            ]:
                metrics.append({
                    'Id': self._create_id_for_metric_data_query(idx, metric_name),
                    'Label': load_balancer_dns_name,
                    'MetricStat': {
                        'Metric': {
                            'Namespace':
                                'AWS/ELB',
                            'MetricName':
                                metric_name,
                            'Dimensions': [{
                                'Name': "LoadBalancerName",
                                'Value': load_balancer_name,
                            }]
                        },
                        'Period': self.period,
                        'Stat': stat,
                    },
                })
        return metrics

    def _compute_content(self, raw_content, colleague_contents):
        content_by_piggyback_hosts = {}
        for row in raw_content.content:
            content_by_piggyback_hosts.setdefault(row['Label'], []).append(row)
        return AWSComputedContent(content_by_piggyback_hosts, raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [
            AWSSectionResult(piggyback_hostname, rows)
            for piggyback_hostname, rows in computed_content.content.iteritems()
        ]


#.
#   .--EBS-----------------------------------------------------------------.
#   |                          _____ ____ ____                             |
#   |                         | ____| __ ) ___|                            |
#   |                         |  _| |  _ \___ \                            |
#   |                         | |___| |_) |__) |                           |
#   |                         |_____|____/____/                            |
#   |                                                                      |
#   '----------------------------------------------------------------------'

# EBS are attached to EC2 instances. Thus we put the content to related EC2
# instance as piggyback host.


class EBSSummary(AWSSectionGeneric):
    def __init__(self, client, region, config, distributor=None):
        super(EBSSummary, self).__init__(client, region, config, distributor=distributor)
        self._names = self._config.service_config['ebs']['names']
        self._tags = self._config.service_config['ebs']['tags']

    @property
    def name(self):
        return "ebs_summary"

    @property
    def interval(self):
        return 300

    def _get_colleague_contents(self):
        colleague = self._received_results.get('ec2_summary')
        if colleague and colleague.content:
            return AWSColleagueContents(colleague.content, colleague.cache_timestamp)
        return AWSColleagueContents({}, 0.0)

    def _fetch_raw_content(self, colleague_contents):
        volumes = {r['VolumeId']: r for r in self._describe_volumes()}
        try:
            response_volume_stats = self._client.describe_volume_status(VolumeIds=volumes.keys())
            volume_states = {r['VolumeId']: r for r in response_volume_stats['VolumeStatuses']}
        except KeyError as e:
            logging.info("%s: KeyError %s; Available are %s", self.name, e,
                         response_volume_stats.keys())
            volume_states = {}
        return volumes, volume_states

    def _get_volumes(self, response):
        # Little hack: This was refactored in later agent versions
        try:
            return response['Volumes']
        except KeyError as e:
            logging.info("%s: KeyError %s; Available are %s", self.name, e, response.keys())
            return []

    def _describe_volumes(self):
        if self._names is not None:
            response = self._client.describe_volumes(VolumeIds=self._names)
            return self._get_volumes(response)

        elif self._tags is not None:
            volumes = []
            for chunk in _chunks(self._tags, length=200):
                # EC2 FilterLimitExceeded: The maximum number of filter values
                # specified on a single call is 200
                response = self._client.describe_volumes(Filters=chunk)
                volumes.extend(self._get_volumes(response))
            return volumes

        response = self._client.describe_volumes()
        return self._get_volumes(response)

    def _compute_content(self, raw_content, colleague_contents):
        volumes, volume_states = raw_content.content
        content = []
        for volume_id in set(volumes.keys()).union(set(volume_states.keys())):
            volume = volumes.get(volume_id, {})
            volume.update(volume_states.get(volume_id, {}))
            content.append(volume)

        instance_name_mapping = {
            v['InstanceId']: k for k, v in colleague_contents.content.iteritems()
        }
        content_by_piggyback_hosts = {}
        for row in content:
            for attachment in row['Attachments']:
                attachment_id = attachment['InstanceId']
                instance_name = instance_name_mapping.get(attachment_id)
                if instance_name is None:
                    instance_name = ""
                content_by_piggyback_hosts.setdefault(instance_name, []).append(row)
        return AWSComputedContent(content_by_piggyback_hosts, raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [
            AWSSectionResult(piggyback_hostname, rows)
            for piggyback_hostname, rows in computed_content.content.iteritems()
        ]


class EBS(AWSSectionCloudwatch):
    @property
    def name(self):
        return "ebs"

    @property
    def interval(self):
        return 300

    def _get_colleague_contents(self):
        colleague = self._received_results.get('ebs_summary')
        if colleague and colleague.content:
            return AWSColleagueContents([(instance_name, row['VolumeId'], row['VolumeType'])
                                         for instance_name, rows in colleague.content.iteritems()
                                         for row in rows], colleague.cache_timestamp)
        return AWSColleagueContents([], 0.0)

    def _get_metrics(self, colleague_contents):
        metrics = []
        for idx, (instance_name, volume_name, volume_type) in enumerate(colleague_contents.content):
            for metric_name, unit, volume_types in [
                ("VolumeReadOps", "Count", []),
                ("VolumeWriteOps", "Count", []),
                ("VolumeReadBytes", "Bytes", []),
                ("VolumeWriteBytes", "Bytes", []),
                ("VolumeQueueLength", "Count", []),
                ("BurstBalance", "Percent", ["gp2", "st1", "sc1"]),
                    #("VolumeThroughputPercentage", "Percent", ["io1"]),
                    #("VolumeConsumedReadWriteOps", "Count", ["io1"]),
                    #("VolumeTotalReadTime", "Seconds", []),
                    #("VolumeTotalWriteTime", "Seconds", []),
                    #("VolumeIdleTime", "Seconds", []),
                    #("VolumeStatus", None, []),
                    #("IOPerformance", None, ["io1"]),
            ]:
                if volume_types and volume_type not in volume_types:
                    continue
                metric = {
                    'Id': self._create_id_for_metric_data_query(idx, metric_name),
                    'Label': instance_name,
                    'MetricStat': {
                        'Metric': {
                            'Namespace': 'AWS/EBS',
                            'MetricName': metric_name,
                            'Dimensions': [{
                                'Name': "VolumeID",
                                'Value': volume_name,
                            }]
                        },
                        'Period': self.period,
                        'Stat': 'Average',
                    },
                }
                if unit:
                    metric['MetricStat']['Unit'] = unit
                metrics.append(metric)
        return metrics

    def _compute_content(self, raw_content, colleague_contents):
        content_by_piggyback_hosts = {}
        for row in raw_content.content:
            content_by_piggyback_hosts.setdefault(row['Label'], []).append(row)
        return AWSComputedContent(content_by_piggyback_hosts, raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [
            AWSSectionResult(piggyback_hostname, rows)
            for piggyback_hostname, rows in computed_content.content.iteritems()
        ]


#.
#   .--RDS-----------------------------------------------------------------.
#   |                          ____  ____  ____                            |
#   |                         |  _ \|  _ \/ ___|                           |
#   |                         | |_) | | | \___ \                           |
#   |                         |  _ <| |_| |___) |                          |
#   |                         |_| \_\____/|____/                           |
#   |                                                                      |
#   '----------------------------------------------------------------------'


class RDSSummary(AWSSectionGeneric):
    def __init__(self, client, region, config, distributor=None):
        super(RDSSummary, self).__init__(client, region, config, distributor=distributor)
        self._names = self._config.service_config['rds']['names']
        self._tags = self._config.service_config['rds']['tags']

    @property
    def name(self):
        return "rds_summary"

    @property
    def interval(self):
        return 300

    def _get_colleague_contents(self):
        return AWSColleagueContents(None, 0.0)

    def _fetch_raw_content(self, colleague_contents):
        response = self._describe_db_instances()
        try:
            return response['DBInstances']
        except KeyError as e:
            logging.info("%s: KeyError %s; Available are %s", self.name, e, response.keys())
            return []

    def _describe_db_instances(self):
        if self._names is not None:
            return [
                self._client.describe_db_instances(DBInstanceIdentifier=name)
                for name in self._names
            ]
        elif self._tags is not None:
            return [self._client.describe_db_instances(Filters=self._tags) for name in self._names]
        return self._client.describe_db_instances()

    def _compute_content(self, raw_content, colleague_contents):
        return AWSComputedContent(
            {instance['DBInstanceIdentifier']: instance for instance in raw_content.content},
            raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [AWSSectionResult("", computed_content.content.values())]


class RDS(AWSSectionCloudwatch):
    @property
    def name(self):
        return "rds"

    @property
    def interval(self):
        return 300

    def _get_colleague_contents(self):
        colleague = self._received_results.get('rds_summary')
        if colleague and colleague.content:
            return AWSColleagueContents(colleague.content, colleague.cache_timestamp)
        return AWSColleagueContents({}, 0.0)

    def _get_metrics(self, colleague_contents):
        metrics = []
        for idx, instance_id in enumerate(colleague_contents.content.iterkeys()):
            for metric_name, unit in [
                ("BinLogDiskUsage", "Bytes"),
                ("BurstBalance", "Percent"),
                ("CPUUtilization", "Percent"),
                ("CPUCreditUsage", "Count"),
                ("CPUCreditBalance", "Count"),
                ("DatabaseConnections", "Count"),
                ("DiskQueueDepth", "Count"),
                ("FailedSQLServerAgentJobsCount", "Count/Second"),
                ("NetworkReceiveThroughput", "Bytes/Second"),
                ("NetworkTransmitThroughput", "Bytes/Second"),
                ("OldestReplicationSlotLag", "Megabytes"),
                ("ReadIOPS", "Count/Second"),
                ("ReadLatency", "Seconds"),
                ("ReadThroughput", "Bytes/Second"),
                ("ReplicaLag", "Seconds"),
                ("ReplicationSlotDiskUsage", "Megabytes"),
                ("TransactionLogsDiskUsage", "Megabytes"),
                ("TransactionLogsGeneration", "Megabytes/Second"),
                ("WriteIOPS", "Count/Second"),
                ("WriteLatency", "Seconds"),
                ("WriteThroughput", "Bytes/Second"),
                    #("FreeableMemory", "Bytes"),
                    #("SwapUsage", "Bytes"),
                    #("FreeStorageSpace", "Bytes"),
                    #("MaximumUsedTransactionIDs", "Count"),
            ]:
                metric = {
                    'Id': self._create_id_for_metric_data_query(idx, metric_name),
                    'Label': instance_id,
                    'MetricStat': {
                        'Metric': {
                            'Namespace': 'AWS/RDS',
                            'MetricName': metric_name,
                            'Dimensions': [{
                                'Name': "DBInstanceIdentifier",
                                'Value': instance_id,
                            }]
                        },
                        'Period': self.period,
                        'Stat': 'Average',
                    },
                }
                if unit:
                    metric['MetricStat']['Unit'] = unit
                metrics.append(metric)
        return metrics

    def _compute_content(self, raw_content, colleague_contents):
        for row in raw_content.content:
            row.update(colleague_contents.content.get(row['Label'], {}))
        return AWSComputedContent(raw_content.content, raw_content.cache_timestamp)

    def _create_results(self, computed_content):
        return [AWSSectionResult("", computed_content.content)]


#.
#   .--sections------------------------------------------------------------.
#   |                               _   _                                  |
#   |                 ___  ___  ___| |_(_) ___  _ __  ___                  |
#   |                / __|/ _ \/ __| __| |/ _ \| '_ \/ __|                 |
#   |                \__ \  __/ (__| |_| | (_) | | | \__ \                 |
#   |                |___/\___|\___|\__|_|\___/|_| |_|___/                 |
#   |                                                                      |
#   '----------------------------------------------------------------------'


class AWSSections(object):
    __metaclass__ = abc.ABCMeta

    def __init__(self, hostname, session, debug=False):
        self._hostname = hostname
        self._session = session
        self._debug = debug
        self._sections = []

    @abc.abstractmethod
    def init_sections(self, services, region, config):
        pass

    def _init_client(self, client_key):
        try:
            return self._session.client(client_key)
        except (ValueError, botocore.exceptions.ClientError,
                botocore.exceptions.UnknownServiceError):
            # If region name is not valid we get a ValueError
            # but not in all cases, eg.:
            # 1. 'eu-central-' raises a ValueError
            # 2. 'foobar' does not raise a ValueError
            # In the second case we get an exception raised by botocore
            # during we execute an operation, eg. cloudwatch.get_metrics(**kwargs):
            # - botocore.exceptions.EndpointConnectionError
            raise

    def run(self, use_cache=True):
        exceptions = []
        results = {}
        for section in self._sections:
            try:
                section_result = section.run(use_cache=use_cache)
            except AssertionError as e:
                logging.info(e)
                if self._debug:
                    raise
            except Exception as e:
                logging.info(e)
                if self._debug:
                    raise
                exceptions.append(e)
            else:
                results.setdefault((section.name, section_result.cache_timestamp, section.interval),
                                   section_result.results)

        self._write_exceptions(exceptions)
        self._write_section_results(results)

    def _write_exceptions(self, exceptions):
        sys.stdout.write("<<<aws_exceptions>>>\n")
        if exceptions:
            out = "\n".join([e.message for e in exceptions])
        else:
            out = "No exceptions"
        sys.stdout.write("%s: %s\n" % (self.__class__.__name__, out))

    def _write_section_results(self, results):
        if not results:
            logging.info("%s: No results or cached data", self.__class__.__name__)
            return

        for (section_name, cache_timestamp, section_interval), result in results.iteritems():
            if not result:
                logging.info("%s: No results", section_name)
                continue

            if not isinstance(result, list):
                logging.info(
                    "%s: Section result must be of type 'list' containing 'AWSSectionResults'",
                    section_name)
                continue

            cached_suffix = ""
            if section_interval > 60:
                cached_suffix = ":cached(%s,%s)" % (int(cache_timestamp), section_interval + 60)

            if any([r.content for r in result]):
                self._write_section_result(section_name, cached_suffix, result)

    def _write_section_result(self, section_name, cached_suffix, result):
        section_header = "<<<aws_%s%s>>>\n" % (section_name, cached_suffix)
        for row in result:
            write_piggyback_header = row.piggyback_hostname\
                                     and row.piggyback_hostname != self._hostname
            if write_piggyback_header:
                sys.stdout.write("<<<<%s>>>>\n" % row.piggyback_hostname)
            sys.stdout.write(section_header)
            sys.stdout.write("%s\n" % json.dumps(row.content, default=_datetime_converter))
            if write_piggyback_header:
                sys.stdout.write("<<<<>>>>\n")


class AWSSectionsUSEast(AWSSections):
    """
    Some clients like CostExplorer only work with US East region:
    https://docs.aws.amazon.com/awsaccountbilling/latest/aboutv2/ce-api.html
    """

    def init_sections(self, services, region, config):
        #---clients---------------------------------------------------------
        ce_client = self._init_client('ce')

        #---distributors----------------------------------------------------

        #---sections with distributors--------------------------------------

        #---sections--------------------------------------------------------
        ce = CostsAndUsage(ce_client, region, config)

        #---register sections to distributors-------------------------------

        #---register sections for execution---------------------------------
        if 'ce' in services:
            self._sections.append(ce)


class AWSSectionsGeneric(AWSSections):
    def init_sections(self, services, region, config):
        #---clients---------------------------------------------------------
        ec2_client = self._init_client('ec2')
        elb_client = self._init_client('elb')
        s3_client = self._init_client('s3')
        rds_client = self._init_client('rds')
        cloudwatch_client = self._init_client('cloudwatch')

        #---distributors----------------------------------------------------
        ec2_summary_distributor = ResultDistributor()
        elb_summary_distributor = ResultDistributor()
        ebs_summary_distributor = ResultDistributor()
        s3_summary_distributor = ResultDistributor()
        rds_summary_distributor = ResultDistributor()

        #---sections with distributors--------------------------------------
        ec2_summary = EC2Summary(ec2_client, region, config, ec2_summary_distributor)
        ebs_summary = EBSSummary(ec2_client, region, config, ebs_summary_distributor)
        elb_summary = ELBSummary(elb_client, region, config, elb_summary_distributor)
        s3_summary = S3Summary(s3_client, region, config, s3_summary_distributor)
        rds_summary = RDSSummary(rds_client, region, config, rds_summary_distributor)

        #---sections--------------------------------------------------------
        elb_health = ELBHealth(elb_client, region, config)
        ec2_security_groups = EC2SecurityGroups(ec2_client, region, config)
        ec2 = EC2(cloudwatch_client, region, config)
        ebs = EBS(cloudwatch_client, region, config)
        elb = ELB(cloudwatch_client, region, config)
        s3 = S3(cloudwatch_client, region, config)
        rds = RDS(cloudwatch_client, region, config)

        #---register sections to distributors-------------------------------
        ec2_summary_distributor.add(ec2_security_groups)
        ec2_summary_distributor.add(ec2)
        ec2_summary_distributor.add(ebs_summary)
        ec2_summary_distributor.add(ebs)
        ebs_summary_distributor.add(ebs)
        elb_summary_distributor.add(elb_health)
        elb_summary_distributor.add(elb)
        s3_summary_distributor.add(s3)
        rds_summary_distributor.add(rds)

        #---register sections for execution---------------------------------
        # Dependencies: First append sections which distribute their results:
        # --ec2_summary ('ec2')
        #   |
        #   |-- ec2 ('ec2')
        #   |
        #   |-- ebs_summary ('ec2', 'ebs')
        #   |       |
        #   |       |-- ebs ('ec2', 'ebs')
        #   |
        #   |-- ebs ('ec2')
        #
        # -- elb_summary
        #    |
        #    |-- elb_health
        #    |
        #    |-- elb
        if 'ec2' in services:
            self._sections.append(ec2_summary)
            self._sections.append(ec2_security_groups)
            self._sections.append(ec2)

        if 'ebs' in services:
            self._sections.append(ebs_summary)
            self._sections.append(ebs)

        if 'elb' in services:
            self._sections.append(elb_summary)
            self._sections.append(elb_health)
            self._sections.append(elb)

        if 's3' in services:
            self._sections.append(s3_summary)
            self._sections.append(s3)

        if 'rds' in services:
            self._sections.append(rds_summary)
            self._sections.append(rds)


#.
#   .--main----------------------------------------------------------------.
#   |                                       _                              |
#   |                       _ __ ___   __ _(_)_ __                         |
#   |                      | '_ ` _ \ / _` | | '_ \                        |
#   |                      | | | | | | (_| | | | | |                       |
#   |                      |_| |_| |_|\__,_|_|_| |_|                       |
#   |                                                                      |
#   '----------------------------------------------------------------------'

AWSRegions = [
    ("ap-south-1", "Asia Pacific (Mumbai)"),
    ("ap-northeast-3", "Asia Pacific (Osaka-Local)"),
    ("ap-northeast-2", "Asia Pacific (Seoul)"),
    ("ap-southeast-1", "Asia Pacific (Singapore)"),
    ("ap-southeast-2", "Asia Pacific (Sydney)"),
    ("ap-northeast-1", "Asia Pacific (Tokyo)"),
    ("ca-central-1", "Canada (Central)"),
    ("cn-north-1", "China (Beijing)"),
    ("cn-northwest-1", "China (Ningxia)"),
    ("eu-central-1", "EU (Frankfurt)"),
    ("eu-west-1", "EU (Ireland)"),
    ("eu-west-2", "EU (London)"),
    ("eu-west-3", "EU (Paris)"),
    ("eu-north-1", "EU (Stockholm)"),
    ("sa-east-1", "South America (Sao Paulo)"),
    ("us-east-2", "US East (Ohio)"),
    ("us-east-1", "US East (N. Virginia)"),
    ("us-west-1", "US West (N. California)"),
    ("us-west-2", "US West (Oregon)"),
]

AWSServiceAttributes = NamedTuple("AWSServiceAttributes", [
    ("key", str),
    ("title", str),
    ("global_service", bool),
    ("filter_by_names_or_tags", bool),
])

AWSServices = [
    AWSServiceAttributes("ce", "Costs and usage", True, False),
    AWSServiceAttributes("ec2", "Elastic Compute Cloud (EC2)", False, True),
    AWSServiceAttributes("ebs", "Elastic Block Storage (EBS)", False, True),
    AWSServiceAttributes("s3", "Simple Storage Service (S3)", False, True),
    AWSServiceAttributes("elb", "Elastic Load Balancing (ELB)", False, True),
    AWSServiceAttributes("rds", "Relational Database Service (RDS)", False, True),
]


def parse_arguments(argv):
    parser = argparse.ArgumentParser(
        description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument("--debug", action="store_true", help="Raise Python exceptions.")
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Log messages from AWS library 'boto3' and 'botocore'.")
    parser.add_argument(
        "--no-cache",
        action="store_true",
        help="Execute all sections, do not rely on cached data. Cached data will not be overwritten."
    )

    parser.add_argument(
        "--access-key-id", required=True, help="The access key for your AWS account.")
    parser.add_argument(
        "--secret-access-key", required=True, help="The secret key for your AWS account.")
    parser.add_argument(
        "--regions",
        nargs='+',
        help="Regions to use:\n%s" % "\n".join(["%-15s %s" % e for e in AWSRegions]))

    parser.add_argument(
        "--global-services",
        nargs='+',
        help="Global services to monitor:\n%s" % "\n".join(
            ["%-15s %s" % (e.key, e.title) for e in AWSServices if e.global_service]))

    parser.add_argument(
        "--services",
        nargs='+',
        help="Services per region to monitor:\n%s" % "\n".join(
            ["%-15s %s" % (e.key, e.title) for e in AWSServices if not e.global_service]))

    for service in AWSServices:
        if service.filter_by_names_or_tags:
            parser.add_argument(
                '--%s-names' % service.key, nargs='+', help="Names for %s" % service.title)
            parser.add_argument(
                '--%s-tag-key' % service.key,
                nargs=1,
                action='append',
                help="Tag key for %s" % service.title)
            parser.add_argument(
                '--%s-tag-values' % service.key,
                nargs='+',
                action='append',
                help="Tag values for %s" % service.title)

    parser.add_argument('--overall-tag-key', nargs=1, action='append', help="Overall tag key")
    parser.add_argument(
        '--overall-tag-values', nargs='+', action='append', help="Overall tag values")

    parser.add_argument("--hostname", required=True)
    return parser.parse_args(argv)


def setup_logging(opt_debug, opt_verbose):
    logger = logging.getLogger()
    logger.disabled = True
    fmt = '%(levelname)s: %(name)s: %(filename)s: %(lineno)s: %(message)s'
    lvl = logging.INFO
    if opt_verbose:
        logger.disabled = False
        lvl = logging.DEBUG
    elif opt_debug:
        logger.disabled = False
    logging.basicConfig(level=lvl, format=fmt)


def create_session(access_key_id, secret_access_key, region):
    return boto3.session.Session(
        aws_access_key_id=access_key_id,
        aws_secret_access_key=secret_access_key,
        region_name=region)


class AWSConfig(object):
    def __init__(self, hostname, overall_tags):
        self.hostname = hostname
        self._overall_tags = self._prepare_tags(overall_tags)
        self.service_config = {}

    def add_service_config(self, key, names, tags):
        """Convert tags via commandline input
        from
            ([['foo'], ['aaa'], ...], [['bar', 'baz'], ['bbb', 'ccc'], ...])
        to
            Filters=[{'Name': 'tag:foo', 'Values': ['bar', 'baz']},
                     {'Name': 'tag:aaa', 'Values': ['bbb', 'ccc']}, ...]
        as we need in API methods if and only if keys AND values are set.
        """
        self.service_config.setdefault(key, {"names": None, "tags": None})
        if names:
            self.service_config[key]["names"] = names
        if tags != (None, None):
            self.service_config[key]["tags"] = self._prepare_tags(tags)
        elif self._overall_tags:
            self.service_config[key]["tags"] = self._overall_tags

    def _prepare_tags(self, tags):
        keys, values = tags
        if keys and values:
            return [{
                'Name': 'tag:%s' % k,
                'Values': v
            } for k, v in zip([k[0] for k in keys], values)]
        return


def main(args=None):
    if args is None:
        cmk.password_store.replace_passwords()
        args = sys.argv[1:]

    args = parse_arguments(args)
    setup_logging(args.debug, args.verbose)
    hostname = args.hostname

    aws_config = AWSConfig(hostname, (args.overall_tag_key, args.overall_tag_values))
    for service_key, service_names, service_tags in [
        ("ec2", args.ec2_names, (args.ec2_tag_key, args.ec2_tag_values)),
        ("ebs", args.ebs_names, (args.ebs_tag_key, args.ebs_tag_values)),
        ("s3", args.s3_names, (args.s3_tag_key, args.s3_tag_values)),
        ("elb", args.elb_names, (args.elb_tag_key, args.elb_tag_values)),
        ("rds", args.rds_names, (args.rds_tag_key, args.rds_tag_values)),
    ]:
        aws_config.add_service_config(service_key, service_names, service_tags)

    has_exceptions = False
    for aws_services, aws_regions, aws_sections in [
        (args.global_services, ["us-east-1"], AWSSectionsUSEast),
        (args.services, args.regions, AWSSectionsGeneric),
    ]:
        if not aws_services:
            continue
        for region in aws_regions:
            try:
                session = create_session(args.access_key_id, args.secret_access_key, region)
                sections = aws_sections(hostname, session, debug=args.debug)
                sections.init_sections(aws_services, region, aws_config)
                sections.run(use_cache=not args.no_cache)
            except AssertionError:
                if args.debug:
                    return 1
            except Exception as e:
                logging.info(e)
                has_exceptions = True
                if args.debug:
                    return 1
    if has_exceptions:
        return 1
    return 0


if __name__ == "__main__":
    sys.exit(main())
