#!/usr/bin/python3.13 -s
# Copyright (C) 2015 Alexander Harvey Nitz, Ian Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""
Followup single-detector triggers which do not contribute to foreground events
"""
import os
import argparse
import logging
import numpy

from ligo.lw import table
from ligo.lw import utils as ligolw_utils

from pycbc import init_logging, add_common_pycbc_options
from pycbc.results import layout
from pycbc.types.optparse import MultiDetOptionAction
from pycbc.events import select_segments_by_definer
import pycbc.workflow.minifollowups as mini
import pycbc.workflow as wf
import pycbc.events
from pycbc.workflow.core import resolve_url_to_file
from pycbc.events import stat, veto
from pycbc.io import hdf

parser = argparse.ArgumentParser(description=__doc__[1:])
add_common_pycbc_options(parser)
parser.add_argument('--bank-file',
                    help="HDF format template bank file")
parser.add_argument('--single-detector-file',
                    help="HDF format merged single detector trigger files")
parser.add_argument('--instrument', help="Name of interferometer e.g. H1")
parser.add_argument('--foreground-censor-file',
                    help="The censor file to be used if vetoing triggers "
                         "in the foreground of the search (optional).")
parser.add_argument('--foreground-segment-name',
                    help="If using foreground censor file must also provide "
                         "the name of the segment to use as a veto.")
parser.add_argument('--veto-file',
                    help="The veto file to be used if vetoing triggers "
                         "(optional).")
parser.add_argument('--veto-segment-name',
                    help="If using veto file must also provide the name of "
                         "the segment to use as a veto.")
parser.add_argument("--gating-veto-windows", nargs='+',
                    action=MultiDetOptionAction,
                    help="Seconds to be vetoed before and after the central time "
                         "of each gate. Given as detector-values pairs, e.g. "
                         "H1:-1,2.5 L1:-1,2.5 V1:0,0")
parser.add_argument('--inspiral-segments',
                    help="xml segment file containing the inspiral analysis "
                         "times")
parser.add_argument('--inspiral-data-read-name',
                    help="Name of inspiral segmentlist containing data read in "
                         "by each analysis job.")
parser.add_argument('--inspiral-data-analyzed-name',
                    help="Name of inspiral segmentlist containing data "
                         "analyzed by each analysis job.")
parser.add_argument('--min-sngl-ranking', type=float, default=6.5,
                    help="Minimum sngl-ranking to consider for loudest "
                         "triggers. Useful for efficiency savings. "
                         "Default=6.5.")
parser.add_argument('--non-coinc-time-only', action='store_true',
                    help="If given remove (veto) single-detector triggers "
                         "that occur during a time when at least one other "
                         "instrument is taking science data.")
parser.add_argument('--vetoed-time-only', action='store_true',
                    help="If given, only report on single-detector triggers "
                         "that occur during vetoed times.")
parser.add_argument('--minimum-duration', default=None, type=float,
                    help="If given only consider single-detector triggers "
                         "with template duration larger than this.")
parser.add_argument('--maximum-duration', default=None, type=float,
                    help="If given only consider single-detector triggers "
                         "with template duration smaller than this.")
parser.add_argument('--cluster-window', type=float, default=10,
                    help="Window (seconds) over which to cluster triggers "
                         "when finding the loudest-ranked. Default=10")
wf.add_workflow_command_line_group(parser)
wf.add_workflow_settings_cli(parser, include_subdax_opts=True)
stat.insert_statistic_option_group(parser,
    default_ranking_statistic='single_ranking_only')
args = parser.parse_args()

# Default logging level is info: --verbose adds to this
init_logging(args.verbose, default_level=1)

workflow = wf.Workflow(args)
workflow.ifos = [args.instrument]
workflow.ifo_string = args.instrument

wf.makedir(args.output_dir)

# create a FileList that will contain all output files
layouts = []

tmpltbank_file = resolve_url_to_file(os.path.abspath(args.bank_file))
sngl_file = resolve_url_to_file(
    os.path.abspath(args.single_detector_file),
    attrs={'ifos': args.instrument}
)

# Flatten the statistic_files option:
statfiles = []
for f in sum(args.statistic_files, []):
    statfiles.append(resolve_url_to_file(os.path.abspath(f)))
statfiles = wf.FileList(statfiles) if statfiles is not [] else None

if args.veto_file is not None:
    veto_file = resolve_url_to_file(
        os.path.abspath(args.veto_file),
        attrs={'ifos': args.instrument}
    )
else:
    veto_file = None

insp_segs = resolve_url_to_file(os.path.abspath(args.inspiral_segments))
insp_data_seglists = select_segments_by_definer\
        (args.inspiral_segments, segment_name=args.inspiral_data_read_name,
         ifo=args.instrument)
insp_data_seglists.coalesce()

num_events = int(workflow.cp.get_opt_tags('workflow-sngl_minifollowups',
                 'num-sngl-events', ''))

trigs = hdf.SingleDetTriggers(
    args.single_detector_file,
    args.instrument,
    bank_file=args.bank_file,
    veto_file=args.foreground_censor_file,
    segment_name=args.foreground_segment_name,
    filter_rank=args.sngl_ranking,
    filter_threshold=args.min_sngl_ranking
)

# Include gating vetoes
if args.gating_veto_windows:
    logging.info("Getting gating vetoes")
    gating_veto = args.gating_veto_windows[args.instrument].split(',')
    gveto_before = float(gating_veto[0])
    gveto_after = float(gating_veto[1])
    if gveto_before > 0 or gveto_after < 0:
        raise ValueError("Gating veto window values must be negative before "
                         "gates and positive after gates.")
    if not (gveto_before == 0 and gveto_after == 0):
        gate_group = f[args.instrument + '/gating/']
        autogate_times = numpy.unique(gate_group['auto/time'][:])
        if 'file' in gate_group:
            detgate_times = gate_group['file/time'][:]
        else:
            detgate_times = []
        gate_times = numpy.concatenate((autogate_times, detgate_times))
        gveto_idx = veto.indices_within_times(
            trigs.end_time,
            gate_times + gveto_before,
            gate_times + gveto_after
        )
        logging.info('%i triggers in gating vetoes', gveto_idx.size)
else:
    gveto_idx = numpy.array([], dtype=numpy.uint64)

if args.veto_file:
    logging.info('Getting file vetoes')
    # veto_mask is an array of indices into the trigger arrays
    # giving the surviving triggers
    veto_file_idx, _ = pycbc.events.veto.indices_within_segments(
        trigs.end_time,
        [args.veto_file],
        ifo=args.instrument,
        segment_name=args.veto_segment_name
    )

    logging.info('%i triggers in file-vetoed segments',
                 veto_file_idx.size)
else:
    veto_file_idx = numpy.array([], dtype=numpy.uint64)

# Work out indices we are going to keep / remove
vetoed_idx = numpy.unique(numpy.concatenate((veto_file_idx, gveto_idx)))
# Needs to be in ascending order
vetoed_idx = numpy.sort(vetoed_idx).astype(numpy.uint64)

if args.vetoed_time_only and vetoed_idx.size > 0:
    logging.info("Applying mask to keep only triggers within vetoed time")
    trigs.apply_mask(vetoed_idx)
elif vetoed_idx.size > 0:
    logging.info("Applying mask to keep only triggers outwith vetoed time")
    veto_mask = numpy.ones(trigs.end_time.size, dtype=bool)
    veto_mask[vetoed_idx] = False
    trigs.apply_mask(veto_mask)
elif args.vetoed_time_only and vetoed_idx.size == 0:
    logging.warning("No triggers exist inside vetoed times")

if args.non_coinc_time_only:
    from pycbc.io.ligolw import LIGOLWContentHandler as h

    segs_doc = ligolw_utils.load_filename(args.inspiral_segments,
                                          contenthandler=h)
    seg_def_table = table.Table.get_table(segs_doc, 'segment_definer')
    def_ifos = seg_def_table.getColumnByName('ifos')
    def_ifos = [str(ifo) for ifo in def_ifos]
    ifo_list = list(set(def_ifos))
    ifo_list.remove(args.instrument)
    for ifo in ifo_list:
        curr_veto_mask, segs = pycbc.events.veto.indices_outside_segments(
            trigs.end_time, [args.inspiral_segments],
            ifo=ifo, segment_name=args.inspiral_data_analyzed_name)
        curr_veto_mask.sort()
        trigs.apply_mask(curr_veto_mask)

if args.minimum_duration is not None:
    logging.info('applying minimum duration')
    durations = trigs.template_duration
    lgc_mask = durations > args.minimum_duration
    trigs.apply_mask(lgc_mask)
    logging.info('remaining triggers: %s', trigs.mask.sum())

if args.maximum_duration is not None:
    logging.info('applying maximum duration')
    durations = trigs.template_duration
    lgc_mask = durations < args.maximum_duration
    trigs.apply_mask(lgc_mask)
    logging.info('remaining triggers: %s', trigs.mask.sum())

logging.info('Finding loudest clustered events')
rank_method = stat.get_statistic_from_opts(args, [args.instrument])

extra_kwargs = stat.parse_statistic_keywords_opt(args.statistic_keywords)

trigs.mask_to_n_loudest_clustered_events(
    rank_method,
    n_loudest=num_events,
    cluster_window=args.cluster_window,
    statistic_kwargs=extra_kwargs,
)

times = trigs.end_time

if isinstance(trigs.mask, numpy.ndarray) and trigs.mask.dtype == bool:
    trigger_ids = numpy.flatnonzero(trigs.mask)
else:
    trigger_ids = trigs.mask

trig_stat = trigs.stat
# loop over number of loudest events to be followed up
order = trig_stat.argsort()[::-1]
for rank, num_event in enumerate(order):
    logging.info('Processing event: %s', rank)

    files = wf.FileList([])
    time = times[num_event]
    ifo_time = '%s:%s' %(args.instrument, str(time))
    tid = trigger_ids[num_event]
    ifo_tid = '%s:%s' %(args.instrument, str(tid))

    layouts += (mini.make_sngl_ifo(workflow, sngl_file, tmpltbank_file,
                                   tid, args.output_dir, args.instrument,
                                   statfiles=statfiles,
                                   tags=args.tags + [str(rank)]),)
    files += mini.make_trigger_timeseries(workflow, [sngl_file],
                              ifo_time, args.output_dir, special_tids=ifo_tid,
                              tags=args.tags + [str(rank)])
    curr_params = {}
    curr_params['mass1'] = trigs.mass1[num_event]
    curr_params['mass2'] = trigs.mass2[num_event]
    curr_params['spin1z'] = trigs.spin1z[num_event]
    curr_params['spin2z'] = trigs.spin2z[num_event]
    curr_params['f_lower'] = trigs.f_lower[num_event]
    curr_params[args.instrument + '_end_time'] = time
    curr_params['mean_time'] = time
    # don't require precessing template info if not present
    try:
        curr_params['spin1x'] = trigs.spin1x[num_event]
        curr_params['spin2x'] = trigs.spin2x[num_event]
        curr_params['spin1y'] = trigs.spin1y[num_event]
        curr_params['spin2y'] = trigs.spin2y[num_event]
        curr_params['inclination'] = trigs.inclination[num_event]
    except KeyError:
        pass
    try:
        # Only present for precessing search
        curr_params['u_vals'] = trigs.u_vals[num_event]
    except:
        pass

    _, sngl_plot = mini.make_single_template_plots(workflow, insp_segs,
                            args.inspiral_data_read_name,
                            args.inspiral_data_analyzed_name,
                            curr_params,
                            args.output_dir,
                            data_segments={args.instrument : insp_data_seglists},
                            tags=args.tags+[str(rank)])

    files += sngl_plot

    files += mini.make_plot_waveform_plot(workflow, curr_params,
                                        args.output_dir, [args.instrument],
                                        tags=args.tags + [str(rank)])

    files += mini.make_singles_timefreq(workflow, sngl_file, tmpltbank_file,
                            time, args.output_dir,
                            data_segments=insp_data_seglists,
                            tags=args.tags + [str(rank)])

    files += mini.make_qscan_plot(workflow, args.instrument, time,
                                  args.output_dir,
                                  data_segments=insp_data_seglists,
                                  tags=args.tags + [str(rank)])

    layouts += list(layout.grouper(files, 2))

workflow.save()
layout.two_column_layout(args.output_dir, layouts)
