#!/usr/bin/env python
# -----------------------------------------------------------------------
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# -----------------------------------------------------------------------

features =  [ 'stop_ducc from head node only',
              'support --head which stops non-agent daemons only on local head',
              'support --agents which stops on any stated node',
              '--help explains that stop_ducc disables autostart',
              'employ broadcast for --stop and --quiesce of ducc daemons',
              'employ kill -15 for --stop of broker and database',
              'show ssh before and after for kill -15',
            ]

import sys

version_min = [2, 7]
version_info = sys.version_info
version_error = False
if(version_info[0] < version_min[0]):
    version_error = True
elif(version_info[0] == version_min[0]):
    if(version_info[1] < version_min[1]):
        version_error = True
if(version_error):
    print('Python minimum requirement is version '+str(version_min[0])+'.'+str(version_min[1]))
    sys.exit(1)

import argparse
import datetime
import textwrap
import time
import traceback

from threading import BoundedSemaphore
from threading import Lock
from threading import Thread

from argparse import RawDescriptionHelpFormatter

from ducc_util  import DuccUtil

# lock for messages
lock_print = Lock()

# print message
def output(msg):
    with lock_print:
        print msg

# produce a time stamp
def get_timestamp():
    tod = time.time()
    timestamp = datetime.datetime.fromtimestamp(tod).strftime('%Y-%m-%d %H:%M:%S')      
    return timestamp

_flag_debug = False

# record debug message
def debug(mn,text):
    if(_flag_debug):
        type ='D'
        msg = get_timestamp()+' '+type+' '+mn+' '+text
        output(msg)

