#!/usr/bin/python3.13 -s

# Copyright 2020 Gareth S. Cabourn Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""Fit a background model to single-detector triggers from PyCBC Live.

See https://arxiv.org/abs/2008.07494 for a description of the method."""

import os
import sys
import argparse
import logging
import numpy as np

import pycbc
from pycbc.bin_utils import IrregularBins
from pycbc.events import cuts, trigger_fits as trstats, stat
from pycbc.io import DictArray, HFile
from pycbc.io import live as liveio
from pycbc.live import significance_fits as sngls_io
from pycbc.events import ranking
from pycbc.events.coinc import cluster_over_time
from pycbc.types import MultiDetOptionAction

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--ifos", nargs="+", required=True,
                    help="Which ifo(s) are we fitting the triggers for? "
                         "Required")
liveio.add_live_trigger_selection_options(parser)
sngls_io.add_live_significance_trigger_pruning_options(parser)
sngls_io.add_live_significance_duration_bin_options(parser)
parser.add_argument("--fit-function", default="exponential",
                    action=MultiDetOptionAction,
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit. "
                         "Choose from exponential, rayleigh or power. "
                         "Default: exponential")
parser.add_argument("--fit-threshold", type=float, default=5,
                    action=MultiDetOptionAction,
                    help="Lower threshold used in fitting the triggers."
                         "Default 5. Can be supplied as a single value, "
                         "or as a set of IFO:value pairs, e.g. H1:0:, L1:-1")
parser.add_argument("--cluster", action='store_true',
                    help="Only use maximum of the --sngl-ranking value "
                         "from each file.")
parser.add_argument("--output", required=True,
                    help="File in which to save the output trigger fit "
                         "parameters.")
stat.insert_statistic_option_group(
    parser,
    default_ranking_statistic='single_ranking_only'
)
cuts.insert_cuts_option_group(parser)

args = parser.parse_args()

pycbc.init_logging(args.verbose)

# Check input options
sngls_io.verify_live_significance_trigger_pruning_options(args, parser)
sngls_io.verify_live_significance_duration_bin_options(args, parser)

stat_kwargs = stat.parse_statistic_keywords_opt(args.statistic_keywords)

duration_bin_edges = sngls_io.duration_bins_from_cli(args)
logging.info(
    "Duration bin edges: %s",
    ', '.join([f'{e:.3e}' for e in duration_bin_edges])
)

logging.info("Finding files")

files = liveio.find_trigger_files_from_cli(args)

logging.info("%s files found", len(files))

logging.info("Setting up the cut dictionaries")
# Add template duration cuts according to the bin inputs
args.template_cuts = args.template_cuts or []
# These won't be needed if using args.duration_from_bank,
# but might be if using manually selected bin edges
args.template_cuts.append(f"template_duration:{min(duration_bin_edges)}:lower")
args.template_cuts.append(f"template_duration:{max(duration_bin_edges)}:upper_inc")

# Ensure that triggers are within the time limits
args.trigger_cuts = args.trigger_cuts or []
args.trigger_cuts.append(f"end_time:{args.gps_start_time}:lower_inc")
args.trigger_cuts.append(f"end_time:{args.gps_end_time}:upper_inc")

trigger_cut_dict, template_cut_dict = cuts.ingest_cuts_option_group(args)

logging.info("Setting up duration bins")
tbins = IrregularBins(duration_bin_edges)

# Also calculate live time so that this fitting can be used in rate estimation
# Live time is not immediately obvious - get an approximation with 8 second
# granularity by adding 8 seconds per 'valid' file

live_time = {ifo: 0 for ifo in args.ifos}

logging.info("Getting events which meet criteria")

# Loop through files - add events which meet the immediately gettable
# criteria
events = {}

rank_method = {ifo: stat.get_statistic_from_opts(args, [ifo]) for ifo in args.ifos}

