#!/usr/bin/python3.11
""" Bin triggers by their dq value and calculate trigger rates in each bin
"""
import logging
import argparse
import pycbc
import pycbc.events
from pycbc.events import stat as pystat
from pycbc.types.timeseries import load_timeseries
import numpy as np
import h5py as h5
from pycbc.io.hdf import HFile, SingleDetTriggers
from pycbc.version import git_verbose_msg as version

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--version', action='version', version=version)
parser.add_argument('--verbose', action="store_true")
parser.add_argument("--ifo", type=str, required=True)
parser.add_argument("--trig-file", required=True)
parser.add_argument("--stat-threshold", type=float, default=1.,
                    help="Only consider triggers with --sngl-ranking value "
                    "above this threshold")
parser.add_argument("--f-lower", type=float, default=15.,
                    help='Enforce a uniform low frequency cutoff to '
                         'calculate template duration over the bank')
parser.add_argument("--dq-file", required=True, nargs='+')
parser.add_argument("--dq-channel", required=False, type=str,
                    help='name of channel to read in from the provided '
                         'dq file. Required if file is .hdf5 format')
parser.add_argument('--bank-file', help='hdf format template bank file',
                    required=True)
parser.add_argument('--background-bins', nargs='+',
                    help='list of background bin format strings')
parser.add_argument('--n-time-bins', type=int, default=200,
                    help='Number of time bins to use')
parser.add_argument("--output-file", required=True)
parser.add_argument("--prune-number", type=int, default=0,
                    help="Number of loudest events to remove from each "
                         "split histogram, default 0")
parser.add_argument("--prune-window", type=float, default=0.1,
                    help="Time (s) to remove all triggers around a trigger "
                         "which is loudest in each split, default 0.1s")

pystat.insert_statistic_option_group(parser,
        default_ranking_statistic='single_ranking_only')
args = parser.parse_args()
pycbc.init_logging(args.verbose)

logging.info('Start')

ifo = args.ifo

# Setup a data mask to remove any triggers with SNR below threshold
# This works as a pre-filter as SNR is always greater than or equal
# to sngl_ranking, except in the psdvar case, where it could increase.

with HFile(args.trig_file, 'r') as trig_file:
    n_triggers_orig = trig_file[f'{ifo}/snr'].size
    logging.info("Trigger file has %d triggers", n_triggers_orig)
    logging.info('Generating trigger mask')
    if f'{ifo}/psd_var_val' in trig_file:
        idx, _, _ = trig_file.select(
            lambda snr, psdvar: snr / psdvar ** 0.5 >= args.stat_threshold,
            f'{ifo}/snr',
            f'{ifo}/psd_var_val',
            return_indices=True
        )
    else:
        # psd_var_val may not have been calculated
        idx, _ = trig_file.select(
            lambda snr: snr >= args.stat_threshold,
            f'{ifo}/snr',
            return_indices=True
        )
    data_mask = np.zeros(n_triggers_orig, dtype=bool)
    data_mask[idx] = True

logging.info("Getting %s triggers from file with pre-cut SNR > %.3f",
             idx.size, args.stat_threshold)

trigs = SingleDetTriggers(
    args.trig_file,
    None,
    None,
    None,
    None,
    ifo,
    premask=data_mask
)

# Extract the data we actually need from the data structure:
tmplt_ids = trigs.template_id
trig_times = trigs.end_time
stat = trigs.get_ranking(args.sngl_ranking)
trig_times_int = trig_times.astype(np.int64)

del trigs

n_triggers = tmplt_ids.size

logging.info("Applying %s > %.3f cut", args.sngl_ranking,
             args.stat_threshold)
keep = stat >= args.stat_threshold
tmplt_ids = tmplt_ids[keep]
trig_times = trig_times[keep]
trig_times_int = trig_times_int[keep]
logging.info("Removed %d triggers, %d remain",
             n_triggers - tmplt_ids.size, tmplt_ids.size)

dq_logl = np.array([])
dq_times = np.array([])

