#!/usr/bin/python
"""Monitor MongoDB on Linux

This agent plugin creates various sections out of the MongoDB server status information.
Important: 1) If MongoDB runs as single instance the agent data is assigned
              to the host same host where the plugin resides.

           2) If MongoDB is deployed as replica set the agent data is piggybacked
              to a different hostname, name after the replica set name.
              You have to create a new host in the monitoring system matching the
              replica set name, or use the piggyback translation rule to modify the
              hostname according to your needs.
"""
import os
import pprint
import sys
import time
import inspect
from contextlib import contextmanager

import pymongo  # pylint: disable=import-error

MK_VARDIR = os.environ.get("MK_VARDIR")


@contextmanager
def piggyback(info):
    piggyhost = info.get("setName") if info else None
    if piggyhost:
        sys.stdout.write("<<<<%s>>>>\n" % piggyhost)
    try:
        yield
    finally:
        if piggyhost:
            sys.stdout.write("<<<<>>>>\n")


def get_database_info(client):
    if inspect.ismethod(client.list_database_names):
        db_names = client.list_database_names()
    elif inspect.ismethod(client.database_names):
        db_names = client.database_names()
    else:
        db_names = []

    databases = {n: {} for n in db_names}
    for name in db_names:
        database = client[name]
        databases[name]["collections"] = database.collection_names()
        databases[name]["stats"] = database.command("dbstats")
        databases[name]["collstats"] = {}
        for collection in databases[name]["collections"]:
            databases[name]["collstats"][collection] = database.command("collstats", collection)
    return databases


def section_instance(server_status):
    sys.stdout.write("<<<mongodb_instance:sep(9)>>>\n")
    sys.stdout.write("version\t%s\n" % server_status.get("version", "n/a"))
    sys.stdout.write("pid\t%s\n" % server_status.get("pid", "n/a"))

    repl_info = server_status.get("repl")
    if not repl_info:
        sys.stdout.write("mode\tSingle Instance\n")
        return

    if repl_info.get("ismaster"):
        sys.stdout.write("mode\tPrimary\n")
        return

    if repl_info.get("secondary"):
        sys.stdout.write("mode\tSecondary\n")
        return

    sys.stdout.write("mode\tArbiter\n")
    if repl_info.get("me"):
        sys.stdout.write("address\t%s\n" % repl_info.get("me", "n/a"))


def section_flushing(server_status):
    # key is depricated for MongoDB 4.0
    flushing_info = server_status.get("backgroundFlushing")
    if flushing_info is None:
        return
    sys.stdout.write("<<<mongodb_flushing>>>\n")
    sys.stdout.write("average_ms %s\n" % flushing_info.get("average_ms", "n/a"))
    sys.stdout.write("last_ms %s\n" % flushing_info.get("last_ms", "n/a"))
    sys.stdout.write("flushed %s\n" % flushing_info.get("flushes", "n/a"))


def sections_replica(client, server_status):
    repl_info = server_status.get("repl")
    if not repl_info:
        return
    sys.stdout.write("<<<mongodb_replica:sep(9)>>>\n")
    sys.stdout.write("primary\t%s\n" % repl_info.get("primary", "n/a"))
    if repl_info.get("hosts"):
        sys.stdout.write("hosts\t%s\n" % " ".join(repl_info.get("hosts")))

    if repl_info.get("arbiters"):
        sys.stdout.write("arbiters\t%s\n" % " ".join(repl_info.get("arbiters")))

    sys.stdout.write("<<<mongodb_replstatus>>>\n")
    sys.stdout.write(pprint.pformat(client.admin.command("replSetGetStatus")))
    sys.stdout.write("\n")


def section_chunks(client, databases):
    sys.stdout.write("<<<mongodb_chunks>>>\n")
    col = client.config.chunks
    for db_name, db_data in databases.items():
        shards = col.distinct("shard")
        sys.stdout.write("shardcount %d\n" % len(shards))
        for collection in db_data.get("collections"):
            nsfilter = "%s.%s" % (db_name, collection)
            sys.stdout.write("nscount %s %s\n" % (nsfilter, col.find({"ns": nsfilter}).count()))
            for shard in shards:
                matches = col.find({"ns": nsfilter, "shard": shard}).count()
                sys.stdout.write("shardmatches" " %s#%s %s\n" % (nsfilter, shard, matches))


