#!/usr/bin/python3.13 -s
"""
The program combines output files generated
by pycbc_sngls_findtrigs to generate a mapping between SNR and FAP/FAR, along
with producing the combined foreground and background triggers.
"""

import argparse, itertools
import logging, numpy, copy
from pycbc.events import veto, coinc
from pycbc.events import triggers, trigger_fits as trstats
from pycbc.events import significance
import pycbc.pnutils, pycbc.io
import sys
import pycbc.conversions as conv

class fw(object):
    def __init__(self, name):
        self.f = pycbc.io.HFile(name, 'w')
        self.attrs = self.f.attrs

    def __setitem__(self, name, data):
        # Make a new item if isn't in the hdf file
        if not name in self.f:
            self.f.create_dataset(name, data=data, compression="gzip",
                                  compression_opts=9, shuffle=True,
                                  maxshape=data.shape)
        # Else reassign values
        else:
            self.f[name][:] = data

    def __getitem__(self, *args):
        return self.f.__getitem__(*args)

parser = argparse.ArgumentParser()
# General required options
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--sngls-files', nargs='+',
                    help='List of files containing trigger and statistic '
                         'information.')
parser.add_argument('--ifos', nargs=1,
                    help='List of ifos used in these coincidence files')
parser.add_argument('--cluster-window', type=float, default=10,
                    help='Length of time window in seconds to cluster coinc '
                         'events [default=10s]')
parser.add_argument('--veto-window', type=float, default=.1,
                    help='Time around each zerolag trigger to window out '
                         '[default=.1s]')
significance.insert_significance_option_group(parser)
parser.add_argument('--hierarchical-removal-ifar-threshold',
                    type=float, default=0.5,
                    help="Threshold to hierarchically remove foreground "
                         "triggers with IFAR (years) above this value "
                         "[default=0.5yr]")
parser.add_argument('--hierarchical-removal-window', type=float, default=1.0,
                    help='Time around each trigger to window out for a very '
                         'louder trigger in the hierarchical removal '
                         'procedure [default=1.0s]')
parser.add_argument('--max-hierarchical-removal', type=int, default=0,
                    help='Maximum number of hierarchical removals to carry '
                         'out. Choose -1 for unlimited hierarchical removal '
                         'until no foreground triggers are louder than the '
                         '(inclusive/exclusive) background. Choose 0 to do '
                         'no hierarchical removals, choose 1 to do at most '
                         '1 hierarchical removal and so on. If given, must '
                         'also provide --hierarchical-removal-against to '
                         'indicate which background to remove triggers '
                         'against. [default=0]')
parser.add_argument('--hierarchical-removal-against', type=str,
                    default='none', choices=['none', 'inclusive', 'exclusive'],
                    help='If doing hierarchical removal, remove foreground '
                         'triggers that are louder than either the "inclusive"'
                         ' (little-dogs-in) background, or the "exclusive" '
                         '(little-dogs-out) background. [default="none"]')
parser.add_argument('--output-file')
args = parser.parse_args()

significance.check_significance_options(args, parser)

# Check that the user chose inclusive or exclusive background to perform
# hierarchical removals of foreground triggers against.
if args.max_hierarchical_removal == 0:
    if args.hierarchical_removal_against != 'none':
        parser.error("User Error: 0 maximum hierarchical removals chosen but "
                     "option for --hierarchical-removal-against was given. "
                     "These are conflicting options. Use with --help for more "
                     "information.")
else :
    is_bkg_inc = (args.hierarchical_removal_against == 'inclusive')
    is_bkg_exc = (args.hierarchical_removal_against == 'exclusive')
    if not(is_bkg_inc or is_bkg_exc):
        parser.error("--max-hierarchical-removal requires a choice of which "
                    "background to remove foreground triggers against, "
                     "inclusive or exclusive. Use with --help for more "
                     "information.")


pycbc.init_logging(args.verbose)

logging.info("Loading triggers")
ifo = args.ifos[0]
logging.info("IFO input: %s" % ifo)
all_trigs = pycbc.io.MultiifoStatmapData(files=args.sngls_files,
                                         ifos=[ifo])
assert ifo + '/time' in all_trigs.data

logging.info("We have %s triggers" % len(all_trigs.stat))
logging.info("Clustering triggers")
all_trigs = all_trigs.cluster(args.cluster_window)

fg_time = float(all_trigs.attrs['foreground_time'])

logging.info("Dumping foreground triggers")
f = fw(args.output_file)
f.attrs['num_of_ifos'] = 1
f.attrs['ifos'] = ifo

