#!/usr/bin/env python

import urllib2

from nagioscheck import NagiosCheck, UsageError
from nagioscheck import PerformanceMetric, Status


try:
    import json
except ImportError:
    import simplejson as json

HEALTH = {'red': 0,
          'yellow': 1,
          'green': 2}

RED = HEALTH['red']
YELLOW = HEALTH['yellow']
GREEN = HEALTH['green']

HEALTH_MAP = {0: 'critical',
              1: 'warning',
              2: 'ok'}

SHARD_STATE = {'UNASSIGNED': 1,
               'INITIALIZING': 2,
               'STARTED': 3,
               'RELOCATING': 4}


class ESShard(object):
    def __init__(self, state):
        self.state = state


class ESIndex(object):
    def __init__(self, name, n_shards, n_replicas):
        self.name = name
        self.n_shards = n_shards
        self.n_replicas = n_replicas


class ESNode(object):
    def __init__(self, name=None, esid=None, attributes={}):
        self.esid = esid
        self.name = name
        self.attributes = attributes


class ElasticSearchCheck(NagiosCheck):
    version = '1.1.0'

    def __init__(self):
        NagiosCheck.__init__(self)

        self.health = HEALTH['green']

        self.add_option('f', 'failure-domain', 'failure_domain',
                        "A comma-separated list of ElasticSearch "
                        "attributes that make up your cluster's "
                        "failure domain[0]. This should be the same list "
                        "of attributes that ElasticSearch's location-"
                        "aware shard allocator has been configured "
                        "with. If this option is supplied, additional "
                        "checks are carried out to ensure that primary "
                        "and replica shards are not stored in the same "
                        "failure domain. "
                        "[0]: http://www.elasticsearch.org/guide/en/"
                        "elasticsearch/reference/current/modules-cluster.html")

        self.add_option('H', 'host', 'host',
                        "Hostname or network address to probe. "
                        "The ElasticSearch API should be listening here. "
                        "Defaults to 'localhost'.")

        self.add_option('m', 'master-nodes', 'master_nodes',
                        "Issue a warning if the number of master-eligible "
                        "nodes in the cluster drops below this "
                        "number. By default, do not monitor the "
                        "number of nodes in the cluster.")

        self.add_option('p', 'port', 'port',
                        "TCP port to probe. "
                        "The ElasticSearch API should be listening "
                        "here. Defaults to 9200.")

        self.add_option('r', 'prefix', 'prefix',
                        "Optional prefix (e.g. 'es') for the "
                        "ElasticSearch API. Defaults to ''.")

    def check(self, opts, args):
        host = opts.host or "localhost"
        port = int(opts.port or '9200')
        prefix = opts.prefix or ""
        if not prefix.endswith('/'):
            prefix += '/'

        failure_domain = []

        is_failure_domain_str = isinstance(opts.failure_domain, str)
        if is_failure_domain_str and len(opts.failure_domain) > 0:
            failure_domain.extend(opts.failure_domain.split(","))

        if opts.master_nodes is not None:
            try:
                if int(opts.master_nodes) < 1:
                    raise ValueError("'master_nodes' must be greater "
                                     "than zero")
            except ValueError:
                raise UsageError("Argument to -m/--master-nodes must "
                                 "be a natural number")
        #
        # Data retrieval
        #

        # Request "about" info, so we can figure out the ES version,
        # to allow for version-specific API changes.
        es_about = get_json(r'http://%s:%d/%s' % (host, port, prefix))
        es_version = es_about['version']['number']

        # Request cluster 'health'.  /_cluster/health is like a tl;dr 
        # for /_cluster/state (see below).  There is very little useful 
        # information here.  We are primarily interested in ES' cluster 
        # 'health colour':  a little rating ES gives itself to describe 
        # how much pain it is in.
        es_health = get_json(r'http://%s:%d/%s_cluster/health' %
                             (host, port, prefix))

        their_health = HEALTH[es_health['status'].lower()]

        # Request cluster 'state'.  This be where all the meat at, yo.  
        # Here, we can see a list of all nodes, indexes, and shards in 
        # the cluster.  This response will also contain a map detailing 
        # where all shards are living at this point in time.
        es_state = get_json(r'http://%s:%d/%s_cluster/state' %
                            (host, port, prefix))

        # Request a bunch of useful numbers that we export as perfdata.  
        # Details like the number of get, search, and indexing 
        # operations come from here.
        es_stats = get_json(r'http://%s:%d/%s_nodes/_local/'
                            'stats?all=true' % (host, port, prefix))

        myid = es_stats['nodes'].keys()[0]

        n_nodes = es_health['number_of_nodes']
        n_dnodes = es_health['number_of_data_nodes']

        # Unlike n_dnodes (the number of data nodes), we must compute 
        # the number of master-eligible nodes ourselves.
        n_mnodes = 0
        for esid in es_state['nodes']:
            master_elig = True

            # ES will never elect 'client' nodes as masters.
            try:
                master_elig = not (booleanise(
                    es_state['nodes'][esid]['attributes']
                    ['client']))
            except KeyError, e:
                if e.args[0] != 'client':
                    raise

            try:
                master_elig = (booleanise(
                    es_state['nodes'][esid]['attributes']
                    ['master']))
            except KeyError, e:
                if e.args[0] != 'master':
                    raise

            if master_elig:
                n_mnodes += 1

        n_active_shards = es_health['active_shards']
        n_relocating_shards = es_health['relocating_shards']
        n_initialising_shards = es_health['initializing_shards']
        n_unassigned_shards = es_health['unassigned_shards']
        n_shards = (n_active_shards + n_relocating_shards +
                    n_initialising_shards + n_unassigned_shards)

        #
        # Map construction
        #

        # String all the dumb ES* objects into a bunch of transitive 
        # associations so that we may make some useful assertions about 
        # them.
        esid_node_map = {}  # ESID       : <ESNode>
        index_primary_map = {}  # <ESIndex>  : { 0: <ESShard>, ... }
        name_index_map = {}  # 'bar'      : <ESIndex>
        name_node_map = {}  # 'foo'      : <ESNode>
        node_esid_map = {}  # <ESNode>   : ESID
        node_location_map = {}  # <ESNode>   : ('mars',)
        node_shard_map = {}  # <ESNode>   : [ <ESShard>, ... ]
        primary_replica_map = {}  # <ESShard>  : [ <ESShard>, ... ]
        shard_location_map = {}  # <ESShard>  : ('mars',)

        # Build node maps:
        #
        # - esid_node_map
        # - name_node_map
        # - node_esid_map
        # - node_location_map (data nodes only)
        #
        nodes = es_state['nodes']
        for n in nodes:
            name = nodes[n]['name']
            attrs = nodes[n]['attributes']
            node = ESNode(name, n, attrs)

            name_node_map[name] = node
            esid_node_map[n] = node
            node_esid_map[node] = n
            node_shard_map[node] = []

            if len(failure_domain) > 0:
                node_location_map[node] = tuple()
                try:
                    node_location_map[node] = (
                        tuple(map(lambda a: attrs[a], failure_domain)))
                except KeyError, e:
                    # Nodes that do not store shards (e.g.: 'client' 
                    # nodes) cannot be expected to have been configured 
                    # with locational attributes.
                    if 'data' not in attrs or booleanise(attrs['data']):
                        missing_attr = e.args[0]
                        raise Status('warning',
                                     ("Node '%s' missing location "
                                      "attribute '%s'" %
                                      (name, missing_attr),))

        # Build index maps:
        #
        # - name_index_map
        #
        indices = es_state['metadata']['indices']
        n_indices = len(indices)
        n_closed_indices = 0
        for i in indices:
            if indices[i]["state"] == "close":
                n_closed_indices += 1
                continue
            idx_stns = indices[i]['settings']
            if version(es_version) < version("1.0.0"):
                idx = ESIndex(i,
                              int(idx_stns['index.number_of_shards']),
                              int(idx_stns['index.number_of_replicas']))
            else:
                idx = ESIndex(i,
                              int(idx_stns['index']['number_of_shards']),
                              int(idx_stns['index']['number_of_replicas']))

            name_index_map[i] = idx

        # Build shard maps:
        #
        # - index_primary_map
        # - node_shard_map
        # - primary_replica_map
        #     - shard_location_map
        #
        for i in name_index_map:
            idx = name_index_map[i]

            if idx not in index_primary_map:
                index_primary_map[idx] = dict(map(lambda n: (n, None),
                                                  range(idx.n_shards)))

            idx_shards = (es_state['routing_table']['indices']
                          [i]['shards'])
            for d in idx_shards:
                primary = None
                replicas = []
                for s in idx_shards[d]:
                    shard = ESShard(SHARD_STATE[s['state'].upper()])

                    if s['primary']:
                        primary = shard
                    else:
                        replicas.append(shard)

                    if s['state'] != 'UNASSIGNED':
                        node = esid_node_map[s['node']]

                        node_shard_map[node].append(shard)

                        if len(failure_domain) > 0:
                            loc = node_location_map[esid_node_map[s['node']]]
                            shard_location_map[shard] = loc

                index_primary_map[idx][int(d)] = primary

                if primary is not None:
                    primary_replica_map[primary] = replicas

        #
        # Perfdata
        #

        perfdata = []

        def dict2perfdata(base, metrics):
            for metric in metrics:
                try:
                    label, path, unit = metric
                except IndexError:
                    try:
                        label, path = metric
                        unit = ''
                    except IndexError:
                        continue

                value = base
                for key in path.split("."):
                    try:
                        value = value[key]
                    except (KeyError, TypeError):
                        value = None
                        break

                if value is not None:
                    metric = PerformanceMetric(label=label,
                                               value=value,
                                               unit=unit)
                    perfdata.append(metric)

        def other2perfdata(metrics):
            for metric in metrics:
                if len(metric) == 2:
                    label, value = metric
                    unit = ""
                elif len(metric) > 2:
                    label, value, unit = metric
                else:
                    continue

                if value is not None:
                    metric = PerformanceMetric(label=label,
                                               value=value,
                                               unit=unit)
                    perfdata.append(metric)

        # Add cluster-wide metrics first.  If you monitor all of your ES 
        # cluster nodes with this plugin, they should all report the 
        # same figures for these labels.  Not ideal, but 'tis better to 
        # graph this data multiple times than not graph it at all.
        metrics = [["cluster_nodes", n_nodes],
                   ["cluster_master_eligible_nodes", n_mnodes],
                   ["cluster_data_nodes", n_dnodes],
                   ["cluster_active_shards", n_active_shards],
                   ["cluster_relocating_shards", n_relocating_shards],
                   ["cluster_initialising_shards", n_initialising_shards],
                   ["cluster_unassigned_shards", n_unassigned_shards],
                   ["cluster_total_shards", n_shards],
                   ["cluster_total_indices", n_indices],
                   ["cluster_closed_indices", n_closed_indices]]

        other2perfdata(metrics)

        metrics = (
            ("storesize", 'indices.store.size_in_bytes', "B"),
            ("documents", 'indices.docs.count', "c"),

            ("index_ops", 'indices.indexing.index_total', "c"),
            ("index_time", 'indices.indexing.index_time_in_millis', "ms"),

            ("flush_ops", 'indices.flush.total', "c"),
            ("flush_time", 'indices.flush.total_time_in_millis', "ms"),

            ("throttle_time", "indices.store.throttle_time_in_millis", "ms"),

            ("index_ops", "indices.indexing.index_total", "c"),
            ("index_time", "indices.indexing.index_time_in_millis", "ms"),
            ("delete_ops", "indices.indexing.delete_total", "c"),
            ("delete_time", "indices.indexing.delete_time_in_millis", "ms"),

            ("get_ops", "indices.get.total", "c"),
            ("get_time", "indices.get.time_in_millis", "ms"),
            ("exists_ops", "indices.get.exists_total", "c"),
            ("exists_time", "indices.get.exists_time_in_millis", "ms"),
            ("missing_ops", "indices.get.missing_total", "c"),
            ("missing_time", "indices.get.missing_time_in_millis", "ms"),

            ("query_ops", 'indices.search.query_total', "c"),
            ("query_time", 'indices.search.query_time_in_millis', "ms"),
            ("fetch_ops", "indices.search.fetch_total", "c"),
            ("fetch_time", "indices.search.fetch_time_in_millis", "ms"),

            ("merge_ops", "indices.merges.total", "c"),
            ("merge_time", "indices.merges.time_in_millis", "ms"),

            ("refresh_ops", "indices.refresh.total", "c"),
            ("refresh_time", "indices.refresh.total_time_in_millis", "ms"),

            ("gc_old_count", "jvm.gc.collectors.old.collection_count", "c"),
            (
                "gc_young_count", "jvm.gc.collectors.young.collection_count",
                "c"),
            ("heap_used", "jvm.mem.heap_used_percent", "%"),
        )

        dict2perfdata(es_stats['nodes'][myid], metrics)

        #
        # Assertions
        #

        detail = []  # Collect error messages into this list

        msg = "Monitoring cluster '%s'" % es_health['cluster_name']

        # Assertion:  Each shard has one primary in STARTED or
        # RELOCATING state.
        downgraded = False

        for idx_name, idx in name_index_map.iteritems():
            for shard_no in range(idx.n_shards):
                primary = index_primary_map[idx][shard_no]
                if primary is None:
                    downgraded |= self.downgrade_health(RED)
                    detail.append("Index '%s' missing primary on "
                                  "shard %d" % (idx_name, shard_no))
                else:
                    if primary.state not in (
                            SHARD_STATE['STARTED'],
                            SHARD_STATE['RELOCATING'] ):
                        downgraded |= self.downgrade_health(RED)
                        detail.append("Index '%s' primary down on "
                                      "shard %d" % (idx_name, shard_no))

        if downgraded:
            msg = ("One or more indexes are missing primary shards.  "
                   "Use -vv to list them.")

        # Assertion:  Each primary has replicas in STARTED state.
        downgraded = False

        for idx_name, idx in name_index_map.iteritems():
            expect_replicas = idx.n_replicas

            for shard_no in range(idx.n_shards):
                primary = index_primary_map[idx][shard_no]

                if primary is None:
                    continue

                has_replicas = len(primary_replica_map[primary])

                if has_replicas < expect_replicas:
                    downgraded |= self.downgrade_health(YELLOW)
                    detail.append("Index '%s' missing replica on "
                                  "shard %d" % (idx_name, shard_no))

                for replica in primary_replica_map[primary]:
                    if replica.state not in (
                            SHARD_STATE['STARTED'],
                            SHARD_STATE['RELOCATING'] ):
                        downgraded |= self.downgrade_health(YELLOW)
                        detail.append("Index '%s' replica down on "
                                      "shard %d" % (idx_name, shard_no))

        if downgraded:
            msg = ("One or more indexes are missing replica shards.  "
                   "Use -vv to list them.")

        # Assertion:  You have as many master-eligible nodes in the 
        # cluster as you think you ought to.
        #
        # To be of any use in detecting split-brains, this value must be 
        # set to the *total* number of master-eligible nodes in the 
        # cluster, not whatever you set in ElasticSearch's 
        # 'discovery.zen.minimum_master_nodes' configuration parameter.  
        # (See ES issue #2488.)  Of course, this will trip whenever a 
        # node is taken down for maintenance, so we raise only a warning 
        # -- not a critical -- status condition.
        downgraded = False

        if opts.master_nodes is not None:
            if n_mnodes < int(opts.master_nodes):
                downgraded |= self.downgrade_health(YELLOW)
                detail.append("Expected to find %d master-eligible "
                              "nodes in the cluster but only found %d" %
                              (int(opts.master_nodes), n_mnodes))

        if downgraded:
            msg = ("Missing master-eligible nodes")

        # Assertion:  Replicas are not stored in the same failure domain 
        # as their primary.
        downgraded = False

        if len(failure_domain) > 0:
            for idx_name, idx in name_index_map.iteritems():

                # Suppress this test if the index has not been 
                # configured with replicas.
                if idx.n_replicas == 0:
                    continue

                for shard_no in range(idx.n_shards):
                    loc_redundancy = set()
                    vulnerable_shards = set()

                    primary = index_primary_map[idx][shard_no]

                    if primary is None:
                        continue

                    try:
                        loc = shard_location_map[primary]
                    except KeyError:
                        continue
                    loc_redundancy.add(loc)
                    vulnerable_shards.add(primary)

                    for replica in primary_replica_map[primary]:
                        try:
                            loc = shard_location_map[replica]
                        except KeyError:
                            continue
                        loc_redundancy.add(loc)
                        vulnerable_shards.add(replica)

                    # Suppress the problem unless at least one of the 
                    # vulnerable shards is on this data node.
                    my_shards = set(node_shard_map[esid_node_map[myid]])
                    if vulnerable_shards.isdisjoint(my_shards):
                        continue

                    if len(loc_redundancy) == 1:
                        downgraded |= self.downgrade_health(YELLOW)
                        loc = ",".join(list(loc_redundancy)[0])
                        detail.append("Index '%s' shard %d only exists "
                                      "in location '%s'" %
                                      (idx_name, shard_no, loc))

        if downgraded:
            msg = ("One or more index shards are not being replicated "
                   "across failure domains.  Use -vv to list them.")

        # ES detected a problem that we did not.  This should never 
        # happen.  (If it does, you should work out what happened, then 
        # fix this code so that we can detect the problem if it happens 
        # again.)  Obviously, in this case, we cannot provide any useful 
        # output to the operator.
        if their_health < self.health:
            raise Status('critical',
                         ("Cluster reports degraded health: '%s'" %
                          es_health['status'],),
                         perfdata)

        raise Status(HEALTH_MAP[self.health],
                     (msg, None, "%s %s" % (msg, " ".join(detail))),
                     perfdata)

    def downgrade_health(self, new_health):
        if new_health < self.health:
            self.health = new_health
            return True
        return False


