#!/usr/bin/python
# -*- encoding: utf-8; py-indent-offset: 4 -*-
# +------------------------------------------------------------------+
# |             ____ _               _        __  __ _  __           |
# |            / ___| |__   ___  ___| | __   |  \/  | |/ /           |
# |           | |   | '_ \ / _ \/ __| |/ /   | |\/| | ' /            |
# |           | |___| | | |  __/ (__|   <    | |  | | . \            |
# |            \____|_| |_|\___|\___|_|\_\___|_|  |_|_|\_\           |
# |                                                                  |
# | Copyright Mathias Kettner 2014             mk@mathias-kettner.de |
# +------------------------------------------------------------------+
#
# This file is part of Check_MK.
# The official homepage is at http://mathias-kettner.de/check_mk.
#
# check_mk is free software;  you can redistribute it and/or modify it
# under the  terms of the  GNU General Public License  as published by
# the Free Software Foundation in version 2.  check_mk is  distributed
# in the hope that it will be useful, but WITHOUT ANY WARRANTY;  with-
# out even the implied warranty of  MERCHANTABILITY  or  FITNESS FOR A
# PARTICULAR PURPOSE. See the  GNU General Public License for more de-
# tails. You should have  received  a copy of the  GNU  General Public
# License along with GNU Make; see the file  COPYING.  If  not,  write
# to the Free Software Foundation, Inc., 51 Franklin St,  Fifth Floor,
# Boston, MA 02110-1301 USA.

import urllib2, sys, os, socket, base64, ssl
from httplib import HTTPConnection, HTTPSConnection

try:
    from simplejson import json
except ImportError:
    try:
        import json
    except ImportError:
        sys.stdout.write("<<<jolokia_info>>>\n")
        sys.stdout.write("Error: Missing JSON library for Agent Plugin mk_jolokia\n")
        exit()

opt_verbose = '--verbose' in sys.argv
opt_debug = '--debug' in sys.argv


class PreemptiveBasicAuthHandler(urllib2.HTTPBasicAuthHandler):
    """
    sends basic authentication with the first request,
    before the server even asks for it
    """

    def http_request(self, req):
        url = req.get_full_url()
        realm = None
        user, pw = self.passwd.find_user_password(realm, url)
        if pw:
            raw = "%s:%s" % (user, pw)
            auth = 'Basic %s' % base64.b64encode(raw).strip()
            req.add_unredirected_header(self.auth_header, auth)
        return req

    https_request = http_request


class HTTPSValidatingConnection(HTTPSConnection):
    def __init__(self, host, ca_file, key_file, cert_file):
        HTTPSConnection.__init__(self, host, key_file=key_file, cert_file=cert_file)
        self.__ca_file = ca_file
        self.__key_file = key_file
        self.__cert_file = cert_file

    def connect(self):
        HTTPConnection.connect(self)
        if self.__ca_file:
            self.sock = ssl.wrap_socket(
                self.sock,
                keyfile=self.key_file,
                certfile=self.cert_file,
                ca_certs=self.__ca_file,
                cert_reqs=ssl.CERT_REQUIRED)
        else:
            self.sock = ssl.wrap_socket(
                self.sock,
                keyfile=self.key_file,
                certfile=self.cert_file,
                ca_certs=self.__ca_file,
                cert_reqs=ssl.CERT_NONE)


class HTTPSAuthHandler(urllib2.HTTPSHandler):
    def __init__(self, ca_file, key, cert):
        urllib2.HTTPSHandler.__init__(self)
        self.__ca_file = ca_file
        self.__key = key
        self.__cert = cert

    def https_open(self, req):
        # do_open expects a class as the first parameter but getConnection will act
        # as a facotry function
        return self.do_open(self.getConnection, req)

    def getConnection(self, host, timeout):
        return HTTPSValidatingConnection(
            host, ca_file=self.__ca_file, key_file=self.__key, cert_file=self.__cert)


