#!/usr/bin/python3
#
# Displays a summary of Taskomatic activities in progress
#
# Copyright (c) 2016 SUSE LLC
#
# This software is licensed to you under the GNU General Public License,
# version 2 (GPLv2). There is NO WARRANTY for this software, express or
# implied, including the implied warranties of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. You should have received a copy of GPLv2
# along with this software; if not, see
# http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt.
#

import datetime
import time
import struct
import signal
import sys
import curses
import logging
import argparse
import os.path

from spacewalk.server import rhnSQL

from io import BytesIO


DEFAULT_LOGFILE = './taskotop.log'

parser = argparse.ArgumentParser(description="Taskotop is a tool to monitor what taskomatic " +
                                 "is currently doing.")
mode_group = parser.add_mutually_exclusive_group()
mode_group.add_argument("-e", "--each-task", action="store_true",
                    dest="eachTask",
                    help="Display most recent run for each task instead of recent task run history.")
mode_group.add_argument("-r", "--recent-history", action="store_true",
                    dest="recentHistory",
                    help="Display recent history of task runs.  This is the default display mode.")
parser.add_argument("-H", "--human-readable", action="store_true",
                    dest="humanReadable",
                    help="Use human readable time output.  Time will be displayed " +
                    "in the format [[days:]hours:]min:sec instead of total seconds.")
parser.add_argument("-m", "--max-age", type=int,
                    dest="maxAge", default=60,
                    help="Retrieve past events up to this old (in seconds, default 60). " +
                    "This has no effect if -e/--each-task is specified.")
parser.add_argument("-n", type=int,
                    dest="numIterations", default=0,
                    help="taskotop will iterate the specified number of times and then exit. " +
                    "If not specified or 0 (the default), taskotop will run until the user exits taskotop.")
parser.add_argument("-t", "--taskomatic", action="store_true",
                    dest="displayTaskomatic",
                    help="Include taskomaticd process information in the output.")
parser.add_argument("-v", "--verbose", action="count",
                    dest="verbose", default=0,
                    help="Increase log output verbosity.  Specify multiple times, up to 4 " +
                    "to increase verbosity.")
parser.add_argument("--hide-elapsed", action="store_true",
                    dest="hideElapsed",
                    help="Hide the ELAPSED column in the display.")
parser.add_argument("--show-start", action="store_true",
                    dest="showStart",
                    help="Include the START column in the display.")
parser.add_argument("--logfile", dest="logfile", default=DEFAULT_LOGFILE,
                    help="Specify logfile to use if at least one verbose arg specified.  " +
                    "Default is %s" % DEFAULT_LOGFILE)
args = parser.parse_args()

DISPLAY_TASKOMATIC = args.displayTaskomatic
DISPLAY_MODE_RECENT_HISTORY = 0
DISPLAY_MODE_EACH_TASK = 1
DISPLAY_MODE = DISPLAY_MODE_RECENT_HISTORY
SHOW_ELAPSED_TIME = True
SHOW_START_TIME = False
HUMAN_READABLE = args.humanReadable
MAXIMUM_AGE = args.maxAge
LOGGING_ENABLED = False

INTERACTIVE_HELP = [
"Help for taskotop interactive commands",
"",
"  e    Change display mode to show each task's latest run.",
"  h    Display this help page.",
"  H    Toggle human readable format.  This toggles the time display between",
"       [[days:]hours:]min:sec format and total seconds.",
"  q    Quit taskotop.",
"  r    Change display mode to show recent history of task runs.",
"  t    Toggle taskomatic process information display.",
""
]


def log_debug(msg, *args, **kwargs):
    if LOGGING_ENABLED:
        logging.debug(msg, *args, **kwargs)

def log_info(msg, *args, **kwargs):
    if LOGGING_ENABLED:
        logging.info(msg, *args, **kwargs)

def log_warning(msg, *args, **kwargs):
    if LOGGING_ENABLED:
        logging.warning(msg, *args, **kwargs)

def log_error(msg, *args, **kwargs):
    if LOGGING_ENABLED:
        logging.error(msg, *args, **kwargs)