f.attrs['timeslide_interval'] = all_trigs.attrs['timeslide_interval']

# Copy over the segment info
for key in all_trigs.seg.keys():
    f['segments/%s/start' % key] = all_trigs.seg[key]['start'][:]
    f['segments/%s/end' % key] = all_trigs.seg[key]['end'][:]

f['segments/foreground_veto/start'] = numpy.array([0])
f['segments/foreground_veto/end'] = numpy.array([0])
for k in all_trigs.data:
    f['foreground/' + k] = all_trigs.data[k]
    f['background/' + k] = all_trigs.data[k]

logging.info("Estimating FAN from background statistic values")
# Ranking statistic of foreground and background
fore_stat = back_stat = all_trigs.stat
bkg_dec_facs = all_trigs.decimation_factor

significance_dict = significance.digest_significance_options([ifo], args)

# Cumulative array of inclusive background triggers and the number of
# inclusive background triggers louder than each foreground trigger
bg_far, fg_far = significance.get_far(
    back_stat,
    fore_stat,
    bkg_dec_facs,
    fg_time,
    **significance_dict[ifo])

fg_far = significance.apply_far_limit(
    fg_far,
    significance_dict,
    combo=ifo,
)
bg_far = significance.apply_far_limit(
    bg_far,
    significance_dict,
    combo=ifo,
)

bg_ifar = 1. / bg_far
fg_ifar = 1. / fg_far

f['background/ifar'] = conv.sec_to_year(bg_ifar)

f.attrs['background_time'] = fg_time
f.attrs['foreground_time'] = fg_time

# Find foreground triggers with IFAR > the set limit and remove from
# the exclusive background

# Need to make copies for use as exclusive triggers as these will have
# items removed from them, and don't want to overwrite the original
fg_time_exc = fg_time
fg_ifar_exc = copy.deepcopy(fg_ifar)
bg_ifar_exc = copy.deepcopy(bg_ifar)
back_stat_exc = copy.deepcopy(back_stat)
bkg_exc_dec_facs = copy.deepcopy(bkg_dec_facs)

# Record indices into all_trigs for the exclusive background
back_exc_locs = numpy.arange(len(all_trigs.stat))

# Remove trigs from 'exclusive' background if their IFAR is > livetime
to_keep = bg_ifar_exc <= fg_time_exc

n_removed = bg_ifar_exc.size - sum(to_keep)
logging.info("Removing %s event(s) from exclusive background",
             n_removed)

back_stat_exc = back_stat_exc[to_keep]
bkg_exc_dec_facs = bkg_exc_dec_facs[to_keep]
back_exc_locs = back_exc_locs[to_keep]

# Cumulative array of exclusive background triggers and the number
# of exclusive background triggers louder than each foreground trigger
bg_far_exc, fg_far_exc = significance.get_far(
    back_stat_exc,
    fore_stat,
    bkg_exc_dec_facs,
    fg_time_exc,
    **significance_dict[ifo])

fg_far_exc = significance.apply_far_limit(
    fg_far_exc,
    significance_dict,
    combo=ifo)

bg_far_exc = significance.apply_far_limit(
    bg_far_exc,
    significance_dict,
    combo=ifo)

bg_ifar_exc = 1. / bg_far_exc
fg_ifar_exc = 1. / fg_far_exc

# Remove a small amount of time from the exclusive fore/background
# time to account for this removal
fg_time_exc -= n_removed * args.veto_window

for k in all_trigs.data:
    f['background_exc/' + k] = all_trigs.data[k][back_exc_locs]

f['background_exc/ifar'] = conv.sec_to_year(bg_ifar_exc)
f.attrs['background_time_exc'] = fg_time_exc
f.attrs['foreground_time_exc'] = fg_time_exc

logging.info("calculating foreground ifar/fap values")

fap = 1 - numpy.exp(- fg_time / fg_ifar)
f['foreground/ifar'] = conv.sec_to_year(fg_ifar)
f['foreground/fap'] = fap
fap_exc = 1 - numpy.exp(- fg_time_exc / fg_ifar_exc)
f['foreground/ifar_exc'] = conv.sec_to_year(fg_ifar_exc)
f['foreground/fap_exc'] = fap_exc

if 'name' in all_trigs.attrs:
    f.attrs['name'] = all_trigs.attrs['name']

# Incorporate hierarchical removal for any other loud triggers
logging.info("Beginning hierarchical removal of foreground triggers")