def fetch_url_get(base_url, path, function):
    if path:
        url = "%s/%s/%s" % (base_url, function, path)
    else:
        url = base_url + "/"

    if opt_verbose:
        sys.stderr.write("DEBUG: Fetching: %s\n" % url)
    try:
        json_data = urllib2.urlopen(url).read()
        if opt_verbose:
            sys.stderr.write("DEBUG: Result: %s\n\n" % json_data)
    except Exception, e:
        if opt_debug:
            raise
        sys.stderr.write("ERROR: %s\n" % e)
        return []
    return json_data


def fetch_url_post(base_url, path, service_url, service_user, service_password, function):
    segments = path.split("/")

    data = {
        "type": function.upper(),
        "mbean": segments[0],
        "attribute": segments[1],
        "target": {
            "url": service_url,
        },
    }
    if len(segments) > 2:
        data["path"] = segments[2]

    if service_user:
        data["target"]["user"] = service_user
        data["target"]["password"] = service_password

    if opt_verbose:
        sys.stderr.write("DEBUG: Fetching: %s\n" % base_url)
    try:
        json_data = urllib2.urlopen(base_url, data=json.dumps(data)).read()
        if opt_verbose:
            sys.stderr.write("DEBUG: Result: %s\n\n" % json_data)
    except Exception, e:
        if opt_debug:
            raise
        sys.stderr.write("ERROR: %s\n" % e)
        return []
    return json_data


def fetch_var(protocol,
              server,
              port,
              path,
              suburi,
              itemspec,
              service_url,
              service_user,
              service_password,
              function="read"):
    base_url = "%s://%s:%d/%s" % (protocol, server, port, suburi)

    if service_url is not None:
        json_data = fetch_url_post(base_url, path, service_url, service_user, service_password,
                                   function)
    else:
        json_data = fetch_url_get(base_url, path, function)

    try:
        obj = json.loads(json_data)
    except Exception, e:
        sys.stderr.write('ERROR: Invalid json code (%s)\n' % e)
        sys.stderr.write('       Response %s\n' % json_data)
        return []

    if obj.get('status', 200) != 200:
        sys.stderr.write('ERROR: Invalid response when fetching url %s\n' % base_url)
        sys.stderr.write('       Response: %s\n' % json_data)
        return []

    # Only take the value of the object. If the value is an object
    # take the first items first value.
    # {'Catalina:host=localhost,path=\\/test,type=Manager': {'activeSessions': 0}}
    if 'value' not in obj:
        if opt_verbose:
            sys.stderr.write("ERROR: not found: %s\n" % path)
        return []
    val = obj.get('value', None)
    return make_item_list((), val, itemspec)


# convert single values into lists of items in
# case value is a 1-levelled or 2-levelled dict
def make_item_list(path, value, itemspec):
    if type(value) != dict:
        if type(value) == str:
            value = value.replace(r'\/', '/')
        return [(path, value)]
    else:
        result = []
        for key, subvalue in value.items():
            # Handle filtering via itemspec
            miss = False
            while itemspec and '=' in itemspec[0]:
                if itemspec[0] not in key:
                    miss = True
                    break
                itemspec = itemspec[1:]
            if miss:
                continue
            item = extract_item(key, itemspec)
            if not item:
                item = (key,)
            result += make_item_list(path + item, subvalue, [])
        return result


# Example:
# key = 'Catalina:host=localhost,path=\\/,type=Manager'
# itemsepc = [ "path" ]
# --> "/"


def extract_item(key, itemspec):
    path = key.split(":", 1)[-1]
    components = path.split(",")
    item = ()
    comp_dict = {}
    for comp in components:
        parts = comp.split("=")
        if len(parts) == 2:
            left, right = parts
            comp_dict[left] = right
    for pathkey in itemspec:
        if pathkey in comp_dict:
            right = comp_dict[pathkey]
            right = right.replace(r'\/', '/')
            if '/' in right:
                right = '/' + right.split('/')[-1]
            item = item + (right,)
    return item