class CursesDisplayBuilder:
    """Builder class to make laying out a curses display in a table format easier"""
    JUSTIFY_LEFT = 0
    JUSTIFY_CENTER = 1
    JUSTIFY_RIGHT = 2
    row = []

    def add_column(self, width, heading, heading_justify, data_justify, format_data_callable, data_key = ""):
        self.row.append(CursesDisplayColumn(width,
                                       heading,
                                       heading_justify,
                                       data_justify,
                                       format_data_callable,
                                       data_key))

    def string_of_length(self, length, char="*"):
        retval = ""
        for i in range(0, length):
            retval += char
        return retval

    def add_column_value_to_screen(self, screen, value, ypos, xpos, column_width, justify, maxy, maxx):
        if len(value) > column_width:
            value = self.string_of_length(column_width)
        addxpos = int(xpos + justify * (column_width - len(value)) / 2)
        # need to skip writing to the last character on the last line
        # so the cursor has a place to exist, even though its not
        # visible
        if ypos == maxy - 1 and addxpos + len(value) == maxx:
            value = value[0:-1]
        log_debug('y=%d x=%d value \'%s\'  value length %d' % (ypos, addxpos, value, len(value)))
        if len(value) > 0:
            screen.addstr(ypos, addxpos, value)
        return column_width

    def output_to_screen(self, screen, data, starty = 0):
        maxy, maxx = screen.getmaxyx()
        log_debug('maxy, maxx is %d %d' % (maxy, maxx))
        last_column_to_display = 0
        current_width = 0
        for column in self.row:
            if current_width + column.width <= maxx:
                last_column_to_display += 1
                current_width += column.width + 1
            else:
                break
            log_debug('column \'%s\': width %d, next column starts at %d' % (column.heading, column.width, current_width))

        current_row = 0
        if current_row + starty < maxy:
            current_x = 0
            for colindex in range (0, last_column_to_display):
                column = self.row[colindex]
                width = column.width
                if width == -1:
                    width = maxx - current_x
                value = column.heading
                current_x += self.add_column_value_to_screen(screen, value, current_row + starty, current_x, width, column.heading_justify, maxy, maxx) + 1
            current_row += 1

        for rowdata in data:
            if current_row + starty >= maxy:
                break
            current_x = 0
            for colindex in range (0, last_column_to_display):
                column = self.row[colindex]
                width = column.width
                if width == -1:
                    width = maxx - current_x
                value = column.format_data_callable(rowdata, column.data_key, width)
                current_x += self.add_column_value_to_screen(screen, value, current_row + starty, current_x, width, column.data_justify, maxy, maxx) + 1
            current_row += 1

        return current_row - 1

class CursesDisplayColumn:
    """Data structure for representing a column within CursesDisplayBuilder"""
    def __init__(self, width, heading, heading_justify, data_justify, format_data_callable, data_key = ""):
        self.width = width
        self.heading = heading
        self.heading_justify = heading_justify
        self.data_justify = data_justify
        self.format_data_callable = format_data_callable
        self.data_key = data_key


def get_tasko_runs_newer_than_age(maximum_age):
    """Returns data about recent Taskomatic task runs from the database."""
    task_query = rhnSQL.prepare("""
        SELECT
            task.name AS name,
            run.id AS id,
            run.start_time AS start_time,
            run.end_time AS end_time,
            schedule.data AS data

            FROM rhnTaskoRun run
                JOIN rhnTaskoSchedule schedule ON schedule.id = run.schedule_id
                JOIN rhnTaskoTemplate template ON template.id = run.template_id
                JOIN rhnTaskoTask task ON task.id = template.task_id

            WHERE
                run.start_time IS NOT NULL
                    AND (run.end_time IS NULL OR run.end_time > :timelimit)

            ORDER BY end_time DESC NULLS FIRST, start_time ASC
    """);
    # trim those older than maximum_age
    task_query.execute(timelimit = datetime.datetime.now() - datetime.timedelta(seconds=maximum_age))

    # HACK: simulate fetchall_dict() in such a way BLOBs are only read once
    # (otherwise we get exceptions)
    result = []
    row = task_query.fetchone_dict()
    while row is not None:
        row["data"] =  rhnSQL.read_lob(row["data"])
        result.append(row)
        row = task_query.fetchone_dict()
    return result