logging.info("Processing %d files", len(files))
for counter, filename in enumerate(files):
    if counter and counter % 1000 == 0:
        logging.info("Processed %d/%d files", counter, len(files))
        for ifo in args.ifos:
            if ifo not in events:
                # In case of no triggers for an extended period
                logging.info("%s: No data", ifo)
            else:
                logging.info("%s: %d triggers in %.0f s", ifo,
                             events[ifo].data['snr'].size, live_time[ifo])

    # If there is an IOerror with the file, don't fail, just carry on
    try:
        HFile(filename, 'r')
    except IOError:
        logging.warning('IOError with file %s', f)
        continue

    # Triggers for this file
    triggers = {}
    with HFile(filename, 'r') as fin:
        # Open the file: does it have the ifo group and snr dataset?
        for ifo in args.ifos:
            if not (ifo in fin and 'snr' in fin[ifo]):
                continue
            # Eventual FIX ME: live output files should (soon) have the live time
            # added, but for now, extract from the filename
            # Format of the filename is to have the live time as a dash,
            # followed by '.hdf' at the end of the filename
            lt = int(filename.split('-')[-1][:-4])
            live_time[ifo] += lt

            n_triggers = fin[ifo]['snr'].size
            # Skip if there are no triggers
            if not n_triggers:
                continue

            # Read trigger value datasets from file
            # Get all datasets with the same size as the trigger SNRs,
            # except for edge cases where the number of loudest, gates etc.
            # happens to be the same as the trigger count
            triggers[ifo] = {k: fin[ifo][k][:] for k in fin[ifo].keys()
                             if k not in ('loudest', 'stat', 'gates', 'psd')
                             and fin[ifo][k].size == n_triggers}

            # The stored chisq is actually reduced chisq, so hack the
            # chisq_dof dataset to use the standard conversions.
            # chisq_dof of 1.5 gives the right number (2 * 1.5 - 2 = 1)
            triggers[ifo]['chisq_dof'] = \
                1.5 * np.ones_like(triggers[ifo]['snr'])


    for ifo, trigs_ifo in triggers.items():

        # Apply the cuts to triggers
        keep_idx = cuts.apply_trigger_cuts(trigs_ifo, trigger_cut_dict)

        # triggers contains the datasets that we want to use for
        # the template cuts, so here it can be used as the template bank
        keep_idx = cuts.apply_template_cuts(trigs_ifo, template_cut_dict,
                                            template_ids=keep_idx)

        # Skip if no triggers survive the cuts
        if not keep_idx.size:
            continue

        # Apply the cuts
        triggers_cut = {k: trigs_ifo[k][keep_idx]
                        for k in trigs_ifo.keys()}

        # Calculate the sngl_ranking values
        sds = rank_method[ifo].single(triggers_cut)
        sngls_value = rank_method[ifo].rank_stat_single(
            (ifo, sds),
            **stat_kwargs
        )

        triggers_cut['stat'] = sngls_value

        triggers_da = DictArray(data=triggers_cut)

        # If we are clustering, take the max sngl_ranking value
        if args.cluster:
            max_idx = sngls_value.argmax()
            # Make sure that the DictArray has array data, not float
            triggers_da = triggers_da.select([max_idx])

        if ifo in events:  # DictArray already exists for the ifo
            events[ifo] += triggers_da
        else:  # Set up a new dictionary entry
            events[ifo] = triggers_da

logging.info("All events processed")

logging.info("Number of events which meet all criteria:")
for ifo in args.ifos:
    if ifo not in events:
        logging.info("%s: No data", ifo)
    else:
        logging.info("%s: %d in %.2fs",
                     ifo, len(events[ifo]), live_time[ifo])

logging.info('Sorting events into template duration bins')

# Set up bins and prune loud events in each bin
n_bins = duration_bin_edges.size - 1
alphas = {i: np.zeros(n_bins, dtype=np.float32) for i in args.ifos}
counts = {i: np.zeros(n_bins, dtype=np.float32) for i in args.ifos}
event_bins = {}
times_to_prune = {ifo: [] for ifo in args.ifos}

for ifo in events:
    # Sort the events into their bins
    event_bins[ifo] = np.array([tbins[d]
                               for d in events[ifo].data['template_duration']])

    if args.prune_loudest:
        for bin_num in range(n_bins):
            inbin = event_bins[ifo] == bin_num

            binned_events = events[ifo].data['stat'][inbin]
            binned_event_times = events[ifo].data['end_time'][inbin]

            # Cluster triggers in time with the pruning window to ensure
            # that clusters are independent
            cidx = cluster_over_time(binned_events, binned_event_times,
                                     args.prune_window)

            # Find clusters at/above the statistic threshold
            above_stat_min = binned_events[cidx] >= args.prune_stat_threshold
            cidx = cidx[above_stat_min]

            if args.prune_loudest > cidx.size:
                # There are fewer clusters than the number specified,
                # so prune them all
                times_to_prune[ifo] += list(binned_event_times[cidx])
                continue

            # Find the loudest of the triggers in this bin
            argloudest = np.argsort(binned_events[cidx])[-args.prune_loudest:]
            times_to_prune[ifo] += list(binned_event_times[cidx][argloudest])

