#!/usr/bin/python3.13 -s
# Copyright (C) 2015 Alexander Harvey Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Create a workflow for following up missed loud injections."""

import os
import argparse
import logging
import copy
import numpy

from pycbc import init_logging, add_common_pycbc_options
import pycbc.workflow as wf
import pycbc.workflow.minifollowups as mini
from pycbc.types import MultiDetOptionAction
from pycbc.events import select_segments_by_definer, coinc
from pycbc.results import layout
from pycbc.detector import Detector
from pycbc.workflow.core import resolve_url_to_file
from pycbc.io.hdf import SingleDetTriggers, HFile


legal_distance_types = [
    'decisive_optimal_snr',
    'comb_optimal_snr',
    'dec_chirp_distance'
]


def sort_injections(args, inj_group, missed):
    """Return an array of indices to sort the missed injections from most to
    least likely to be detected, according to a metric of choice.

    Parameters
    ----------
    args : object
        CLI arguments parsed by argparse. Must have a `distance_type` attribute
        to specify how to sort, which must be one of the values in
        `legal_distance_types`.
    inj_group : h5py group object
        HDF5 group object containing the injection definition.
    missed : array
        Array of indices of missed injections into `inj_group`.

    Returns
    -------
    missed_sorted : array
        Array of indices of missed injections sorted as requested.
    """
    if not hasattr(args, 'distance_type'):
        raise ValueError('Distance type not provided')
    if args.distance_type not in legal_distance_types:
        raise ValueError(
            f'Invalid distance type "{args.distance_type}", '
            f'allowed types are {", ".join(legal_distance_types)}'
        )

    if 'optimal_snr' in args.distance_type:
        optimal_snrs = [
            inj_group[dsn][:][missed] for dsn in inj_group.keys()
            if dsn.startswith('optimal_snr_')
        ]
        assert optimal_snrs, 'These injections do not have optimal SNRs'

    if args.distance_type == 'decisive_optimal_snr':
        # descending order of decisive (2nd largest) optimal SNR
        dec_snr = numpy.array([
            sorted(snrs)[-2] for snrs in zip(*optimal_snrs)
        ])
        if args.maximum_decisive_snr is not None:
            # By setting to 0, these injections will not be considered
            dec_snr[dec_snr > args.maximum_decisive_snr] = 0
        sorter = dec_snr.argsort()[::-1]
        return missed[sorter]

    if args.distance_type == 'comb_optimal_snr':
        # descending order of network optimal SNR
        optimal_snrs = numpy.vstack(optimal_snrs)
        net_opt_snrs_squared = (optimal_snrs ** 2).sum(axis=0)
        sorter = net_opt_snrs_squared.argsort()[::-1]
        return missed[sorter]

    if args.distance_type == 'dec_chirp_distance':
        # ascending order of decisive (2nd smallest) chirp distance
        from pycbc.conversions import mchirp_from_mass1_mass2, chirp_distance

        eff_dists = []
        for ifo in args.single_detector_triggers:
            eff_dist = Detector(ifo).effective_distance(
                inj_group['distance'][:][missed],
                inj_group['ra'][:][missed],
                inj_group['dec'][:][missed],
                inj_group['polarization'][:][missed],
                inj_group['tc'][:][missed],
                inj_group['inclination'][:][missed]
            )
            eff_dists.append(eff_dist)
        dec_eff_dist = sorted(eff_dists)[-2]
        mchirp = mchirp_from_mass1_mass2(
            inj_group['mass1'][:][missed],
            inj_group['mass2'][:][missed]
        )
        dec_chirp_dist = chirp_distance(dec_dist, mchirp)
        sorter = dec_chirp_dist.argsort()
        return missed[sorter]


parser = argparse.ArgumentParser(description=__doc__)
add_common_pycbc_options(parser)
parser.add_argument('--bank-file',
                    help="HDF format template bank file")
parser.add_argument('--injection-file',
                    help="HDF format injection results file")
parser.add_argument('--injection-xml-file',
                    help="XML format injection file")
parser.add_argument('--single-detector-triggers', nargs='+', action=MultiDetOptionAction,
                    help="HDF format merged single detector trigger files")
parser.add_argument('--inspiral-segments',
                    help="xml segment files containing the inspiral analysis times")
