#!/usr/bin/python3.13 -s
import argparse, logging, h5py, numpy as np
from ligo.segments import infinity
from numpy.random import seed, shuffle

import pycbc
from pycbc.events import veto, coinc, stat
import pycbc.conversions as conv
from pycbc import io
from pycbc.events import cuts, trigger_fits as trfits
from pycbc.events.veto import indices_outside_times
from pycbc.types.optparse import MultiDetOptionAction
from pycbc import init_logging

parser = argparse.ArgumentParser()
pycbc.add_common_pycbc_options(parser)
# Basic file input options
parser.add_argument("--trigger-files", type=str, nargs=1,
                    help="File containing single-detector triggers")
parser.add_argument("--template-bank", required=True,
                    help="Template bank file in HDF format")
parser.add_argument("--template-fraction-range", default="0/1",
                    help="Optional, analyze only part of template bank. Format"
                         " PART/NUM_PARTS")
parser.add_argument("--randomize-template-order", action="store_true",
                    help="Random shuffle templates with fixed seed "
                         "before selecting range to analyze")
# Options to define the vetoes
parser.add_argument("--veto-files", nargs='*', action='append', default=[],
                    help="Optional veto file. Triggers within veto segments "
                         "contained in the file are ignored")
parser.add_argument("--segment-name", nargs='*', action='append', default=[],
                    help="Optional, name of veto segment in veto file")
parser.add_argument("--gating-veto-windows", nargs='+',
                    action=MultiDetOptionAction,
                    help="Seconds to be vetoed before and after the central time "
                         "of each gate. Given as detector-values pairs, e.g. "
                         "H1:-1,2.5 L1:-1,2.5 V1:0,0")
# additional veto options
# produces a list of lists to allow multiple invocations and multiple args
parser.add_argument('--cluster-window', type=float,
                    help='Window (seconds) during which to keep the trigger '
                         'with the loudest statistic value. '
                         'Default=do not cluster')
parser.add_argument("--output-file",
                    help="File to store the candidate triggers")
parser.add_argument("--loudest-keep-values",
                    default='[6:1]',
                    help="Apply successive multiplicative levels of"
                         " decimation to coincs with stat value below the"
                         " given thresholds. Supply as a comma-separated list"
                         " of threshold:decimation value pairs surrounded by"
                         " square brackets (no spaces!). Decimation values must"
                         " be positive integers."
                         " Ex. [15:5,10:30,5:30,0:30]."
                         " Default: no decimation")
stat.insert_statistic_option_group(parser)
cuts.insert_cuts_option_group(parser)
args = parser.parse_args()

trigger_file = args.trigger_files[0]

if (args.veto_files and not args.segment_name) or \
    (args.segment_name and not args.veto_files):
    raise RuntimeError('--veto-files and --segment-name are mutually required')

if not len(args.veto_files) == len(args.segment_name):
    raise RuntimeError('--segment-name optionss are required for each --veto-files')

args.segment_name = sum(args.segment_name, [])
args.veto_files = sum(args.veto_files, [])

init_logging(args.verbose)

trigger_cut_dict, template_cut_dict = cuts.ingest_cuts_option_group(args)

logging.info('Opening trigger file: %s', trigger_file)
trigf = io.HFile(trigger_file, 'r')
ifo = list(trigf.keys())[0]

# Set up to only load triggers from the templates of interest

def parse_template_range(num_templates, rangestr):
    part = int(rangestr.split('/')[0])
    pieces = int(rangestr.split('/')[1])
    tmin = int(num_templates / float(pieces) * part)
    tmax = int(num_templates / float(pieces) * (part+1))
    return tmin, tmax

num_templates = io.HFile(args.template_bank, "r")['template_hash'].size
tmin, tmax = parse_template_range(num_templates, args.template_fraction_range)
logging.info('Analyzing template %s - %s' % (tmin, tmax-1))

if args.randomize_template_order:
    seed(0)
    template_ids = np.arange(0, num_templates)
    shuffle(template_ids)
    template_ids = template_ids[tmin:tmax]
else:
    template_ids = np.array(range(tmin, tmax))

original_bank_len = len(template_ids)

from pycbc.io.hdf import ReadByTemplate
trigs = ReadByTemplate(trigger_file,
                       args.template_bank,
                       args.segment_name,
                       args.veto_files,
                       args.gating_veto_windows)
logging.info("%d triggers in file", trigf[ifo + '/snr'].size)

stat_all = []
trigger_ids_all = []
template_ids_all = []
trigger_times_all = []

rank_method = stat.get_statistic_from_opts(args, [ifo])

# Apply cuts to templates
template_ids = cuts.apply_template_cuts(
    trigs.bank,
    template_cut_dict,
    statistic=rank_method,
    ifos=[ifo],
    template_ids=template_ids)

logging.info("%d out of %d templates kept after applying template cuts",
             len(template_ids), original_bank_len)

if args.cluster_window is not None:
  logging.info('Clustering events over %s s window within each template',
               args.cluster_window)

