#!/usr/bin/python3.13 -s
# Copyright (C) 2015 Alexander Harvey Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Followup foreground events
"""
import os
import sys
import argparse
import logging
import re

import pycbc.workflow as wf
from pycbc import init_logging, add_common_pycbc_options
from pycbc.results import layout
from pycbc.types import MultiDetOptionAction
from pycbc.events import select_segments_by_definer, coinc
from pycbc.io import get_all_subkeys, HFile
import pycbc.workflow.minifollowups as mini
from pycbc.workflow.core import resolve_url_to_file

parser = argparse.ArgumentParser(description=__doc__[1:])
add_common_pycbc_options(parser)
parser.add_argument('--bank-file',
                    help="HDF format template bank file")
parser.add_argument('--statmap-file',
                    help="HDF format clustered coincident trigger result file")
parser.add_argument('--single-detector-triggers', nargs='+', action=MultiDetOptionAction,
                    help="HDF format merged single detector trigger files")
parser.add_argument('--inspiral-segments',
                    help="xml segment files containing the inspiral analysis times")
parser.add_argument('--inspiral-data-read-name',
                    help="Name of inspiral segmentlist containing data read in "
                         "by each analysis job.")
parser.add_argument('--inspiral-data-analyzed-name',
                    help="Name of inspiral segmentlist containing data "
                         "analyzed by each analysis job.")
parser.add_argument('--analysis-category', type=str, required=False,
                    default='background_exc',
                    choices = ['foreground', 'background', 'background_exc'],
                    help='Designates whether to look at foreground triggers '
                         'background triggers (including "little dogs") '
                         'or background triggers (with "little dogs" removed)')
parser.add_argument('--sort-variable', default='ifar',
                    help='Which subgroup of --analysis-category to use for '
                         'sorting. Default=ifar')
parser.add_argument('--sort-order', default='descending',
                    choices=['ascending','descending'],
                    help='Which direction to use when sorting on '
                         '--sort-variable. Default=descending')

wf.add_workflow_command_line_group(parser)
wf.add_workflow_settings_cli(parser, include_subdax_opts=True)
args = parser.parse_args()

# Default logging level is info: --verbose adds to this
init_logging(args.verbose, default_level=1)

workflow = wf.Workflow(args)

wf.makedir(args.output_dir)

# create a FileList that will contain all output files
layouts = []

tmpltbank_file = resolve_url_to_file(os.path.abspath(args.bank_file))
coinc_file = resolve_url_to_file(os.path.abspath(args.statmap_file))
insp_segs = resolve_url_to_file(os.path.abspath(args.inspiral_segments))

single_triggers = []
fsdt = {}
insp_data_seglists = {}
insp_analysed_seglists = {}
for ifo in args.single_detector_triggers:
    fname = args.single_detector_triggers[ifo]
    strig_file = resolve_url_to_file(os.path.abspath(fname),
                                     attrs={'ifos': ifo})
    single_triggers.append(strig_file)
    fsdt[ifo] = HFile(args.single_detector_triggers[ifo], 'r')
    insp_data_seglists[ifo] = select_segments_by_definer(
        args.inspiral_segments,
        segment_name=args.inspiral_data_read_name,
        ifo=ifo)
    insp_analysed_seglists[ifo] = select_segments_by_definer(
        args.inspiral_segments,
        segment_name=args.inspiral_data_analyzed_name,
        ifo=ifo)
    # NOTE: make_singles_timefreq needs a coalesced set of segments. If this is
    #       being used to determine command-line options for other codes,
    #       please think if that code requires coalesced, or not, segments.
    insp_data_seglists[ifo].coalesce()
    insp_analysed_seglists[ifo].coalesce()

num_events = int(workflow.cp.get_opt_tags('workflow-minifollowups', 'num-events', ''))
f = HFile(args.statmap_file, 'r')
file_val = args.analysis_category
stat = f['{}/stat'.format(file_val)][:]

if args.sort_variable not in f[file_val]:
    all_datasets = [re.sub(file_val, '', ds).strip('/')
                    for ds in get_all_subkeys(f, file_val)]
    raise KeyError(f'Sort variable {args.sort_variable} not in {file_val}: sort'
                   f'choices in {file_val} are ' + ', '.join(all_datasets))

# In case we are doing background minifollowup with repeated events,
# we must include the ordering / template / trigger / time info for
# _many_ more events to make sure we get enough
events_to_read = num_events * 100

# We've asked for more events than there are!
if len(stat) < num_events:
    num_events = len(stat)
if len(stat) < events_to_read:
    events_to_read = len(stat)

# Get the indices of the events we are considering in the order specified
sorting = f[file_val + '/' + args.sort_variable][:].argsort()
if args.sort_order == 'descending':
    sorting = sorting[::-1]
event_idx = sorting[0:events_to_read]
stat = stat[event_idx]

# Save the time / trigger / template ids for the events
times = {}
tids = {}
ifo_list = f.attrs['ifos'].split(' ')
f_cat = f[file_val]
for ifo in ifo_list:
    times[ifo] = f_cat[ifo]['time'][:][event_idx]
    tids[ifo] = f_cat[ifo]['trigger_id'][:][event_idx]
bank_ids = f_cat['template_id'][:][event_idx]
f.close()

bank_data = HFile(args.bank_file, 'r')

# loop over number of loudest events to be followed up
event_times = {}
skipped_data = []
event_count = 0
curr_idx = -1
while event_count < num_events and curr_idx < (events_to_read - 1):
    curr_idx += 1
    files = wf.FileList([])
    duplicate = False

    # Times and trig ids for trigger timeseries plot
    ifo_times_strings = []
    ifo_tids_strings = []

    # Mean time to use as reference for zerolag
    mtime = coinc.mean_if_greater_than_zero(
        [times[ifo][curr_idx] for ifo in times])[0]

    for ifo in times:
        plottime = times[ifo][curr_idx]
        if plottime == -1:  # Ifo is not in coinc
            # Use mean time instead and don't plot special trigger
            plottime = mtime
        else:  # Do plot special trigger
            ifo_tids_strings += ['%s:%s' % (ifo, tids[ifo][curr_idx])]
        ifo_times_strings += ['%s:%s' % (ifo, plottime)]

        # For background do not want to follow up 10 background coincs with the
        # same event in ifo 1 and different events in ifo 2, so record times
        if ifo not in event_times:
            event_times[ifo] = []
        # Only do this for background coincident triggers & not sentinel time -1
        if 'background' in args.analysis_category and \
                times[ifo][curr_idx] != -1 and \
                int(times[ifo][curr_idx]) in event_times[ifo]:
            skipped_data.append((ifo, int(times[ifo][curr_idx])))
            duplicate = True
            break

    if duplicate:
        continue
    for ifo in times:
        event_times[ifo].append(int(times[ifo][curr_idx]))

    ifo_times = ' '.join(ifo_times_strings)
    ifo_tids = ' '.join(ifo_tids_strings)

    event_count += 1
    if skipped_data:
        layouts += (mini.make_skipped_html(
                        workflow,
                        skipped_data,
                        args.output_dir,
                        tags=[f'SKIP_{event_count}']),)
        skipped_data = []

    bank_id = bank_ids[curr_idx]

    layouts += (mini.make_coinc_info(workflow, single_triggers, tmpltbank_file,
                              coinc_file, args.output_dir, n_loudest=curr_idx,
                              sort_order=args.sort_order, sort_var=args.sort_variable,
                              tags=args.tags + [str(event_count)]),)
    files += mini.make_trigger_timeseries(workflow, single_triggers,
                             ifo_times, args.output_dir, special_tids=ifo_tids,
                             tags=args.tags + [str(event_count)])

    params = mini.get_single_template_params(
        curr_idx,
        times,
        bank_data,
        bank_ids[curr_idx],
        fsdt,
        tids
    )

    _, sngl_tmplt_plots = mini.make_single_template_plots(
        workflow,
        insp_segs,
        args.inspiral_data_read_name,
        args.inspiral_data_analyzed_name,
        params,
        args.output_dir,
        data_segments=insp_data_seglists,
        tags=args.tags + [str(event_count)]
    )
    files += sngl_tmplt_plots


    for single in single_triggers:
        time = times[single.ifo][curr_idx]
        if time==-1:
            # If this detector did not trigger, still make the plot, but use
            # the average time of detectors which did trigger
            time = coinc.mean_if_greater_than_zero([times[sngl.ifo][curr_idx]
                                                    for sngl in single_triggers])[0]
        for seg in insp_analysed_seglists[single.ifo]:
            if time in seg:
                files += mini.make_singles_timefreq(workflow, single, tmpltbank_file,
                                time, args.output_dir,
                                data_segments=insp_data_seglists[single.ifo],
                                tags=args.tags + [str(event_count)])
                files += mini.make_qscan_plot\
                    (workflow, single.ifo, time, args.output_dir,
                     data_segments=insp_data_seglists[single.ifo],
                     tags=args.tags + [str(event_count)])
                break
        else:
            logging.info(f'Trigger time {time} is not valid in '
                         f'{single.ifo}, skipping singles plots')

    layouts += list(layout.grouper(files, 2))

workflow.save()
layout.two_column_layout(args.output_dir, layouts)