def fetch_metric(inst, path, title, itemspec, inst_add=None):
    values = fetch_var(inst["protocol"], inst["server"], inst["port"], path, inst["suburi"],
                       itemspec, inst["service_url"], inst["service_user"],
                       inst["service_password"])

    for subinstance, value in values:
        if not subinstance and not title:
            sys.stdout.write("INTERNAL ERROR: %s\n" % value)
            continue

        if "threadStatus" in subinstance or "threadParam" in subinstance:
            continue

        if len(subinstance) > 1:
            item = ",".join((inst["instance"],) + subinstance[:-1])
        elif inst_add is not None:
            item = inst["instance"] + "," + inst_add
        else:
            item = inst["instance"]
        if title:
            if subinstance:
                tit = title + "." + subinstance[-1]
            else:
                tit = title
        else:
            tit = subinstance[-1]

        yield (item.replace(" ", "_"), tit, value)


def query_instance(inst):
    # Prepare user/password authentication via HTTP Auth
    password_mngr = urllib2.HTTPPasswordMgrWithDefaultRealm()
    if inst.get("password"):
        password_mngr.add_password(None,
                                   "%s://%s:%d/" % (inst["protocol"], inst["server"], inst["port"]),
                                   inst["user"], inst["password"])

    handlers = []
    if inst["protocol"] == "https":
        if inst["mode"] == 'https' and (inst["client_key"] is None or inst["client_cert"] is None):
            sys.stdout.write('<<<jolokia_info>>>\n')
            sys.stderr.write("ERROR: https set up as authentication method but certificate "
                             "wasn't provided\n")
            return
        handlers.append(
            HTTPSAuthHandler(inst["cert_path"], inst["client_key"], inst["client_cert"]))
    if inst["mode"] == 'digest':
        handlers.append(urllib2.HTTPDigestAuthHandler(password_mngr))
    elif inst["mode"] == "basic_preemptive":
        handlers.append(PreemptiveBasicAuthHandler(password_mngr))
    elif inst["mode"] == "basic" and inst["protocol"] != "https":
        handlers.append(urllib2.HTTPBasicAuthHandler(password_mngr))

    if handlers:
        opener = urllib2.build_opener(*handlers)
        urllib2.install_opener(opener)

    # Determine type of server
    server_info = fetch_var(inst["protocol"], inst["server"], inst["port"], "", inst["suburi"], "",
                            None, None, None)

    sys.stdout.write('<<<jolokia_info>>>\n')
    if server_info:
        d = dict(server_info)
        version = d.get(('info', 'version'), "unknown")
        product = d.get(('info', 'product'), "unknown")
        if inst.get("product"):
            product = inst["product"]
        agentversion = d.get(('agent',), "unknown")
        sys.stdout.write("%s %s %s %s\n" % (inst["instance"], product, version, agentversion))
    else:
        sys.stdout.write("%s ERROR\n" % (inst["instance"],))
        sys.stdout.write('<<<jolokia_metrics>>>\n')
        sys.stdout.write("%s ERROR\n" % (inst["instance"],))
        return

    mbean_search_results = {}

    sys.stdout.write('<<<jolokia_metrics>>>\n')
    # Fetch the general information first
    for var in global_vars + specific_vars.get(product, []):
        # support old and new configuration format so we stay compatible with older
        # configuration files
        if len(var) == 3:
            path, title, itemspec = var
            mbean, path = path.split("/", 1)
            do_search = False
        else:
            mbean, path, title, itemspec, do_search = var

        queries = []
        if do_search:
            if mbean in mbean_search_results:
                paths = mbean_search_results[mbean]
            else:
                try:
                    paths = fetch_var(
                        inst["protocol"],
                        inst["server"],
                        inst["port"],
                        mbean,
                        inst["suburi"],
                        "",
                        None,
                        None,
                        None,
                        function="search")[0][1]
                except IndexError:
                    paths = []
                mbean_search_results[mbean] = paths

            for mbean_exp in paths:
                queries.append((inst, "%s/%s" % (urllib2.quote(mbean_exp), path), path, itemspec))
        else:
            queries.append((inst, mbean + "/" + path, title, itemspec))

        for inst, mbean_path, title, itemspec in queries:
            try:
                for out_item, out_title, out_value in fetch_metric(inst, mbean_path, title,
                                                                   itemspec):
                    sys.stdout.write("%s %s %s\n" % (out_item, out_title, out_value))
            except IOError:
                return
            except socket.timeout:
                return
            except:
                if opt_debug:
                    raise
                # Simply ignore exceptions. Need to be removed for debugging
                continue

    if custom_vars:
        sys.stdout.write('<<<jolokia_generic:sep(0)>>>\n')
    for var in custom_vars:
        mbean, path, title, itemspec, do_search, value_type = var
        queries = []
        if do_search:
            if mbean in mbean_search_results:
                paths = mbean_search_results[mbean]
            else:
                try:
                    paths = fetch_var(
                        inst["protocol"],
                        inst["server"],
                        inst["port"],
                        mbean,
                        inst["suburi"],
                        "",
                        None,
                        None,
                        None,
                        function="search")[0][1]
                except IndexError:
                    paths = []
                mbean_search_results[mbean] = paths

            for mbean_exp in paths:
                queries.append((inst, "%s/%s" % (urllib2.quote(mbean_exp), path), path, itemspec))
        else:
            queries.append((inst, mbean + "/" + path, title, itemspec))

        for query_tuple in queries:
            try:
                for metric in fetch_metric(*query_tuple):
                    out_items = metric + (value_type,)
                    output_line = chr(0).join(str(i) for i in out_items)
                    sys.stdout.write(output_line + "\n")
            except IOError:
                return
            except socket.timeout:
                return
            except:
                if opt_debug:
                    raise
                # Simply ignore exceptions. Need to be removed for debugging
                continue