extra_kwargs = {}
for inputstr in args.statistic_keywords:
    try:
        key, value = inputstr.split(':')
        extra_kwargs[key] = value
    except ValueError:
        err_txt = "--statistic-keywords must take input in the " \
                  "form KWARG1:VALUE1 KWARG2:VALUE2 KWARG3:VALUE3 ... " \
                  "Received {}".format(args.statistic_keywords)
        raise ValueError(err_txt)

loudest_keep_vals = args.loudest_keep_values.strip('[]').split(',')

threshes = []
factors = []
tot_fac = 1
for decstr in loudest_keep_vals:
    thresh, factor = decstr.split(':')
    if float(factor) % 1:
        raise RuntimeError("Non-integer decimation is not supported")
    if int(factor) < 1:
        raise RuntimeError("Negative or zero decimation does not make sense")
    if int(factor) == 1:
        continue
    tot_fac *= float(factor)
    threshes.append(float(thresh))
    factors.append(tot_fac)

# Sort the threshes into descending order
# - allows decimation factors to be given in any order
threshorder = np.argsort(threshes)[::-1]
threshes = np.array(threshes)[threshorder]
factors = np.array(factors)[threshorder]


for i, tnum in enumerate(template_ids):
    if i % 1000 == 0:
        logging.info(
            "Calculating statistic in template %d out of %d",
            i, len(template_ids),
        )
    tids_uncut = trigs.set_template(tnum)

    trigger_keep_ids = cuts.apply_trigger_cuts(trigs, trigger_cut_dict,
                                               statistic=rank_method)
    tids_full = tids_uncut[trigger_keep_ids]
    logging.debug('%s:%s', tnum, len(tids_uncut))
    if len(tids_full) < len(tids_uncut):
        logging.debug("%s triggers cut",
                      len(tids_uncut) - len(tids_full))

    n_tot_trigs = tids_full.size
    if n_tot_trigs == 0:
        continue

    # Stat class instance to calculate the ranking statistic
    sds = rank_method.single(trigs)[trigger_keep_ids]
    stat_t = rank_method.rank_stat_single((ifo, sds),
                                          **extra_kwargs)
    trigger_times = sds['end_time']
    if args.cluster_window is not None:
        cid = coinc.cluster_over_time(stat_t, trigger_times,
                                      args.cluster_window)
        stat_t = stat_t[cid]
        tids_full = tids_full[cid]
        trigger_times = trigger_times[cid]

    trigger_ids_all += list(tids_full)
    template_ids_all += list(tnum * np.ones_like(tids_full))
    trigger_times_all += list(trigger_times)
    stat_all += list(stat_t)

# Perform decimation
dec_facs = np.ones_like(template_ids_all)
stat_all = np.array(stat_all)
template_ids_all = np.array(template_ids_all)
trigger_ids_all = np.array(trigger_ids_all)
trigger_times_all = np.array(trigger_times_all)

logging.info("%d events", dec_facs.size)
for t, f in zip(threshes, factors):
    # Find the events with statistic above the threshold - these will be kept:
    abovethresh = stat_all >= t
    # Any events below threshold, change decimation factor to the current one
    belowthresh = np.logical_not(abovethresh)
    dec_facs[belowthresh] = f

    # Also keep events with trigger id as an integer multiple of the factor
    # Work out which events would be kept whatever happens
    factor_idx = trigger_ids_all % f == 0
    # Keep events which are above threshold or a trigger_id multiple of
    # the factor
    keep = np.logical_or(abovethresh, factor_idx)
    # Do the cuts:
    dec_facs = dec_facs[keep]
    stat_all = stat_all[keep]
    template_ids_all = template_ids_all[keep]
    trigger_ids_all = trigger_ids_all[keep]
    trigger_times_all = trigger_times_all[keep]

    logging.info(
        "%d events after decimation at statistic %.3f with factor %d",
        dec_facs.size,
        t,
        f
    )

data = {"stat": stat_all,
        "decimation_factor": dec_facs,
        "timeslide_id": np.zeros_like(stat_all),
        "template_id": template_ids_all,
        "%s/time" % ifo : trigger_times_all,
        "%s/trigger_id" % ifo: trigger_ids_all}

logging.info("saving triggers")
f = io.HFile(args.output_file, 'w')
for key in data:
    f.create_dataset(key, data=data[key],
                     compression="gzip",
                     compression_opts=9,
                     shuffle=True)
# Store segments
f['segments/%s/start' % ifo], f['segments/%s/end' % ifo] = trigs.valid
fg_segs = veto.start_end_to_segments(*trigs.valid)
fg_time = abs(fg_segs)
f.attrs['foreground_time'] = fg_time
f.attrs['background_time'] = fg_time
f.attrs['num_of_ifos'] = 1
f.attrs['pivot'] = ifo
f.attrs['fixed'] = ifo
f.attrs['ifos'] = ifo
f.attrs['timeslide_interval'] = 0

# Do hierarchical removal
# h_iterations = 0
# if args.max_hierarchical_removal != 0:

logging.info("Done")