def get_tasko_runs_latest_each_task():
    """Returns data about latest of each Taskomatic task runs from the database."""
    task_query = rhnSQL.prepare("""
        SELECT
            task.name AS name,
            run.id AS id,
            run.start_time AS start_time,
            run.end_time AS end_time,
            schedule.data AS data

            FROM (SELECT template_id,
                         CASE WHEN MAX(CASE WHEN end_time IS NULL THEN 1 ELSE 0 END) = 0
                              THEN MAX(end_time)
                         END AS end_time
                      FROM rhnTaskoRun GROUP BY template_id) m
                JOIN rhnTaskoRun run ON run.template_id = m.template_id
                    AND (run.end_time = m.end_time OR (run.end_time IS NULL AND m.end_time IS NULL))
                JOIN rhnTaskoSchedule schedule ON schedule.id = run.schedule_id
                JOIN rhnTaskoTemplate template ON template.id = run.template_id
                JOIN rhnTaskoTask task ON task.id = template.task_id

            ORDER BY end_time DESC NULLS FIRST, start_time ASC
    """);

    task_query.execute()

    # HACK: simulate fetchall_dict() in such a way BLOBs are only read once
    # (otherwise we get exceptions)
    result = []
    row = task_query.fetchone_dict()
    while row is not None:
        row["data"] =  rhnSQL.read_lob(row["data"])
        result.append(row)
        row = task_query.fetchone_dict()
    return result


def get_channel_names(ids):
    """Gets the channel names corresponding to channel ids from the database."""
    if len(ids) == 0:
        return []

    query = rhnSQL.prepare("""
        SELECT DISTINCT label
            FROM rhnChannel
            WHERE id IN ({0})
            ORDER BY label
    """.format(",".join(ids)));
    query.execute()

    return [tuple[0] for tuple in query.fetchall()]

def get_current_repodata_channel_names():
    """Gets the channel names of currenlty running repodata tasks from the database."""
    query = rhnSQL.prepare("""
        SELECT DISTINCT channel_label
            FROM rhnRepoRegenQueue
            WHERE next_action IS NULL
            ORDER BY channel_label
    """);
    query.execute()

    return [row[0] for row in query.fetchall()]

def extract_channel_ids(bytes):
    """Extracts channel ids from a Java Map in serialized form."""
    # HACK: this heuristicallty looks for strings, which are marked with 't',
    # two bytes for the length and the string chars themselves. If they
    # represent numbers, we assume they are channel_ids
    # (currently this is the case)
    java_strings = []
    io = BytesIO(bytes)

    while True:
        char = io.read(1)
        if char == b"":
            break
        elif char == b"t":
            oldpos = io.tell()
            try:
                length = struct.unpack(">H", io.read(2))[0]
                java_string = struct.unpack(">{0}s".format(length), io.read(length))
                java_strings += java_string
            except struct.error:
                pass # not a real string, ignore
            io.seek(oldpos)
    # of those found, filter the ones looking like a number
    return [java_string.decode() for java_string in java_strings if java_string.isdigit()]

# column indexes for ps output
IX_PID  = 0
IX_PPID = 1
IX_PCPU = 2
IX_PMEM = 3
IX_RSS  = 4
IX_TIME = 5
# comm can have whitespace, but taskomaticd does not and
# is the only process for which we parse start time.  As
# such, comm must be the second to last column, followed
# by lstart.
IX_COMM = 6
# lstart displays in format [weekday month day time year]
# ex Mon Mar 13 06:56:22 2017
IX_START_MON = 8
IX_START_DAY = 9
IX_START_TIME = 10
IX_START_YEAR = 11