# Step 1: Create a copy of foreground trigger ranking statistic for reference
#         in the hierarchical removal while loop when updating ifar and fap o
#         hierarchically removed foreground triggers.

# Set an index to keep track of how many hierarchical removals we want to do.
h_iterations = 0

orig_fore_stat = fore_stat

# Assign a new variable to keep track of whether we care about inclusive or
# exclusive ifar depending on whether we want to remove hierarchically
# against inclusive or exclusive background.

if args.max_hierarchical_removal != 0:
    # If user wants to remove against inclusive background.
    if is_bkg_inc:
        ifar_louder = fg_ifar
    # Otherwise user wants to remove against exclusive background
    else :
        ifar_louder = fg_ifar_exc
else :
    # It doesn't matter if you choose inclusive or exclusive,
    # the while loop below will break if none are louder than
    # ifar_louder, or at the comparison
    # h_iterations == args.max_hierarchical_removal. But this avoids
    # a NameError
    ifar_louder = fg_ifar

# Step 2 : Loop until we don't have to hierarchically remove anymore. This
#          will happen when ifar_louder has no elements that are
#          above the set threshold, or a set maximum.

# Convert threshold into seconds
hier_ifar_thresh_s = args.hierarchical_removal_ifar_threshold /  \
    conv.sec_to_year(1)

