#!/usr/bin/python3
from optparse import OptionParser
import shlex
import subprocess
import sys
import requests
import json
import re

if "check_output" not in dir( subprocess ): # duck punch it in!
  def f(*popenargs, **kwargs):
      if 'stdout' in kwargs:
          raise ValueError('stdout argument not allowed, it will be overridden.')
      process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
      output, unused_err = process.communicate()
      retcode = process.poll()
      if retcode:
          cmd = kwargs.get("args")
          if cmd is None:
              cmd = popenargs[0]
          raise subprocess.CalledProcessError(retcode, cmd)
      return output
  subprocess.check_output = f


class RabbitCmdWrapper(object):
    """So basically this just runs rabbitmqctl commands and returns parsed output.
       Typically this means you need root privs for this to work.
       Made this it's own class so it could be used in other monitoring tools
       if desired."""

    @classmethod
    def list_connections(cls):
        args = shlex.split("sudo rabbitmqctl list_connections")
        cmd_result = subprocess.check_output(args).strip().decode('utf-8')
        results = cls._parse_list_results(cmd_result)
        return results

    @classmethod
    def list_queues(cls):
        args = shlex.split('sudo rabbitmqctl list_queues -q')
        cmd_result = subprocess.check_output(args).strip().decode('utf-8')
        results = cls._parse_list_results(cmd_result)
        return results

    @classmethod
    def status(cls):
        args = shlex.split('sudo rabbitmqctl status')
        cmd_result = subprocess.check_output(args).strip().decode('utf-8')
        results = cls._parse_list_results(cmd_result)
        return results

    @classmethod
    def _parse_list_results(cls, result_string):
        results = result_string.strip().split('\n')
        #remove text fluff
        if "Listing connections ..." in results: results.remove("Listing connections ...")
        if "Listing queues ..." in results: results.remove("Listing queues ...")
        if "name\tmessages" in results: results.remove("name\tmessages")
        return_data = []
        for row in results:
            return_data.append(row.split('\t'))
        return return_data


def check_connection_count(critical=0, warning=0):
    """Checks to make sure the numbers of connections are within parameters."""
    try:
        count = len(RabbitCmdWrapper.list_connections())
        if count >= critical:
            print ("CRITICAL - Connection Count %d" % count),
            sys.exit(2)
        elif count >= warning:
            print ("WARNING - Connection Count %d" % count),
            sys.exit(1)
        else:
            print ("OK - Connection Count %d" % count),
        print ("| connection_count=%d;%d;%d;%d;%d" % (count, warning, critical, 0, critical * 1.5))
    except Exception as err:
        print ("CRITICAL - %s" % err)


def check_queues_count(critical=1000, warning=1000, skip=0):
    """
    A blanket check to make sure all queues are within count parameters.
    TODO: Possibly break this out so test can be done on individual queues.
    """
    try:
        critical_q = []
        warning_q = []
        perf_q = []
        results = RabbitCmdWrapper.list_queues()
        for queue in results:
            if len(queue) == 2:
                count = int(queue[1])
                if count < skip:
                    continue
                elif count >= critical:
                    critical_q.append("%s: %s" % (queue[0], count)),
                elif count >= warning:
                    warning_q.append("%s: %s" % (queue[0], count)),
                perf_q.append("'%s'=%s " % (queue[0], count)),
        if critical_q:
            print ("CRITICAL - %s" % ", ".join(critical_q)),
            print ("| %s" % ", ".join(perf_q))
            sys.exit(2)
        elif warning_q:
            print ("WARNING - %s" % ", ".join(warning_q)),
            print ("| %s" % ", ".join(perf_q))
            sys.exit(1)
        else:
            print ("OK - NO QUEUES EXCEED THRESHOLDS"),
            print ("| %s" % ", ".join(perf_q))
            sys.exit(0)
    except Exception as err:
        print ("CRITICAL - %s" % err)
        sys.exit(2)

def check_mem_usage(critical=75, warning=50):
    """Check to make sure the RAM usage of rabbitmq process does not exceed 50%% of its max"""
    try:
        results = RabbitCmdWrapper.status()
        for line in results:
            rss = re.search(r'rss,([0-9]+)', str(line))
            vml = re.search(r'vm_memory_limit,([0-9]+)', str(line))
            if rss:
                memory_used = int(rss.groups()[0]) / 1024
            if vml:
                memory_limit = int(vml.groups()[0]) / 1024

        percent_usage = int(100 * memory_used / memory_limit)

        if percent_usage > critical:
            print ("CRITICAL - RABBITMQ RAM USAGE at %i%% of max|TOTAL=%iKB;;;; USED=%iKB;;;; mem_usage=%s%%;%s;%s;" % (percent_usage, memory_limit, memory_used, percent_usage, warning, critical))
            sys.exit(2)
        elif percent_usage > warning:
            print ("WARNING - RABBITMQ RAM USAGE at %i%% of max|TOTAL=%iKB;;;; USED=%iKB;;;; mem_usage=%s%%;%s;%s;" % (percent_usage, memory_limit, memory_used, percent_usage, warning, critical))
            sys.exit(1)
        else:
            print ("OK - RABBITMQ RAM USAGE at %i%% of max|TOTAL=%iKB;;;; USED=%iKB;;;; mem_usage=%s%%;%s;%s;" % (percent_usage, memory_limit, memory_used, percent_usage, warning, critical))
            sys.exit(0)

    except Exception as err:
        print ("Critical - %s" % err)
        sys.exit(2)