class StopDucc(DuccUtil):

    def _fn(self):
        fpath = __file__.split('/')
        flen = len(fpath)
        return fpath[flen-1]

    # return class name
    def _cn(self):
        return self.__class__.__name__
    
    # return method name
    def _mn(self):
        return traceback.extract_stack(None,2)[0][2]
    
    c_agent = 'agent'
    c_ag = 'ag'
    c_broker = 'broker'
    c_br = 'br'
    c_database = 'database'
    c_db = 'db'
    c_orchestrator = 'orchestrator'
    c_or = 'or'
    c_pm = 'pm'
    c_rm = 'rm'
    c_sm = 'sm'
    c_ws = 'ws'
    
    n_ag = 'ag'
    n_br = 'br'
    n_db = 'db'
    n_or = 'or'
    n_pm = 'pm'
    n_rm = 'rm'
    n_sm = 'sm'
    n_ws = 'ws'
    
    components = [ c_agent, c_pm, c_rm, c_sm, c_or, c_ws, c_broker, c_database, ]
    
    shortname = { c_agent:n_ag, 
                  c_broker:n_br, 
                  c_database:n_db, 
                  c_or:n_or, 
                  c_pm:n_pm,
                  c_rm:n_rm,
                  c_sm:n_sm,
                  c_ws:n_ws,
                  c_ag:n_ag,
                  c_br:n_br,
                  c_db:n_db,
                  c_orchestrator:n_or,
                }
    
    longname = { n_ag:c_agent, 
                 n_br:c_broker, 
                 n_db:c_database, 
                 n_or:c_orchestrator, 
                 n_pm:c_pm,
                 n_rm:c_rm,
                 n_sm:c_sm,
                 n_ws:c_ws,
               }
    
    option_agents = '--agents'
    option_all = '--all'
    option_component = '--component'
    option_debug = '--debug'
    option_head = '--head'
    option_kill = '--kill'
    option_maxthreads = '--maxthreads'
    option_nodelist = '--nodelist'
    option_stop = '--stop'
    option_quiesce = '--quiesce-then-stop'
    
    cmd_kill_9 = 'kill -9'
    cmd_kill_15 = 'kill -15'
    cmd_ssh = 'ssh'
    cmd_start_ducc = 'start_ducc'
    
    kw_DUCC = 'DUCC'
    
    maxthreads = 10
    default_stop = 60
    
    sig15 = 15
    sig9 = 9
    
    def _exit(self):
        sys.exit(1)

    def _help(self):
        self.parser.print_help()
        self._exit
    
    def get_epilog(self):
        epilog = ''
        epilog = epilog+'Notes:'
        epilog = epilog+'\n'
        epilog = epilog+'N1. '+self._fn()+' '+'is limited to running on a head node.'
        epilog = epilog+'\n'
        epilog = epilog+'N2. '+self._fn()+' '+'updates database autostart table with "stop" status.'
        epilog = epilog+'\n'
        epilog = epilog+'N3. '+self._fn()+' '+self.option_kill+' option employs '+self.cmd_ssh+' with '+self.cmd_kill_9+'.'
        epilog = epilog+'\n'
        epilog = epilog+'N4. '+self._fn()+' '+self.option_stop+' and '+self.option_quiesce+' options employ broadcast via broker.'\
                       +'\n'\
                       +'    '+'The broker and database are exceptions, whereby '+self.cmd_ssh+' with '+self.cmd_kill_15+' is employed.'
        epilog = epilog+'\n\n'
        epilog = epilog+'Examples:'
        epilog = epilog+'\n\n'
        epilog = epilog+'E1. kill all daemons that were started, as recorded in the database autostart table'
        epilog = epilog+'\n'
        epilog = epilog+'> '+self._fn()+' '+self.option_all+' '+self.option_kill
        epilog = epilog+'\n\n'
        epilog = epilog+'E2. stop all head node daemons on the present node'
        epilog = epilog+'\n'
        epilog = epilog+'> '+self._fn()+' '+self.option_head+' '+self.option_stop
        epilog = epilog+'\n\n'
        epilog = epilog+'E3. stop all agents via broadcast, each will issue '+self.cmd_kill_15+' to children'\
                       +'\n'\
                       +'    '+'then exit after a maximum of '+str(self.default_stop)+' seconds, by default'
        epilog = epilog+'\n'
        epilog = epilog+'> '+self._fn()+' '+self.option_agents+' '+self.option_stop
        epilog = epilog+'\n\n'
        epilog = epilog+'E4. quiesce all agents, each will issue '+self.cmd_kill_15+' to children then exit only'\
                       +'\n'\
                       +'    '+'after all children have exited'
        epilog = epilog+'\n'
        epilog = epilog+'> '+self._fn()+' '+self.option_agents+' '+self.option_quiesce
        epilog = epilog+'\n\n'
        epilog = epilog+'E5. quiesce agents listed in groupA.nodes and groupB.nodes, each will issue '+self.cmd_kill_15\
                       +'\n'\
                       +'    '+'to children then exit only after all children have exited'
        epilog = epilog+'\n'
        epilog = epilog+'> '+self._fn()+' '+self.option_nodelist+' groupA.nodes '+self.option_nodelist+' groupB.nodes '+self.option_quiesce
        epilog = epilog+'\n\n'
        epilog = epilog+'E6. stop agents on nodes nodeC8 and nodeD5, each will issue '+self.cmd_kill_15\
                       +'\n'\
                       +'    '+'to children then exit after a maximum of '+str(90)+' seconds.'
        epilog = epilog+'\n'
        epilog = epilog+'> '+self._fn()+' '+self.option_component+' '+self.c_agent+'@nodeC8 '+self.option_component+' '+self.c_agent+'@nodeD5 '+self.option_stop+' '+str(90)
        epilog = epilog+'\n\n'
        epilog = epilog+'E7. stop orchestrator'
        epilog = epilog+'\n'
        epilog = epilog+'> '+self._fn()+' '+self.option_component+' '+self.c_or+' '+self.option_stop
        epilog = epilog+'\n\n'
        epilog = epilog+'E8. kill orchestrator on alternate head node nodeX3'
        epilog = epilog+'\n'
        epilog = epilog+'> '+self._fn()+' '+self.option_component+' '+self.c_or+'@nodeX3'+' '+self.option_kill
        return epilog
    
    help_all        = 'Stop all DUCC management and agent processes by using database entries recorded by start_ducc.'
    help_head       = 'Stop the DUCC  management processes on the present head node by using database entries recorded by start_ducc.'
    help_agents     = 'Stop the DUCC agents processes on all nodes by using database entries recorded by '+cmd_start_ducc+'.'
    help_nodelist   = 'Stop agents on the nodes in the nodefile.  Multiple nodefiles may be specified.'
    help_component  = 'Stop a specific component.  The component may be qualified with the node name using the @ symbol: component@node.'\
                    + '  If node is not specified, then the localhost is presumed.  Multiple components may be specified. components = '+str(components)+'.'\
                    + '  Specification of a head node component other than on the present head node is disallowed unless '+option_kill+' option is also specified.'\
                    + '  Specification of broker or database is disallowed unless that component is automanaged by '+kw_DUCC+'.'   
    help_kill       = 'Stop the component(s) forcibly and immediately using '+cmd_ssh+' with '+cmd_kill_9+'.  Use this only if a normal stop does not work (e.g. the process may be hung).'
    help_stop       = 'Stop the component(s) gracefully using broadcast.  Agents allow children specified time (in seconds) to exit.  Default is '+str(default_stop)+'.'\
                    + '  Broadcast is not used for broker, database, and remote head node daemons; instead a direct kill -15 is employed.'
    help_quiesce    = 'Stop the component(s) gracefully using broadcast.  Agents exit only when no children exist.  Children are given infinite time to exit.'
    help_maxthreads = 'Maximum concurrent threads.  Default = '+str(maxthreads)+'.'
    help_debug      = 'Display debugging messages.'
    
    # get user specified command line args
    def get_args(self):
        self.parser = argparse.ArgumentParser(formatter_class=RawDescriptionHelpFormatter,epilog=self.get_epilog())
        group1 = self.parser.add_mutually_exclusive_group(required=True)
        group1.add_argument(self.option_all, '-a', action='store_true', help=self.help_all)
        group1.add_argument(self.option_head, action='store_true', help=self.help_head)
        group1.add_argument(self.option_agents, action='store_true', help=self.help_agents)
        group1.add_argument(self.option_nodelist, '-n', action='append', help=self.help_nodelist)
        group1.add_argument(self.option_component, '-c', action='append',  help=self.help_component)
        group2 = self.parser.add_mutually_exclusive_group()
        group2.add_argument(self.option_kill, '-k', action='store_true',  help=self.help_kill)
        group2.add_argument(self.option_stop, '-s', action='store', type=int, nargs='?', const=self.default_stop, help=self.help_stop)
        group2.add_argument(self.option_quiesce, '-q', action='store_true', help=self.help_quiesce)
        self.parser.add_argument(self.option_maxthreads, '-m', action='store', type=int, default=None, help=self.help_maxthreads)
        self.parser.add_argument(self.option_debug, '-d', action='store_true', help=self.help_debug)
        self.args = self.parser.parse_args()
        # mutual choice
        if(not self.args.kill):
            if(not self.args.quiesce_then_stop):
                if(self.args.stop == None):
                    self.args.stop = self.default_stop
        # special cases
        if(self.args.kill):
            if(self.args.maxthreads == None):
                self.args.maxthreads = self.maxthreads
        elif(self.args.stop):
            if(self.args.maxthreads == None):
                self.args.maxthreads = 2
        elif(self.args.quiesce_then_stop):
            if(self.args.maxthreads == None):
                self.args.maxthreads = 2
        elif(self.args.maxthreads != None):
            self.parser.error(self.option_maxthreads+' requires '+self.option_kill)
        # debug
        if(self.args.debug):
            global _flag_debug
            _flag_debug = True
        text = str(self.args)
        debug(self._mn(),text)
    
    db_list = None
    
    # fetch and cache list of tuples comprising 
    # daemon@node from database autostart table
    def get_db_list(self):
        if(self.db_list == None):
            self.db_list = self.db_acct_query()
        text = 'list='+str(self.db_list)
        debug(self._mn(),text)
        return self.db_list
    
    # --all
    def all(self):
        text = str(self.args.all)
        debug(self._mn(),text)
        # get list of tuples from DB: 
        # [ host, component, state ]
        list = self.get_db_list()
        return list
    
    # --head
    def head(self):
        text = str(self.args.head)
        debug(self._mn(),text)
        # get list of tuples from DB: 
        # [ host, component, state ]
        db_list = self.get_db_list()
        list = []
        this_node = self.get_node_name()
        for item in db_list:
            node = item[0]
            component = item[1]
            if(component == self.n_ag):
                continue
            if(component == self.n_db):
                if(not self.automanage_database):
                    continue
            if(component == self.n_br):
                if(not self.automanage_broker):
                    continue
            if(node != this_node):
                continue
            list.append(item)
        return list
    
    # --agents
    def agents(self):
        text = str(self.args.agents)
        debug(self._mn(),text)
        # get list of tuples from DB: 
        # [ host, component, state ]
        db_list = self.get_db_list()
        list = []
        for item in db_list:
            node = item[0]
            component = item[1]
            if(component == self.n_ag):
                list.append(item)
                text = 'add: '+'node:'+node+' '+'component:'+component
                debug(self._mn(),text)
            else:
                text = 'skip: '+'node:'+node+' '+'component:'+component
                debug(self._mn(),text)
        return list
    
    # --nodelist
    def nodelist(self):
        text = str(self.args.nodelist)
        debug(self._mn(),text)
        component = 'ag'
        state = ''
        list = []
        map = {}
        # fetch map where key is nodefile filename
        # and value is list of nodes
        for nodefile in self.args.nodelist:
            nodes, map = self.read_nodefile(nodefile,map)
            if(nodes < 0):
                self._exit()
        # create list of tuples from nodelist file(s): 
        # [ host, component, state ]
        for key in map:
            nodes = map[key]
            for node in nodes:
                entry = [ node, component, state ]
                list.append(entry)
        return list
    
    # --component
    def complist(self):
        text = str(self.args.component)
        debug(self._mn(),text)
        list = []
        # validate components specified on cmdline
        for c in self.args.component:
            parts = c.split('@')
            if len(parts) == 1:
                dn = self.get_node_name()
                dc = parts[0]
            elif len(parts) == 2:
                dn = parts[1]
                dc = parts[0]
                if(dc.startswith('all')):
                    msg = 'node specification disallowed for: '+dc
                    output(msg)
                    self._exit()
            else:
                msg = 'invalid component: '+c
                output(msg)
                self._exit()
            if(dc in self.shortname):
                component = dc
            elif(dc in self.components):
                component = self.shortname[dc]
            else:
                msg = 'invalid component: '+c
                output(msg)
                self._exit()
            text = dc+'.'+dn
            debug(self._mn(),text)
            node = dn
            entry = [ node, component, '' ]
            list.append(entry)
        return list
    
    # disallow br/db unless automanaged
    def validate_automanage(self,component):
        if(component == 'br'):
            if(not self.automanage_broker):
                msg = 'component='+component+' '+'not automanaged.'
                output(msg)
                self._exit()
        elif(component == 'db'):
            if(not self.automanage_database):
                msg = 'component='+component+' '+'not automanaged.'
                output(msg)
                self._exit()
    
    # disallow unless in db
    def validate_db(self,node,component):
        list = self.get_db_list()
        for item in list:
            db_node = item[0]
            db_component = item[1]
            if(db_node == node):
                if(db_component == component):
                    return
        msg = 'node='+node+' '+'component='+component+' not found in database autostart table'
        output(msg)
        self._exit()
    
    # validate user specified list
    def validate_list(self,list):
        for item in list:
            node = item[0]
            component = item[1]
            self.validate_automanage(component)
            self.validate_db(node, component)
    
    # in: tuples of (component,pid,user) and a desired component
    # out: the pid of the desired component, if found 
    def find_pid(self,tuples,component):
        pid = None
        for tuple in tuples:
            t_component = tuple[0]
            t_pid = tuple[1]
            t_user = tuple[2]
            if(t_user == self.ducc_uid):
                if(self.shortname.has_key(t_component)):
                    t_comp = self.shortname[t_component]
                    if(t_comp == component):
                        pid = t_pid
                        break
        return pid
    
    def acct_stop(self,node,component):
        print 'stop: '+component+'@'+node
        self.db_acct_stop(node,component)
        
    # target=kill
    def kill(self,count,tid,node,component,signal):
        self.acct_stop(node,component)
        verbosity=False
        ssh = self.ssh_operational(node,verbosity)
        state = 'state=pending'
        pfx = 'kill'+' '+'daemon='+str(count)+' '+'thread='+str(tid)+' '+'node='+node+' '+'component='+component+' '
        msg = pfx+state
        output(msg)
        process=''
        if(ssh):
            state='state=success'
            status, tuples = self.find_ducc_process(node)
            if(status):
                pid = self.find_pid(tuples,component)
                if(pid == None):
                    state='state=component not found'
                else:
                    self.ssh(node, True, 'kill', '-'+str(signal), str(pid))
                    process='pid='+str(pid)+' '
            else:
                state='state=find DUCC process failed'
        else:
            state = 'state=ssh failure'
        msg = pfx+process+state
        output(msg)
        self.put_tid(tid)
        self.pool.release()
    
    # launch threads to perform kills
    def kill_threads(self,list,signal):
        size = len(list)
        msg = 'daemons='+str(len(list))
        output(msg)
        count = 0
        for raw_type in self.components:
            type = self.shortname[raw_type]
            for item in list:
                node = item[0]
                component = item[1]
                if(component == type):
                    count = count+1
                    self.pool.acquire()
                    tid = self.get_tid()
                    t = Thread(target=self.kill, args=(count,tid,node,component,signal))
                    t.start()
    
    def filter_remote_head(self,list):
        list_remote_head = []
        list_remainder = []
        this_node = self.get_node_name()
        for item in list:
            node = item[0]
            com = item[1]
            if(com == self.c_ag):
                list_remainder.append(item)
            elif(node == this_node):
                list_remainder.append(item)
            else:
                list_remote_head.append(item)
        return list_remote_head, list_remainder
    
    # target=stop
    def stop(self,list,qflag):   
        text = 'list='+str(list)
        debug(self._mn(),text)
        # get 2 lists
        list_remote_head, list = self.filter_remote_head(list)
        # stop remote head(s)
        if(len(list_remote_head)>0):
            signal = self.sig15
            self.kill_threads(list_remote_head, signal)
        # update database + build admin string
        admin = ''
        stop_db = False
        stop_broker = False
        # pass 1 remove db,br
        for item in list:
            node = item[0]
            com = item[1]
            self.acct_stop(node,com)
            component = self.longname[com]
            if(component == self.c_broker):
                stop_broker = True
                continue
            elif(component == self.c_database):
                stop_db = True
                continue
            else:
                component = self.longname[com]
                admin = admin+component+'@'+node+' '
        # issue command
        admin = admin.split()
        for item in admin:
            if(qflag):
                print 'quiesce: '+item
                self.ducc_admin('--quiesce',item)
            else:
                print 'stop: '+item
                self.ducc_admin('--stop',item)
        # stop broker
        if(stop_broker):
            self.stop_broker()
        # stop database
        if(stop_db):
            self.db_stop()
    
    # only agent component allowed
    def agent_only(self,list):
        for item in list:
            component = item[1]
            if(component != self.n_ag):
                'invalid component='+component
                self._exit()
            
    # if this command is not running from a head node, 
    # then complain and exit
    def enforce_location_limits(self,list):
        if(not self.is_head_node()):
            msg = 'cannot run from non-head node.'
            output(msg)
            self._exit()
    
    # multi-thread lock to obtain thread id
    lock_tid = Lock()
    
    # get thread id
    def get_tid(self):
        with self.lock_tid:
            tid = self.tids.pop(0)
            return tid
    
    # return thread id
    def put_tid(self,tid):
        with self.lock_tid:
            self.tids.append(tid)
    
    # initialize for threading
    def threads_prep(self):
        maxthreads = self.args.maxthreads
        self.tids = range(1,maxthreads+1)
        self.pool = BoundedSemaphore(value=maxthreads)
    
    # main
    def main(self,argv):
        self.get_args()
        self.threads_prep()
        # get list of nodes+daemons
        if(self.args.all):
            list = self.all()
        elif(self.args.head):
            list = self.head()
        elif(self.args.agents):
            list = self.agents()
        elif(self.args.nodelist != None):
            list = self.nodelist()
        elif(self.args.component != None):
            list = self.complist()
        else:
            self._help()
        text = str(list)
        debug(self._mn(),text)
        # disallow br/db unless DUCC managed
        self.validate_list(list)
        # allow only from head node, except for stop of local agent
        self.enforce_location_limits(list)
        # perform action
        if(self.args.kill):
            signal = self.sig9
            self.kill_threads(list,signal)
        elif(self.args.stop != None):
            self.stop(list,False)
        elif(self.args.quiesce_then_stop):
            self.stop(list,True)
        else:
            self._help()
        
if __name__ == '__main__':
    instance = StopDucc()
    instance.main(sys.argv[1:])