while numpy.any(ifar_louder > hier_ifar_thresh_s):
    # If the user wants to stop doing hierarchical removals after a set
    # number of iterations then break when that happens.
    if (h_iterations == args.max_hierarchical_removal):
        break

    # Write foreground trigger info before hierarchical removals for
    # downstream codes.
    if h_iterations == 0:
        f['background_h%s/stat' % h_iterations] = back_stat
        f['background_h%s/ifar' % h_iterations] = conv.sec_to_year(bg_ifar)
        for k in all_trigs.data:
            f['background_h%s/' % h_iterations + k] = all_trigs.data[k]
        f['foreground_h%s/stat' % h_iterations] = fore_stat
        f['foreground_h%s/ifar' % h_iterations] = conv.sec_to_year(fg_ifar)
        f['foreground_h%s/ifar_exc' % h_iterations] = conv.sec_to_year(fg_ifar_exc)
        f['foreground_h%s/fap' % h_iterations] = fap
        for k in all_trigs.data:
            f['foreground_h%s/' % h_iterations + k] = all_trigs.data[k]
    # Add the iteration number of hierarchical removals done.
    h_iterations += 1

    # Among foreground triggers, find the one with the largest ifar
    # and mark it for removal.
    max_stat_idx = ifar_louder.argmax()

    # Step 3: Remove that trigger from the list of zerolag triggers

    # Find the index of the loud foreground trigger to remove next. And find
    # the index in the list of original foreground triggers.
    rm_trig_idx = numpy.where(all_trigs.stat[:] == fore_stat[max_stat_idx])[0][0]
    orig_fore_idx = numpy.where(orig_fore_stat == fore_stat[max_stat_idx])[0][0]

    # Store any foreground trigger's information that we want to
    # hierarchically remove.
    f['foreground/ifar'][orig_fore_idx] = conv.sec_to_year(fg_ifar[max_stat_idx])
    f['foreground/fap'][orig_fore_idx] = fap[max_stat_idx]

    logging.info("Removing foreground trigger that is louder than the inclusive background.")

    # Remove the foreground trigger and all of the background triggers that
    # are associated with it.
    ave_rm_time = all_trigs.data['%s/time' % ifo][rm_trig_idx]

    ind_to_rm = {}
    ind_to_rm[ifo] = veto.indices_within_times(all_trigs.data['%s/time' % ifo],
                              [ave_rm_time - args.hierarchical_removal_window],
                              [ave_rm_time + args.hierarchical_removal_window])
    indices_to_rm = []
    indices_to_rm = numpy.concatenate([indices_to_rm, ind_to_rm[ifo]])

    all_trigs = all_trigs.remove(indices_to_rm.astype(int))
    logging.info("We have %s triggers after hierarchical removal." % len(all_trigs.stat))

    # Step 4: Re-cluster the triggers and calculate the inclusive ifar/fap
    logging.info("Clustering coinc triggers (inclusive of zerolag)")
    all_trigs = all_trigs.cluster(args.cluster_window)

    logging.info("%s clustered foreground triggers" % len(all_trigs))
    logging.info("%s hierarchically removed foreground trigger(s)" % h_iterations)

    logging.info("Dumping foreground triggers")
    logging.info("Dumping background triggers (inclusive of zerolag)")
    for k in all_trigs.data:
        f['background_h%s/' % h_iterations + k] = all_trigs.data[k]

    logging.info("Calculating FAN from background statistic values")
    back_stat = fore_stat = all_trigs.stat
    bkg_dec_facs = all_trigs.decimation_factor

    bg_far, fg_far = significance.get_far(
        back_stat,
        fore_stat,
        bkg_dec_facs,
        fg_time,
        **significance_dict[ifo])

    fg_far = significance.apply_far_limit(
        fg_far,
        significance_dict,
        combo=ifo,
    )

    bg_far = significance.apply_far_limit(
        bg_far,
        significance_dict,
        combo=ifo,
    )

    bg_ifar = 1. / bg_far
    fg_ifar = 1. / fg_far

    # Update the ifar_louder criteria depending on whether foreground
    # triggers are being removed via inclusive or exclusive background.
    if is_bkg_inc:
        ifar_louder = fg_ifar

    # Exclusive background doesn't change when removing foreground triggers.
    # So we don't have to take bg_far_exc, just repopulate fg_ifar_exc
    else:
        _, fg_far_exc = significance.get_far(
            back_stat_exc,
            fore_stat,
            bkg_exc_dec_facs,
            fg_time_exc,
            **significance_dict[ifo])

        fg_far_exc = significance.apply_far_limit(
            fg_far_exc,
            significance_dict,
            combo=ifo,
        )

        fg_ifar_exc = 1. / fg_far_exc

        ifar_louder = fg_ifar_exc

    # louder_foreground has been updated and the code can continue.

    logging.info("Calculating ifar/fap values")
    f['background_h%s/ifar' % h_iterations] = conv.sec_to_year(bg_ifar)
    f.attrs['background_time_h%s' % h_iterations] = fg_time
    f.attrs['foreground_time_h%s' % h_iterations] = fg_time

    if len(all_trigs) > 0:
        # Write ranking statistic to file just for downstream plotting code
        f['foreground_h%s/stat' % h_iterations] = fore_stat

        fap = 1 - numpy.exp(- fg_time / fg_ifar)
        f['foreground_h%s/ifar' % h_iterations] = conv.sec_to_year(fg_ifar)
        f['foreground_h%s/fap' % h_iterations] = fap

        fap_exc = 1 - numpy.exp(- fg_time / fg_ifar_exc)
        f['foreground_h%s/ifar' % h_iterations] = conv.sec_to_year(fg_ifar_exc)
        f['foreground_h%s/fap' % h_iterations] = fap_exc

        # Update ifar and fap for other foreground triggers
        for i in range(len(fg_ifar)):
            orig_fore_idx = numpy.where(orig_fore_stat == fore_stat[i])[0][0]
            f['foreground/ifar'][orig_fore_idx] = conv.sec_to_year(fg_ifar[i])
            f['foreground/fap'][orig_fore_idx] = fap[i]

        # Save trigger ids for foreground triggers for downstream plotting code.
        # These don't change with the iterations but should be written at every
        # level.

        f['foreground_h%s/template_id' % h_iterations] = all_trigs.data['template_id']
        trig_id = all_trigs.data['%s/trigger_id' % ifo]
        trig_time = all_trigs.data['%s/time' % ifo]
        f['foreground_h%s/%s/time' % (h_iterations,ifo)] = trig_time
        f['foreground_h%s/%s/trigger_id' % (h_iterations,ifo)] = trig_id
    else :
        f['foreground_h%s/stat' % h_iterations] = numpy.array([])
        f['foreground_h%s/ifar' % h_iterations] = numpy.array([])
        f['foreground_h%s/fap' % h_iterations] = numpy.array([])
        f['foreground_h%s/template_id' % h_iterations] = numpy.array([])
        f['foreground_h%s/%s/time' % (h_iterations,ifo)] = numpy.array([])
        f['foreground_h%s/%s/trigger_id' % (h_iterations,ifo)] = numpy.array([])

# Write to file how many hierarchical removals were implemented.
f.attrs['hierarchical_removal_iterations'] = h_iterations

# Write whether hierarchical removals were removed against the
# inclusive background or the exclusive background. Have to use
# numpy.string_ datatype.
if h_iterations != 0:
    hrm_method = args.hierarchical_removal_against
    f.attrs['hierarchical_removal_method'] = numpy.string_(hrm_method)

logging.info("Done")