def check_aliveness(username, password, timeout, cluster):
      """Declares a test queue, then publishes and consumes a message. Intended for use by monitoring tools. If everything is working correctly, will return HTTP status 200 with body"""
      try:
          r = requests.get("http://%s:15672/api/aliveness-test/%%2F" % cluster, auth=(username, password), timeout=timeout)
      except requests.exceptions.RequestException as e: # Throw error if rabbitmq is down
        print ("Critical - %s" % e)
        sys.exit(2)
      if r.status_code == 200:
          print ("OK - RABBITMQ Aliveness Test Returns: %s" % r)
          sys.exit(0)
      elif r.status_code != 200:
          print ("CRITICAL - RabbitMQ Error: %s" % r.content)
          sys.exit(2)
      else:
          print ("UNKNOWN - RABBITMQ Aliveness Test")
          sys.ext(1)

def check_cluster(username, password, timeout, cluster):
    """Checks the health of a cluster, if a node is not running mark as offline  """
    try:
        url = "http://%s:15672/api/nodes" % cluster
        r = requests.get(url, auth=(username, password), timeout=timeout)
    except requests.exceptions.RequestException as e: # Throw error if no response
        print ("Critical - %s" % e)
        sys.exit(2)
    text = r.text
    nodes = json.loads(text)

    running_nodes = []
    failed_nodes = []
    for node in nodes:
        if not node['running']:
            failed_nodes.append(node['name'])
        if node['running']:
            running_nodes.append(node['name'])
    if len(failed_nodes) == 1:
       print ("WARNING: RabbitMQ cluster is degraged: Not running %s" % failed_nodes[0])
       sys.exit(1)
    elif len(failed_nodes) >= 2:
       print ("CRITICAL: RabbitMQ cluster is critical: Not running %s" % failed_nodes)
       sys.exit(2)
    else:
       print ("OK: RabbitMQ cluster members: %s" % (" ".join(running_nodes)))
       sys.exit(0)


USAGE = """Usage: ./check_rabbitmq -a [action] -C [critical] -W [warning]
           Actions:
           - connection_count
             checks the number of connection in rabbitmq's list_connections
           - queues_count
             checks the count in each of the queues in rabbitmq's list_queues
           - mem_usage
             checks to ensure mem usage of rabbitmq process does not exceed 50%
           - aliveness
             Use the /api/aliveness-test API to send/receive a message. (requires -u username -p password args)
           - cluster_status
             Parse /api/nodes to check the cluster status. (requires -u username -p password"""

if __name__ == "__main__":
    parser = OptionParser(USAGE)
    parser.add_option("-a", "--action", dest="action",
                      help="Action to Check")
    parser.add_option("-C", "--critical", dest="critical", default=0,
                      type="int", help="Critical Threshold")
    parser.add_option("-W", "--warning", dest="warning", default=0,
                      type="int", help="Warning Threshold")
    parser.add_option("-u", "--username", dest="username", default="guest",
                      type="string", help="RabbitMQ username, Default guest")
    parser.add_option("-p", "--password", dest="password", default="guest",
                      type="string", help="RabbitMQ password, Default guest")
    parser.add_option("-t", "--timeout", dest="timeout", default=1,
                      type="int", help="Request Timeout, defaults to 1 second")
    parser.add_option("-c", "--cluster", dest="cluster", default="localhost",
                      type="string", help="Cluster IP/DNS name, defaults to localhost")
    parser.add_option("-s", "--skip", dest="skip", default=0,
                      type="int", help="skip items with less than SKIP entries")
    (options, args) = parser.parse_args()

    if options.action == "connection_count":
        check_connection_count(options.critical, options.warning)
    elif options.action == "queues_count":
        check_queues_count(options.critical, options.warning, options.skip)
    elif options.action == "mem_usage":
        check_mem_usage(options.critical, options.warning)
    elif options.action == "aliveness":
        check_aliveness(options.username, options.password, options.timeout, options.cluster)
    elif options.action == "cluster_status":
        check_cluster(options.username, options.password, options.timeout, options.cluster)
    else:
        print ("Invalid action: %s" % options.action)
        print (USAGE)