def taskomaticd_ps():
    """use ps command to retrieve pids[], cputimeseconds, cpupercent, memsize, mempercent, upsince"""
    import subprocess
    pids = []
    upsince = ""
    cputimeseconds = 0
    cpupercent = 0.0
    mempercent = 0.0
    memsize = 0
    out = subprocess.Popen(['ps', '--no-headers', '-eHo', 'pid,ppid,pcpu,pmem,rss,time,comm,lstart'],
    stdout=subprocess.PIPE).communicate()[0].splitlines()
    for line in out:
        values = line.split()
        if (values[IX_COMM] == "taskomaticd" or values[IX_PPID] in pids):
            pids.append(values[IX_PID])
            if not upsince:
                upsince = '%s %s %s %s' % (values[IX_START_DAY], values[IX_START_MON], values[IX_START_YEAR], values[IX_START_TIME])
            cputimeseconds += seconds_from_time(values[IX_TIME])
            cpupercent += float(values[IX_PCPU])
            mempercent += float(values[IX_PMEM])
            memsize += int(values[IX_RSS])
    return pids, cputimeseconds, cpupercent, memsize, mempercent, upsince

def seconds_from_time(time):
    """convert time in [DD-]hh:mm:ss format to total seconds"""
    parts = time.split(':')
    days = 0
    hours = 0
    min = int(parts[1])
    sec = int(parts[2])
    if '-' in parts[0]:
        dh = parts[0].split('-')
        days = int(dh[0])
        hours = int(dh[1])
    else:
        hours = int(parts[0])
    return (((((days * 24) + hours) * 60) + min) * 60) + sec

def add_line(screen, line, ypos):
    """Truncate the given line and add it to the screen at the specified position"""
    maxy, maxx = screen.getmaxyx()
    last_line = 0
    if ypos >= maxy:
        return 0
    # need to skip writing to the last character on the last line
    # so the cursor has a place to exist, even though its not
    # visible.  If last_line is 1, we truncate that last character
    # from the line being added to the screen.
    if ypos == maxy - 1:
        last_line = 1
    if len(line) > maxx - last_line:
        line = line[0:maxx - last_line]
    screen.addstr(ypos, 0, line)
    return 1

def show_taskomatic_header(screen):
    """Get taskomatic telemetry and display up to maxy, maxx"""
    ypos = 0
    try:
        pids, cputimeseconds, cpupercent, memsize, mempercent, upsince = taskomaticd_ps()
        line = 'taskomaticd is not running'

        if (len(pids) > 0):
            line = 'taskomaticd pid: %s  up since: %s  child processes: %d' % (pids[0], upsince, len(pids) - 1)
        ypos += add_line(screen, line, ypos)

        if HUMAN_READABLE:
            line = 'cpu %%: %2.1f  cpu total time: %s' % (cpupercent, seconds_to_hms_string(cputimeseconds))
        else:
            line = 'cpu %%: %2.1f  cpu total seconds: %d' % (cpupercent, cputimeseconds)
        ypos += add_line(screen, line, ypos)

        line = 'mem %%: %2.1f  mem total (kB): %d' % (mempercent, memsize)
        ypos += add_line(screen, line, ypos)
    except Exception:
        log_error('failed to issue ps command to retrieve taskomaticd information')
        line = 'failed to issue ps command to retrieve taskomaticd information'
        ypos += add_line(screen, line, ypos)
    return ypos

def display_interactive_help(screen):
    """Display the interactive help on the screen"""
    log_debug("Displaying interactive help")
    while True:
        screen.erase()
        maxy, maxx = screen.getmaxyx()
        ypos = 0
        for line in INTERACTIVE_HELP:
            if ypos > maxy - 2:
                break
            ypos += add_line(screen, line, ypos)
        presskey_text = "Press any key "
        ypos += add_line(screen, presskey_text, ypos)
        screen.refresh()
        try:
            c = screen.getch()
            if c > -1 and c < 256:
                break
        except Exception:
            pass