def section_locks(server_status):
    sys.stdout.write("<<<mongodb_locks>>>\n")
    global_lock_info = server_status.get("globalLock")
    if global_lock_info:
        for what in ["activeClients", "currentQueue"]:
            if what in global_lock_info:
                for key, value in global_lock_info[what].items():
                    sys.stdout.write("%s %s %s\n" % (what, key, value))


def section_by_keys(section_name, keys, server_status, output_key=False):
    sys.stdout.write("<<<mongodb_%s>>>\n" % section_name)
    for key in keys:
        fmt = ("%s " % key if output_key else "") + "%s %s\n"
        for item in server_status.get(key, {}).items():
            sys.stdout.write(fmt % item)


def section_collections(databases):
    sys.stdout.write("<<<mongodb_collections:sep(9)>>>\n")
    for dbname, dbdata in databases.items():
        for collname, colldata in dbdata.get("collstats", {}).items():
            for key, value in colldata.items():
                if 'size' in key.lower():
                    sys.stdout.write("%s\t%s\t%s\t%s\n" % (dbname, collname, key, value))


def get_timestamp(text):
    """parse timestamps like 'Nov  6 13:44:09.345' or '2015-10-17T05:35:24.234'"""
    text = text.split('.')[0]
    for pattern in ["%a %b %d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"]:
        try:
            return time.mktime(time.strptime(text, pattern))
        except ValueError:
            pass


def read_statefile(state_file):
    try:
        with open(state_file) as state_fd:
            last_timestamp = int(state_fd.read())
    except (IOError, ValueError):
        return None, True

    if time.localtime(last_timestamp).tm_year >= 2015:
        return last_timestamp, False

    # Note: there is no year information in these loglines
    # As workaround we look at the creation date (year) of the last statefile
    # If it differs and there are new messages we start from the beginning
    statefile_year = time.localtime(os.stat(state_file).st_ctime).tm_year
    output_all = time.localtime().tm_year != statefile_year
    return last_timestamp, output_all


def update_statefile(state_file, startup_warnings):
    lines = startup_warnings.get("log")
    if not lines:
        return
    timestamp = get_timestamp(lines[-1])
    try:
        with open(state_file, 'w') as state_fd:
            state_fd.write("%d" % timestamp)
    except (IOError, TypeError):
        # TypeError: timestamp was None, but at least ctime is updated.
        pass


def section_logwatch(client):
    if not MK_VARDIR:
        return

    sys.stdout.write("<<<logwatch>>>\n")
    sys.stdout.write("[[[MongoDB startupWarnings]]]\n")
    startup_warnings = client.admin.command({"getLog": "startupWarnings"})

    state_file = "%s/mongodb.state" % MK_VARDIR

    last_timestamp, output_all = read_statefile(state_file)

    for line in startup_warnings["log"]:
        state = "C"
        state_index = line.find("]") + 2
        if len(line) == state_index or line[state_index:].startswith("**  "):
            state = "."

        if "** WARNING:" in line:
            state = "W"

        if output_all or get_timestamp(line) > last_timestamp:
            sys.stdout.write("%s %s\n" % (state, line))

    update_statefile(state_file, startup_warnings)


def main():
    client = pymongo.MongoClient(read_preference=pymongo.ReadPreference.SECONDARY)
    try:
        # connecting is lazy, it might fail only now
        server_status = client.admin.command("serverStatus")
    except pymongo.errors.ConnectionFailure:
        sys.stdout.write("<<<mongodb_instance:sep(9)>>>\n")
        sys.stdout.write("error\tInstance is down\n")
        return

    section_instance(server_status)
    repl_info = server_status.get("repl")
    if repl_info and not repl_info.get("ismaster"):
        return

    with piggyback(repl_info):
        potentially_piggybacked_sections(client, server_status)


def potentially_piggybacked_sections(client, server_status):
    sections_replica(client, server_status)
    section_by_keys("asserts", ("asserts",), server_status)
    section_by_keys("connections", ("connections",), server_status)
    databases = get_database_info(client)
    section_chunks(client, databases)
    section_locks(server_status)
    section_flushing(server_status)
    section_by_keys("mem", ("mem", "extra_info"), server_status)
    section_by_keys("counters", ("opcounters", "opcountersRepl"), server_status, output_key=True)
    section_collections(databases)
    section_logwatch(client)


if __name__ == "__main__":
    sys.exit(main())