parser.add_argument('--inspiral-data-read-name',
                    help="Name of inspiral segmentlist containing data read in "
                         "by each analysis job.")
parser.add_argument('--inspiral-data-analyzed-name',
                    help="Name of inspiral segmentlist containing data "
                         "analyzed by each analysis job.")
parser.add_argument('--inj-window', type=int, default=0.5,
                    help="Time window in which to look for injection triggers")
parser.add_argument('--ifar-threshold', type=float, default=None,
                    help="If given also followup injections with ifar smaller "
                         "than this threshold.")
parser.add_argument('--maximum-decisive-snr', type=float, default=None,
                    help="If given, only followup injections where the "
                         "decisive SNR is smaller than this value.")
parser.add_argument('--nearby-triggers-window', type=float, default=0.05,
                    help="Maximum time difference between the missed "
                         "injection and the loudest SNR nearby trigger to "
                         "display, seconds. Default=0.05")
parser.add_argument('--distance-type',
                    required=True,
                    choices=legal_distance_types,
                    help="How to sort missed injections from most to least "
                         "likely to be detected")
wf.add_workflow_command_line_group(parser)
wf.add_workflow_settings_cli(parser, include_subdax_opts=True)
args = parser.parse_args()


# Default logging level is info: --verbose adds to this
init_logging(args.verbose, default_level=1)

workflow = wf.Workflow(args)

wf.makedir(args.output_dir)

# create a FileList that will contain all output files
layouts = []

tmpltbank_file = resolve_url_to_file(os.path.abspath(args.bank_file))
injection_file = resolve_url_to_file(os.path.abspath(args.injection_file))
injection_xml_file = resolve_url_to_file(os.path.abspath(args.injection_xml_file))
insp_segs = resolve_url_to_file(os.path.abspath(args.inspiral_segments))

single_triggers = []
insp_data_seglists = {}
insp_analysed_seglists = {}
for ifo in args.single_detector_triggers:
    fname = args.single_detector_triggers[ifo]
    strig_file = resolve_url_to_file(os.path.abspath(fname),
                                     attrs={'ifos': ifo})
    single_triggers.append(strig_file)
    insp_data_seglists[ifo] = select_segments_by_definer(
        args.inspiral_segments,
        segment_name=args.inspiral_data_read_name,
        ifo=ifo
    )
    insp_analysed_seglists[ifo] = select_segments_by_definer(
        args.inspiral_segments,
        segment_name=args.inspiral_data_analyzed_name,
        ifo=ifo
    )
    # NOTE: make_singles_timefreq needs a coalesced set of segments. If this is
    #       being used to determine command-line options for other codes,
    #       please think if that code requires coalesced, or not, segments.
    insp_data_seglists[ifo].coalesce()
    insp_analysed_seglists[ifo].coalesce()

f = HFile(args.injection_file, 'r')
inj_def = f['injections']
missed = f['missed/after_vetoes'][:]
if args.ifar_threshold is not None:
    try:  # injections may not have (inclusive) IFAR present
        ifars = f['found_after_vetoes']['ifar'][:]
    except KeyError:
        ifars = f['found_after_vetoes']['ifar_exc'][:]
        logging.warning('Inclusive IFAR not found, using exclusive')
    lgc_arr = ifars < args.ifar_threshold
    missed = numpy.append(missed,
                          f['found_after_vetoes']['injection_index'][lgc_arr])

# Get the trigger SNRs and times
# But only ones which are within a small window of the missed injection
missed_inj_times = numpy.sort(inj_def['tc'][:][missed])

# Note: Adding Earth diameter in light seconds to the window here
# to allow for different IFO's arrival times of the injection
safe_window = args.nearby_triggers_window + 0.0425

def nearby_missedinj(endtime, snr):
    """
    Convenience function to check if trigger times are within a small
    window of the injections

    Parameters
    ----------
    endtime: numpy array
        Trigger times to be checked against the missed injection times

    snr: numpy array
        Required by design of the HFile.select method but not used,
        SNR of the triggers

    Returns
    -------
    boolean array
        True for triggers which are close to any missed injection
    """
    left = numpy.searchsorted(missed_inj_times - safe_window, endtime)
    right = numpy.searchsorted(missed_inj_times + safe_window, endtime)
    return left != right