n_pruned = {ifo: [] for ifo in args.ifos}
pruned_trigger_times = {}
if args.prune_loudest:
    logging.info("Pruning triggers %.2fs either side of the loudest %d "
                 "triggers in each bin if %s > %.2f", args.prune_window,
                 args.prune_loudest, args.sngl_ranking,
                 args.prune_stat_threshold)
    for ifo in events:
        times = events[ifo].data['end_time'][:]
        outwith_window = np.ones_like(times, dtype=bool)
        for t in times_to_prune[ifo]:
            outwith_window &= abs(times - t) > args.prune_window
            # Need to make an (ever-so-small) correction to the live time
            live_time[ifo] -= 2 * args.prune_window

        # Save the pruned events for reporting
        within_window = np.logical_not(outwith_window)
        pruned_trigger_bins = event_bins[ifo][within_window]
        pruned_trigger_times[ifo] = times[within_window]

        # Remove pruned events from the arrays we will fit
        events[ifo] = events[ifo].select(outwith_window)
        event_bins[ifo] = event_bins[ifo][outwith_window]

        # Report the number of pruned triggers in each bin
        for bin_num in range(n_bins):
            pruned_inbin = pruned_trigger_bins == bin_num
            n_pruned_thisbin = np.count_nonzero(pruned_inbin)
            n_pruned[ifo].append(n_pruned_thisbin)
            logging.info("Pruned %d triggers from %s bin %d",
                         n_pruned_thisbin, ifo, bin_num)

# Do the fitting for each bin
for ifo in events:
    for bin_num in range(n_bins):

        inbin = event_bins[ifo] == bin_num

        if not np.count_nonzero(inbin):
            # No triggers, alpha and count are -1
            counts[ifo][bin_num] = -1
            alphas[ifo][bin_num] = -1
            continue

        stat_inbin = events[ifo].data['stat'][inbin]
        counts[ifo][bin_num] = \
            np.count_nonzero(stat_inbin > args.fit_threshold[ifo])

        alphas[ifo][bin_num], _ = trstats.fit_above_thresh(
            args.fit_function[ifo],
            stat_inbin,
            args.fit_threshold[ifo]
        )

logging.info("Writing results")
with HFile(args.output, 'w') as fout:
    for ifo in args.ifos:
        fout_ifo = fout.create_group(ifo)
        fout_ifo.attrs['fit_function'] = args.fit_function[ifo]
        fout_ifo.attrs['fit_threshold'] = args.fit_threshold[ifo]
        if ifo not in events:
            # There were no triggers, but we should still produce some
            # information
            fout_ifo['fit_coeff'] = -1 * np.zeros(n_bins)
            fout_ifo['counts'] = np.zeros(n_bins)
            fout_ifo.attrs['live_time'] = live_time[ifo]
            fout_ifo.attrs['pruned_times'] = []
            fout_ifo.attrs['n_pruned'] = 0
            continue

        # Save the triggers we have used for the fits
        fout_ifo_trigs = fout_ifo.create_group('triggers')
        for key in events[ifo].data:
            fout_ifo_trigs[key] = events[ifo].data[key]
        if ifo in pruned_trigger_times:
            fout_ifo['pruned_trigger_times'] = pruned_trigger_times[ifo]

        fout_ifo['fit_coeff'] = alphas[ifo]
        fout_ifo['counts'] = counts[ifo]
        fout_ifo.attrs['live_time'] = live_time[ifo]
        fout_ifo.attrs['pruned_times'] = times_to_prune[ifo]
        fout_ifo.attrs['n_pruned'] = n_pruned[ifo]

    fout['bins_upper'] = tbins.upper()
    fout['bins_lower'] = tbins.lower()

    fout.attrs['ifos'] = ','.join(args.ifos)
    fout.attrs['fit_start_gps_time'] = args.gps_start_time
    fout.attrs['fit_end_gps_time'] = args.gps_end_time
    fout.attrs['input'] = sys.argv
    fout.attrs['cuts'] = args.template_cuts + args.trigger_cuts
    fout.attrs['sngl_ranking'] = args.sngl_ranking
    fout.attrs['ranking_statistic'] = args.ranking_statistic

logging.info("Done")