def booleanise(b):
    """Normalise a 'stringified' Boolean to a proper Python Boolean.

    ElasticSearch has a habit of returning "true" and "false" in its 
    JSON responses when it should be returning `true` and `false`.  If 
    `b` looks like a stringified Boolean true, return True.  If `b` 
    looks like a stringified Boolean false, return False.

    Raise ValueError if we don't know what `b` is supposed to represent.

    """
    s = str(b)
    if s.lower() == "true":
        return True
    if s.lower() == "false":
        return False

    raise ValueError("I don't know how to coerce %r to a bool" % b)


def get_json(uri):
    try:
        f = urllib2.urlopen(uri)
    except urllib2.HTTPError, e:
        raise Status('unknown', ("API failure",
                                 None,
                                 "API failure:\n\n%s" % str(e)))
    except urllib2.URLError, e:
        # The server could be down; make this CRITICAL.
        raise Status('critical', (e.reason,))

    body = f.read()

    try:
        j = json.loads(body)
    except ValueError:
        raise Status('unknown', ("API returned nonsense",))

    return j


def version(version_string):
    """Accept a typical version string (ex: 1.0.1) and return a tuple
    of ints, allowing for reasonable comparison."""
    return tuple([int(i) for i in version_string.split('.')])


if __name__ == '__main__':
    ElasticSearchCheck().run()