def process_interactive_input(c, screen):
    """Process the interactive input from the user"""
    global DISPLAY_MODE
    global HUMAN_READABLE
    global DISPLAY_TASKOMATIC
    if c == ord('q'):
        system_exit(0)
    elif c == ord('e'):
        DISPLAY_MODE = DISPLAY_MODE_EACH_TASK
        log_debug('Display Mode is now Each Task')
    elif c == ord('r'):
        DISPLAY_MODE = DISPLAY_MODE_RECENT_HISTORY
        log_debug('Display Mode is now Recent History')
    elif c == ord('H'):
        HUMAN_READABLE = not HUMAN_READABLE
        log_debug('HUMAN_READABLE is now %s' % HUMAN_READABLE)
    elif c == ord('t'):
        DISPLAY_TASKOMATIC = not DISPLAY_TASKOMATIC
        log_debug('DISPLAY_TASKOMATIC is now %s' % DISPLAY_TASKOMATIC)
    elif c == ord('h'):
        # turn off half delay with nocbreak(), then
        # turn cbreak back on so no need to hit 'enter'
        curses.nocbreak()
        curses.cbreak()
        curses.curs_set(1)
        display_interactive_help(screen)
        curses.curs_set(0)
        curses.halfdelay(10)

def seconds_to_hms_string(seconds):
    result, s = divmod(seconds, 60)
    result, m = divmod(result, 60)
    result, h = divmod(result, 24)

    retval = ""
    if result > 0:
        retval += '%d:' % result
    if h > 0:
        retval += '%02d:' % h
    retval += '%02d:' % m
    retval += '%02d' % s
    return retval