# Default configuration for all instances
protocol = "http"
server = "localhost"
port = 8080
user = "monitoring"
password = None
mode = "digest"
suburi = "jolokia"
instance = None
cert_path = None
client_cert = None
client_key = None
service_url = None
service_user = None
service_password = None
product = None

global_vars = [
    ("java.lang:type=Memory", "NonHeapMemoryUsage/used", "NonHeapMemoryUsage", [], False),
    ("java.lang:type=Memory", "NonHeapMemoryUsage/max", "NonHeapMemoryMax", [], False),
    ("java.lang:type=Memory", "HeapMemoryUsage/used", "HeapMemoryUsage", [], False),
    ("java.lang:type=Memory", "HeapMemoryUsage/max", "HeapMemoryMax", [], False),
    ("java.lang:type=Threading", "ThreadCount", "ThreadCount", [], False),
    ("java.lang:type=Threading", "DaemonThreadCount", "DeamonThreadCount", [], False),
    ("java.lang:type=Threading", "PeakThreadCount", "PeakThreadCount", [], False),
    ("java.lang:type=Threading", "TotalStartedThreadCount", "TotalStartedThreadCount", [], False),
    ("java.lang:type=Runtime", "Uptime", "Uptime", [], False),
    ("java.lang:type=GarbageCollector,name=*", "CollectionCount", "", [], False),
    ("java.lang:type=GarbageCollector,name=*", "CollectionTime", "", [], False),
    ("java.lang:name=CMS%20Perm%20Gen,type=MemoryPool", "Usage/used", "PermGenUsage", [], False),
    ("java.lang:name=CMS%20Perm%20Gen,type=MemoryPool", "Usage/max", "PermGenMax", [], False),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics", "OffHeapHits",
     "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics", "OnDiskHits",
     "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics",
     "InMemoryHitPercentage", "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics", "CacheMisses",
     "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics",
     "OnDiskHitPercentage", "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics",
     "MemoryStoreObjectCount", "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics",
     "DiskStoreObjectCount", "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics",
     "CacheMissPercentage", "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics",
     "CacheHitPercentage", "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics",
     "OffHeapHitPercentage", "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics",
     "InMemoryMisses", "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics",
     "OffHeapStoreObjectCount", "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics",
     "WriterQueueLength", "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics",
     "WriterMaxQueueSize", "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics", "OffHeapMisses",
     "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics", "InMemoryHits",
     "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics",
     "AssociatedCacheName", "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics", "ObjectCount",
     "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics", "OnDiskMisses",
     "", [], True),
    ("net.sf.ehcache:CacheManager=CacheManagerApplication*,*,type=CacheStatistics", "CacheHits", "",
     [], True),
]

