#!/usr/bin/python3.13 -s
""" Bin triggers by their dq value and calculate trigger rates in each bin
"""
import logging
import argparse

import numpy as np
import h5py as h5

from ligo.segments import segmentlist

import pycbc
from pycbc.events import stat as pystat
from pycbc.events.veto import (select_segments_by_definer,
                               start_end_to_segments,
                               segments_to_start_end)
from pycbc.types.optparse import MultiDetOptionAction
from pycbc.io.hdf import SingleDetTriggers

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--template-bins-file", required=True)
parser.add_argument("--trig-file", required=True)
parser.add_argument("--flag-file", required=True)
parser.add_argument("--flag-name", required=True)
parser.add_argument("--analysis-segment-file", required=True)
parser.add_argument("--analysis-segment-name", required=True)
parser.add_argument("--gating-windows", nargs='+',
                    action=MultiDetOptionAction,
                    help="Seconds to reweight before and after the central"
                         "time of each gate. Given as detector-values pairs, "
                         "e.g. H1:-1,2.5 L1:-1,2.5 V1:0,0")
parser.add_argument("--stat-threshold", type=float, default=1.,
                    help="Only consider triggers with --sngl-ranking value "
                    "above this threshold")
parser.add_argument("--output-file", required=True)

pystat.insert_statistic_option_group(
    parser, default_ranking_statistic='single_ranking_only')
args = parser.parse_args()
pycbc.init_logging(args.verbose)

logging.info('Start')

ifo, flag_name = args.flag_name.split(':')

if args.gating_windows:
    gate_times = []
    with h5.File(args.trig_file, 'r') as trig_file:
        logging.info('Getting gated times')
        try:
            gating_types = trig_file[f'{ifo}/gating'].keys()
            for gt in gating_types:
                gate_times += list(trig_file[f'{ifo}/gating/{gt}/time'][:])
            gate_times = np.unique(gate_times)
        except KeyError:
            logging.warning('No gating found in trigger file')

trigs = SingleDetTriggers(
    args.trig_file,
    ifo,
    filter_rank=args.sngl_ranking,
    filter_threshold=args.stat_threshold,
)

# Extract the data we actually need from the data structure:
tmplt_ids = trigs.template_id
trig_times = trigs.end_time
stat = trigs.get_ranking(args.sngl_ranking)

# Get the template bins
bin_tids_dict = {}
with h5.File(args.template_bins_file, 'r') as f:
    ifo_grp = f[ifo]
    for bin_name in ifo_grp.keys():
        bin_tids_dict[bin_name] = ifo_grp[bin_name]['tids'][:]

# get analysis segments
analysis_segs = select_segments_by_definer(
    args.analysis_segment_file,
    segment_name=args.analysis_segment_name,
    ifo=ifo)

livetime = abs(analysis_segs)

# get flag segments
flag_segs = select_segments_by_definer(args.flag_file,
                                       segment_name=flag_name,
                                       ifo=ifo)

# construct gate segments
gating_segs = segmentlist([])
if args.gating_windows:
    gating_windows = args.gating_windows[ifo].split(',')
    gate_before = float(gating_windows[0])
    gate_after = float(gating_windows[1])
    if gate_before > 0 or gate_after < 0:
        raise ValueError("Gating window values must be negative "
                         "before gates and positive after gates.")
    if not (gate_before == 0 and gate_after == 0):
        gating_segs = start_end_to_segments(
                gate_times + gate_before,
                gate_times + gate_after
        ).coalesce()

# make segments into mutually exclusive dq states
gating_segs = gating_segs & analysis_segs
flag_segs = flag_segs & analysis_segs

dq_state_segs_dict = {}
dq_state_segs_dict[2] = gating_segs
dq_state_segs_dict[1] = flag_segs - gating_segs
dq_state_segs_dict[0] = analysis_segs - flag_segs - gating_segs


# utility function to get the dq state at a given time
def dq_state_at_time(t):
    for state, segs in dq_state_segs_dict.items():
        if t in segs:
            return state
    return None


# compute and save results
with h5.File(args.output_file, 'w') as f:
    ifo_grp = f.create_group(ifo)
    all_bin_grp = ifo_grp.create_group('bins')
    all_dq_grp = ifo_grp.create_group('dq_segments')

    # setup data for each template bin
    for bin_name, bin_tids in bin_tids_dict.items():
        bin_grp = all_bin_grp.create_group(bin_name)
        bin_grp['tids'] = bin_tids

        # get the dq states of the triggers in this bin
        inbin = np.isin(tmplt_ids, bin_tids)
        trig_times_bin = trig_times[inbin]
        trig_states = np.array([dq_state_at_time(t) for t in trig_times_bin])

        # calculate the dq rates in this bin
        dq_rates = np.zeros(3, dtype=np.float64)
        for state, segs in dq_state_segs_dict.items():
            frac_eff = np.mean(trig_states == state)
            frac_dt = abs(segs) / livetime
            dq_rates[state] = frac_eff / frac_dt
        bin_grp['dq_rates'] = dq_rates
        bin_grp['num_triggers'] = len(trig_times_bin)

    # save dq state segments
    for dq_state, segs in dq_state_segs_dict.items():
        name = f'dq_state_{dq_state}'
        dq_grp = all_dq_grp.create_group(name)
        starts, ends = segments_to_start_end(segs)
        dq_grp['segment_starts'] = starts
        dq_grp['segment_ends'] = ends
        dq_grp['livetime'] = abs(segs)

    f.attrs['stat'] = f'{ifo}-dq_stat_info'
    f.attrs['sngl_ranking'] = args.sngl_ranking
    f.attrs['sngl_ranking_threshold'] = args.stat_threshold

logging.info('Done!')