def format_elapsed_time(rowdata, data_key, width):
    """Formats the elapsed time for display."""
    end = datetime.datetime.now()
    if rowdata["end_time"]:
        end = rowdata["end_time"]

    td = end.replace(tzinfo=None) - rowdata["start_time"].replace(tzinfo=None)
    seconds = int((td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6)
    if HUMAN_READABLE:
        return seconds_to_hms_string(seconds)
    return "{0}s".format(seconds)

def format_datetime(rowdata, data_key, width):
    if rowdata[data_key]:
        return rowdata[data_key].strftime('%Y-%m-%d %H:%M:%S')
    return ""

def format_start_time(rowdata, data_key, width):
    if rowdata["start_time"]:
        return format_datetime(rowdata, "start_time", width)
    return ""

def format_end_time(rowdata, data_key, width):
    if rowdata["end_time"]:
        return format_datetime(rowdata, "end_time", width)
    if rowdata["start_time"]:
        if SHOW_ELAPSED_TIME:
            return"(running)"
        return "(running {0})".format(format_elapsed_time(rowdata, "", width))
    return ""

def format_channel_names(rowdata, data_key, width):
    """format channel names for display"""
    channel_names = []
    if rowdata["data"]:
        channel_names = get_channel_names(extract_channel_ids(rowdata["data"]))
    if rowdata["name"] == "channel-repodata" and not rowdata["end_time"]:
        channel_names = get_current_repodata_channel_names()
    retval = ", ".join(channel_names)
    if len(retval) > width:
        retval = retval[0:width]
    return retval

def format_int(rowdata, data_key, width):
    """format an int for display"""
    return "{id:d}".format(**rowdata)

def format_string(rowdata, data_key, width):
    """format a string for display"""
    return rowdata[data_key]


def main(screen):
    """Computes and displays runs every second."""
    curses.halfdelay(10)

    rhnSQL.initDB()

    # exit gracefully on ctrl-c
    signal.signal(signal.SIGINT, lambda signal, frame: sys.exit(0))

    # hide cursor
    curses.curs_set(0)

    # set up curses display builder
    display_builder = CursesDisplayBuilder()
    display_builder.add_column(11,
                            "RUN ID",
                            display_builder.JUSTIFY_RIGHT,
                            display_builder.JUSTIFY_LEFT,
                            format_int,
                            "id"),
    display_builder.add_column(30,
                            "TASK NAME",
                            display_builder.JUSTIFY_RIGHT,
                            display_builder.JUSTIFY_RIGHT,
                            format_string,
                            "name"),
    if SHOW_START_TIME:
        display_builder.add_column(19,
                            "START",
                            display_builder.JUSTIFY_CENTER,
                            display_builder.JUSTIFY_RIGHT,
                            format_start_time),
    if SHOW_ELAPSED_TIME:
        display_builder.add_column(9,
                            "ELAPSED",
                            display_builder.JUSTIFY_RIGHT,
                            display_builder.JUSTIFY_RIGHT,
                            format_elapsed_time),
    display_builder.add_column(19,
                            "END",
                            display_builder.JUSTIFY_CENTER,
                            display_builder.JUSTIFY_RIGHT,
                            format_end_time),
    display_builder.add_column(-1,
                            "CHANNEL",
                            display_builder.JUSTIFY_CENTER,
                            display_builder.JUSTIFY_RIGHT,
                            format_channel_names)

    num_iterations = args.numIterations
    if args.numIterations == 0:
        num_iterations = -1
    while num_iterations != 0:
        if num_iterations > 0:
            num_iterations -= 1
            log_debug('updating screen (%d remaining)' % num_iterations)
        else:
            log_debug("updating screen")
        screen.erase()
        current_y = 0
        if DISPLAY_TASKOMATIC:
            current_y += show_taskomatic_header(screen) + 1
        result = []
        if DISPLAY_MODE == DISPLAY_MODE_EACH_TASK:
             result = get_tasko_runs_latest_each_task()
        else:
             result = get_tasko_runs_newer_than_age(MAXIMUM_AGE)
        display_builder.output_to_screen(screen, result, current_y)

        screen.refresh()
        try:
            # given up to a 1 second to retrieve user input
            # because of halfdelay(10) call earlier
            c = screen.getch()
            if c != curses.ERR:
                process_interactive_input(c, screen)
        except Exception:
            e = sys.exc_info()[1]
            log_warning('getch() exception %s' % e)
            pass


def system_exit(code, msgs=None):
    """Exit with a code and optional message(s). Saved a few lines of code."""
    if msgs:
        if type(msgs) not in [type([]), type(())]:
            msgs = (msgs,)
        for msg in msgs:
            sys.stderr.write(str(msg) + '\n')
    sys.exit(code)

if args.numIterations < 0:
    system_exit(2, "ERROR: NUMITERATIONS must not be a negative value")

if args.maxAge < 0:
    system_exit(2, "ERROR: MAXAGE must not be a negative value")

if args.logfile != DEFAULT_LOGFILE and args.verbose == 0:
    system_exit(2, "ERROR: --logfile command line option requires use of -v or --verbose option")

if args.verbose > 0:
    LOGGING_ENABLED = True
    loglevel = logging.NOTSET
    if args.verbose == 1:
        loglevel = logging.ERROR
    elif args.verbose == 2:
        loglevel = logging.WARNING
    elif args.verbose == 3:
        loglevel = logging.INFO
    else:
        loglevel = logging.DEBUG
    dirname = os.path.dirname(args.logfile)
    if dirname != '' and not os.path.isdir(os.path.dirname(args.logfile)):
        system_exit(1, "ERROR: Directory %s in specified logfile doesn't exist" % dirname)
    try:
        lfile = open(args.logfile, 'a')
        lfile.write("Logging started with verbosity %d on %s\n" % (args.verbose,  time.strftime("%c")))
        lfile.write("taskotop command args: %s\n" % ' '.join(sys.argv[1:]))
        lfile.close()
    except Exception:
        system_exit(1, "ERROR: Failed to open and write to logfile %s" % args.logfile)

    logging.basicConfig(filename=args.logfile, level=loglevel, format='%(asctime)s %(levelname)s:%(message)s')
    log_debug('HUMAN_READABLE is %s' % HUMAN_READABLE)
    log_debug('MAXIMUM_AGE is %d' % MAXIMUM_AGE)

if args.eachTask:
    DISPLAY_MODE = DISPLAY_MODE_EACH_TASK
    log_debug('Display Mode is Each Task')
if args.recentHistory:
    DISPLAY_MODE = DISPLAY_MODE_RECENT_HISTORY
    log_debug('Display Mode is Recent History')
if args.showStart:
    SHOW_START_TIME = True
    log_debug('START column will be displayed')
if args.hideElapsed:
    SHOW_ELAPSED_TIME = False
    log_debug('ELAPSED column will be hidden')

try:
    curses.wrapper(main)
except rhnSQL.SQLConnectError:
    e = sys.exc_info()[1]
    system_exit(20, ["ERROR: Can't connect to the database: %s" % str(e), "Check if your database is running."])
except Exception:
    e = sys.exc_info()[1]
    system_exit(1, "ERROR: Some problems occurred during getting information about tasks: %s" % str(e))