for filename in args.dq_file:
    logging.info('Reading DQ file %s', filename)
    dq_data = load_timeseries(filename, group=args.dq_channel)
    dq_logl = np.concatenate((dq_logl, dq_data[:]))
    dq_times = np.concatenate((dq_times, dq_data.sample_times))
    del dq_data

n_bins = args.n_time_bins
percentiles = np.linspace(0, 100, n_bins+1)
bin_times = np.zeros(n_bins)
dq_percentiles = np.percentile(dq_logl, percentiles)[1:]

# seconds bin tells what bin each second ends up
seconds_bin = np.array([n_bins - np.sum(dq_percentiles >= dq_ll)
                        for dq_ll in dq_logl]).astype(np.int64)
del dq_percentiles

# bin times tells how much time ends up in each bin
bin_times = np.array([np.sum(seconds_bin == i)
                      for i in range(n_bins)]).astype(np.float64)
full_time = float(len(seconds_bin))
times_nz = (bin_times > 0)
del dq_logl

# create a dict to look up dq percentile at any time
dq_percentiles_time = dict(zip(dq_times, seconds_bin/n_bins))
del dq_times

with h5.File(args.bank_file, 'r') as bank:
    if args.background_bins:
        logging.info('Sorting bank into bins')
        data = {
            'mass1': bank['mass1'][:],
            'mass2': bank['mass2'][:],
            'spin1z': bank['spin1z'][:],
            'spin2z': bank['spin2z'][:],
            'f_lower': np.ones_like(bank['mass1'][:]) * args.f_lower
            }
        locs_dict = pycbc.events.background_bin_from_string(
                                                   args.background_bins, data)
        del data
        locs_names = [b.split(':')[0] for b in args.background_bins]
    else:
        locs_dict = {'all_bin': np.arange(0, len(bank['mass1'][:]), 1)}
        locs_names = ['all_bin']

if args.prune_number > 0:
    for bin_name in locs_names:
        logging.info('Pruning bin %s', bin_name)
        bin_locs = locs_dict[bin_name]
        inbin = np.isin(tmplt_ids, bin_locs)
        trig_times_bin = trig_times[inbin]
        trig_stats_bin = stat[inbin]

        for j in range(args.prune_number):
            max_stat_arg = np.argmax(trig_stats_bin)
            remove = abs(trig_times_bin[max_stat_arg] - trig_times) \
                         < args.prune_window
            logging.info("Prune %d: pruning %d triggers", j, sum(remove))
            remove_inbin = abs(trig_times_bin[max_stat_arg]
                               - trig_times_bin) < args.prune_window
            stat[remove] = 0
            trig_stats_bin[remove_inbin] = 0
    keep = np.flatnonzero(stat)
    logging.info("%d triggers removed through pruning", keep.size)
    trig_times_int = trig_times_int[keep]
    tmplt_ids = tmplt_ids[keep]
    del stat
    del keep

del trig_times

with h5.File(args.output_file, 'w') as f:
    for bin_name in locs_names:
        bin_locs = locs_dict[bin_name]
        inbin = np.isin(tmplt_ids, bin_locs)
        trig_times_bin = trig_times_int[inbin]
        trig_percentile = np.array([dq_percentiles_time[t]
                                    for t in trig_times_bin])
        logging.info('Processing %d triggers in bin %s',
                     len(trig_percentile), bin_name)

        (counts, bins) = np.histogram(trig_percentile, bins=(percentiles)/100)
        counts = counts.astype(np.float64)
        rates = np.zeros_like(bin_times)
        rates[times_nz] = counts[times_nz]/bin_times[times_nz]
        mean_rate = len(trig_percentile) / full_time
        if mean_rate > 0.:
            rates = rates / mean_rate

        logging.info('Writing rates to output file %s', args.output_file)
        grp = f.create_group(bin_name)
        grp['rates'] = rates
        grp['locs'] = locs_dict[bin_name]

    f.attrs['names'] = locs_names

logging.info('Done!')
