#!/usr/bin/python3.13 -s

# Copyright 2024 Arthur Tolley
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""Find trigger files and combine them into a single hdf trigger merge file."""

import glob
import numpy
import argparse
import itertools
import h5py
import os
import logging
from datetime import datetime, timedelta

from ligo.segments import segmentlist, segment

import pycbc
from pycbc.io import live as liveio
from pycbc.events import cuts, veto

# Set up the command line argument parser
parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
liveio.add_live_trigger_selection_options(parser)
cuts.insert_cuts_option_group(parser)

parser.add_argument(
    '--ifos',
    nargs='+',
    required=True,
    help="The list of detectors to include triggers in the merged file"
)

parser.add_argument(
    '--output-file',
    required=True,
    help='The output file containing merged triggers.'
)
parser.add_argument(
    "--bank-file",
    required=True,
    help="The bank file used in the search"
)

args = parser.parse_args()

pycbc.init_logging(args.verbose)

logging.info("Finding trigger files")
trigger_files = liveio.find_trigger_files_from_cli(args)
logging.info("%s files found", len(trigger_files))

###########################
# COLLATE THE TRIGGER FILES
###########################
# Deal with where the gps time is in the middle of a file:
args.trigger_cuts = args.trigger_cuts or []
args.trigger_cuts.append(f"end_time:{args.gps_start_time}:lower_inc")
args.trigger_cuts.append(f"end_time:{args.gps_end_time}:upper_inc")

trigger_cut_dict, template_cut_dict = cuts.ingest_cuts_option_group(args)

logging.info(
    "Collating triggers to %s",
    args.output_file
)

# Some tracking objects
file_count = 0
n_triggers = {ifo: 0 for ifo in args.ifos}
n_triggers_cut = {ifo: 0 for ifo in args.ifos}
segs = {ifo: segmentlist([]) for ifo in args.ifos}

with h5py.File(args.bank_file,'r') as bank_file:
    # Count the number of templates
    n_templates = bank_file['template_hash'].size

with h5py.File(args.output_file, 'w') as destination:
    # Create the ifo groups in the trigger file
    for ifo in args.ifos:
        if ifo not in destination:
            destination.create_group(ifo)

    for file_count, source_file in enumerate(trigger_files):
        trigger_file = os.path.basename(source_file)

        start_time = float(trigger_file.split('-')[2])
        duration = float(trigger_file.split('-')[3][:-4])
        end_time = start_time + duration

        if file_count % 100 == 0:
            logging.info(
                "Files appended: %d/%d",
                file_count,
                len(trigger_files)
            )

        with h5py.File(source_file, 'r') as source:
            for ifo in args.ifos:
                try:
                    n_trigs_ifo = source[ifo]['snr'].size
                    n_triggers[ifo] += n_trigs_ifo
                except KeyError:
                    # No triggers in this IFO in this file
                    continue

                # Triggers were generated, so add to the segment list
                segs[ifo].append(segment(start_time, end_time))

                triggers = {
                    k: source[ifo][k][:] for k in source[ifo].keys()
                    if k not in ('loudest', 'stat', 'gates', 'psd')
                    and source[ifo][k].size == n_trigs_ifo
                }
                # The stored chisq is actually reduced chisq, so convert back
                # to unreduced chisq using chisq_dof
                tmpchisq = triggers['chisq'][:] * (2 * triggers['chisq_dof'][:] - 2)
                triggers['chisq'][:] = tmpchisq

                # Apply the cuts to triggers
                keep_idx = cuts.apply_trigger_cuts(triggers, trigger_cut_dict)

                # triggers contains the datasets that we want to use for
                # the template cuts, so here it can be used as the template bank
                keep_idx = cuts.apply_template_cuts(
                    triggers,
                    template_cut_dict,
                    template_ids=keep_idx
                )
                if not any(keep_idx):
                    # No triggers kept after cuts in this ifo for this file
                    continue
                n_triggers_cut[ifo] += keep_idx.size

                triggers = {
                    k: triggers[k][keep_idx]
                    for k in triggers.keys()
                }

                if ('approximant' not in triggers) or (len(triggers['approximant']) == 0):
                    continue

                for name, dataset in triggers.items():
                    if name in destination[ifo]:
                        # Append new data to existing dataset in destination
                        if triggers[name].shape[0] == 0:
                            continue
                        destination[ifo][name].resize(
                            n_triggers_cut[ifo],
                            axis=0
                        )
                        destination[ifo][name][-keep_idx.size:] = dataset
                    else:
                        destination[ifo].create_dataset(
                            name,
                            data=dataset,
                            chunks=True,
                            maxshape=(None,)
                        )

                for attr_name, attr_value in source.attrs.items():
                    destination.attrs[attr_name] = attr_value

    for ifo in args.ifos:
        # Collect the segments, and output to the file
        search_grp = destination[ifo].create_group('search')
        segs[ifo].coalesce()
        seg_starts, seg_ends = veto.segments_to_start_end(segs[ifo])
        search_grp.create_dataset(
            'end_time',
            data=seg_ends,
        )
        search_grp.create_dataset(
            'start_time',
            data=seg_starts,
        )
        logging.info(
            "Found %d %s triggers in %s seconds",
            n_triggers[ifo],
            ifo,
            abs(segs[ifo])
        )
        if n_triggers[ifo] != n_triggers_cut[ifo]:
            logging.info(
                "  %d triggers after cuts",
                n_triggers_cut[ifo],
            )

    # Add all the template boundaries
    for ifo in args.ifos:
        logging.info("Processing %s", ifo)
        triggers = destination[ifo]
        try:
            template_ids = triggers['template_id'][:]
        except KeyError:
            logging.info("No triggers for %s, skipping", ifo)
            continue

        logging.info("Calculating template boundaries")
        sorted_indices = numpy.argsort(template_ids)
        sorted_template_ids = template_ids[sorted_indices]
        unique_template_ids, template_id_counts = numpy.unique(
            sorted_template_ids,
            return_counts=True
        )
        template_boundaries = numpy.searchsorted(
            sorted_template_ids,
            numpy.arange(n_templates)
        )
        triggers['template_boundaries'] = template_boundaries

        # Sort other datasets by template_id so it makes sense
        # Datasets with the same length as the number of triggers:
        #   approximant, chisq, chisq_dof, coa_phase, end_time, f_lower
        #   mass_1, mass_2, sg_chisq, sigmasq, snr, spin1z, spin2z,
        #   template_duration, template_hash, template_id
        for key in triggers.keys():
            if len(triggers[key]) == len(template_ids):
                logging.info('Sorting %s by template id', key)
                sorted_key = triggers[key][:][sorted_indices]
                triggers[key][:] = sorted_key


        logging.info("Setting up region references")
        # Datasets which need region references:
        region_ref_datasets = ('chisq_dof', 'chisq', 'coa_phase',
                               'end_time', 'sg_chisq', 'snr',
                               'template_duration', 'sigmasq')
        if 'psd_var_val' in triggers.keys():
            region_ref_datasets += ('psd_var_val',)

        start_boundaries = template_boundaries
        end_boundaries = numpy.roll(start_boundaries, -1)
        end_boundaries[-1] = len(template_ids)

        for dataset in region_ref_datasets:
            logging.info(
                "Region references for %s",
                dataset
            )
            refs = [
                triggers[dataset].regionref[l:r]
                for l, r in zip(start_boundaries, end_boundaries)
            ]

            logging.info("Adding to file")
            triggers.create_dataset(
                dataset + '_template',
                data=refs,
                dtype=h5py.special_dtype(ref=h5py.RegionReference)
            )

logging.info("Done!")