specific_vars = {
    "weblogic": [
        ("*:*", "CompletedRequestCount", None, ["ServerRuntime"], False),
        ("*:*", "QueueLength", None, ["ServerRuntime"], False),
        ("*:*", "StandbyThreadCount", None, ["ServerRuntime"], False),
        ("*:*", "PendingUserRequestCount", None, ["ServerRuntime"], False),
        ("*:Name=ThreadPoolRuntime,*", "ExecuteThreadTotalCount", None, ["ServerRuntime"], False),
        ("*:*", "ExecuteThreadIdleCount", None, ["ServerRuntime"], False),
        ("*:*", "HoggingThreadCount", None, ["ServerRuntime"], False),
        ("*:Type=WebAppComponentRuntime,*", "OpenSessionsCurrentCount", None,
         ["ServerRuntime", "ApplicationRuntime"], False),
    ],
    "tomcat": [
        ("*:type=Manager,*", "activeSessions,maxActiveSessions", None, ["path", "context"], False),
        ("*:j2eeType=Servlet,name=default,*", "stateName", None, ["WebModule"], False),
        # Check not yet working
        ("*:j2eeType=Servlet,name=default,*", "requestCount", None, ["WebModule"], False),
        ("*:name=*,type=ThreadPool", "maxThreads", None, [], False),
        ("*:name=*,type=ThreadPool", "currentThreadCount", None, [], False),
        ("*:name=*,type=ThreadPool", "currentThreadsBusy", None, [], False),
        # too wide location for addressing the right info
        # ( "*:j2eeType=Servlet,*", "requestCount", None, [ "WebModule" ] , False),
    ],
    "jboss": [("*:type=Manager,*", "activeSessions,maxActiveSessions", None, ["path", "context"],
               False),],
}

#   ( '*:j2eeType=WebModule,name=/--/localhost/-/%(app)s,*/state', None, [ "name" ]),
#   ( '*:j2eeType=Servlet,WebModule=/--/localhost/-/%(app)s,name=%(servlet)s,*/requestCount', None, [ "WebModule", "name" ]),
#   ( "Catalina:J2EEApplication=none,J2EEServer=none,WebModule=*,j2eeType=Servlet,name=*", None, [ "WebModule", "name" ]),

# List of instances to monitor. Each instance is a dict where
# the global configuration values can be overridden.
instances = [{}]

custom_vars = []

conffile = os.getenv("MK_CONFDIR", "/etc/check_mk") + "/jolokia.cfg"

if os.path.exists(conffile):
    execfile(conffile)

if server == "use fqdn":
    server = socket.getfqdn()

if instance == None:
    instance = str(port)

# We have to deal with socket timeouts. Python > 2.6
# supports timeout parameter for the urllib2.urlopen method
# but we are on a python 2.5 system here which seem to use the
# default socket timeout. We are local here so  set it to 1 second.
socket.setdefaulttimeout(1.0)

# Compute list of instances to monitor. If the user has defined
# instances in his configuration, we will use this (a list
# of dicts).
for inst in instances:
    for varname, value in [
        ("protocol", protocol),
        ("server", server),
        ("port", port),
        ("user", user),
        ("password", password),
        ("mode", mode),
        ("suburi", suburi),
        ("instance", instance),
        ("cert_path", cert_path),
        ("client_cert", client_cert),
        ("client_key", client_key),
        ("service_url", service_url),
        ("service_user", service_user),
        ("service_password", service_password),
        ("product", product),
    ]:
        if varname not in inst:
            inst[varname] = value
    if not inst["instance"]:
        inst["instance"] = str(inst["port"])
    inst["instance"] = inst["instance"].replace(" ", "_")

    if inst.get("server") == "use fqdn":
        inst["server"] = socket.getfqdn()

    query_instance(inst)