trigger_idx = {}
trigger_snrs = {}
trigger_times = {}
# This finds the triggers near to _any_ missed injection
for trig in single_triggers:
    ifo = trig.ifo
    with HFile(trig.lfn, 'r') as trig_f:
        trigger_idx[ifo], data_tuple = \
            trig_f.select(
                nearby_missedinj,
                f'{ifo}/end_time',
                f'{ifo}/snr',
            )
    trigger_times[ifo], trigger_snrs[ifo] = data_tuple

# figure out how many injections to follow up
num_events = int(workflow.cp.get_opt_tags(
    'workflow-injection_minifollowups',
    'num-events',
    ''
))
if len(missed) < num_events:
    num_events = len(missed)

# sort the injections
missed = sort_injections(args, inj_def, missed)

# loop over sorted missed injections to be followed up
found_inj_idxes = f['found_after_vetoes/injection_index'][:]
for num_event in range(num_events):
    files = wf.FileList([])

    injection_index = missed[num_event]
    time = inj_def['tc'][injection_index]
    lon = inj_def['ra'][injection_index]
    lat = inj_def['dec'][injection_index]

    ifo_times = ''
    inj_params = {}
    for val in ['mass1', 'mass2', 'spin1z', 'spin2z', 'tc']:
        inj_params[val] = inj_def[val][injection_index]
    for single in single_triggers:
        ifo = single.ifo
        det = Detector(ifo)
        ifo_time = time + det.time_delay_from_earth_center(lon, lat, time)
        for seg in insp_analysed_seglists[ifo]:
            if ifo_time in seg:
                break
        else:
            ifo_time = -1.0

        ifo_times += ' %s:%s ' % (ifo, ifo_time)
        inj_params[ifo + '_end_time'] = ifo_time
    all_times = [inj_params[sngl.ifo + '_end_time'] for sngl in single_triggers]
    inj_params['mean_time'] = coinc.mean_if_greater_than_zero(all_times)[0]

    layouts += [(mini.make_inj_info(workflow, injection_file, injection_index, num_event,
                               args.output_dir, tags=args.tags + [str(num_event)])[0],)]
    if injection_index in found_inj_idxes:
        trig_id = numpy.where(found_inj_idxes == injection_index)[0][0]
        layouts += [(mini.make_coinc_info
                     (workflow, single_triggers, tmpltbank_file,
                      injection_file, args.output_dir, trig_id=trig_id,
                      file_substring='found_after_vetoes',
                      title="Details of closest event",
                      tags=args.tags + [str(num_event)])[0],)]

    for sngl in single_triggers:
        # Find the triggers close to _this_ injection at this IFO
        ifo = sngl.ifo
        trig_tdiff = abs(inj_params[ifo + '_end_time'] - trigger_times[ifo])
        nearby = trig_tdiff < args.nearby_triggers_window
        if not any(nearby):
            # If there are no triggers in the defined window,
            # do not print any info
            continue
        # Find the loudest SNR in this window
        loudest = numpy.argmax(trigger_snrs[ifo][nearby])
        # Convert to the index in the trigger file
        nearby_trigger_idx = trigger_idx[ifo][nearby][loudest]
        # Make the info snippet
        sngl_info = mini.make_sngl_ifo(workflow, sngl, tmpltbank_file,
            nearby_trigger_idx, args.output_dir, ifo,
            title=f"Parameters of loudest SNR nearby trigger in {ifo}",
            tags=args.tags + [str(num_event)])[0]
        layouts += [(sngl_info,)]

    files += mini.make_trigger_timeseries(workflow, single_triggers,
                              ifo_times, args.output_dir,
                              tags=args.tags + [str(num_event)])

    for single in single_triggers:
        checkedtime = time
        if (inj_params[single.ifo + '_end_time'] == -1.0):
            checkedtime = inj_params['mean_time']
        for seg in insp_analysed_seglists[single.ifo]:
            if checkedtime in seg:
                files += mini.make_singles_timefreq(workflow, single, tmpltbank_file,
                                checkedtime, args.output_dir,
                                data_segments=insp_data_seglists[single.ifo],
                                tags=args.tags + [str(num_event)])
                files += mini.make_qscan_plot\
                    (workflow, single.ifo, checkedtime, args.output_dir,
                     data_segments=insp_data_seglists[single.ifo],
                     injection_file=injection_xml_file,
                     tags=args.tags + [str(num_event)])
                break
        else:
            logging.info(
                'Trigger time %s is not valid in %s, skipping singles plots',
                checkedtime, single.ifo
            )

    _, norm_plot = mini.make_single_template_plots(workflow, insp_segs,
                            args.inspiral_data_read_name,
                            args.inspiral_data_analyzed_name, inj_params,
                            args.output_dir, inj_file=injection_xml_file,
                            tags=args.tags+['INJ_PARAMS',str(num_event)],
                            params_str='injection parameters as template, ' +\
                                       'here the injection is made as normal',
                            use_exact_inj_params=True)
    files += norm_plot

    _, inv_plot = mini.make_single_template_plots(workflow, insp_segs,
                            args.inspiral_data_read_name,
                            args.inspiral_data_analyzed_name, inj_params,
                            args.output_dir, inj_file=injection_xml_file,
                            tags=args.tags + ['INJ_PARAMS_INVERTED',
                                              str(num_event)],
                            params_str='injection parameters as template, ' +\
                                       'here the injection is made inverted',
                            use_exact_inj_params=True)
    files += inv_plot

    _, noinj_plot = mini.make_single_template_plots(workflow, insp_segs,
                            args.inspiral_data_read_name,
                            args.inspiral_data_analyzed_name, inj_params,
                            args.output_dir, inj_file=injection_xml_file,
                            tags=args.tags + ['INJ_PARAMS_NOINJ',
                                              str(num_event)],
                            params_str='injection parameters, here no ' +\
                                       'injection was actually performed',
                            use_exact_inj_params=True)
    files += noinj_plot

    for curr_ifo in args.single_detector_triggers:
        # Finding loudest template in this detector near to the injection:
        # First, find triggers close to the missed injection
        single_fname = args.single_detector_triggers[curr_ifo]
        idx, _ = HFile(single_fname).select(
            lambda t: abs(t - inj_params['tc']) < args.inj_window,
            f'{curr_ifo}/end_time',
            return_data=False,
        )

        if len(idx) == 0:
            # No events in the injection window for this detector
            continue

        hd_sngl = SingleDetTriggers(
            single_fname,
            curr_ifo,
            bank_file=args.bank_file,
            premask=idx
        )
        # Next, find the loudest within this set of triggers
        # Use SNR here or NewSNR, or other??
        loudest_idx = hd_sngl.snr.argmax()
        hd_sngl.apply_mask([loudest_idx])

        # What are the parameters of this trigger?
        curr_params = copy.deepcopy(inj_params)
        curr_params['mass1'] = hd_sngl.mass1[0]
        curr_params['mass2'] = hd_sngl.mass2[0]
        curr_params['spin1z'] = hd_sngl.spin1z[0]
        curr_params['spin2z'] = hd_sngl.spin2z[0]
        curr_params['f_lower'] = hd_sngl.f_lower[0]
        # don't require precessing template info if not present
        try:
            curr_params['spin1x'] = hd_sngl.spin1x[0]
            curr_params['spin2x'] = hd_sngl.spin2x[0]
            curr_params['spin1y'] = hd_sngl.spin1y[0]
            curr_params['spin2y'] = hd_sngl.spin2y[0]
            curr_params['inclination'] = hd_sngl.inclination[0]
        except KeyError:
            pass
        try:
            # Only present for precessing search
            curr_params['u_vals'] = hd_sngl.u_vals[0]
        except:
            pass

        curr_tags = ['TMPLT_PARAMS_%s' %(curr_ifo,)]
        curr_tags += [str(num_event)]
        _, loudest_plot = mini.make_single_template_plots(workflow, insp_segs,
                                args.inspiral_data_read_name,
                                args.inspiral_data_analyzed_name, curr_params,
                                args.output_dir, inj_file=injection_xml_file,
                                tags=args.tags + curr_tags,
                                params_str='loudest template in %s' % curr_ifo )
        files += loudest_plot

    layouts += list(layout.grouper(files, 2))

workflow.save()
layout.two_column_layout(args.output_dir, layouts)
