#!/usr/bin/python3.11

# Copyright 2020 Gareth S. Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""Combine PyCBC Live single-detector trigger fitting parameters from several
different files."""

import argparse
import logging
import numpy as np
import h5py
import pycbc


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--trfits-files", nargs="+", required=True,
                    help="Files containing daily trigger fits")
parser.add_argument("--conservative-percentile", type=int, default=95,
                    help="What percentile to use for the conservative "
                         "combined fit. Integer in range 50-99. Default=95")
parser.add_argument("--output", required=True,
                    help="Output file for combined fit parameters")
parser.add_argument("--ifos", required=True, nargs="+",
                    help="list of ifos fo collect info for")

args = parser.parse_args()

pycbc.init_logging(args.verbose)

# Assert some sensible limits on the arguments

if args.conservative_percentile < 50 or \
    args.conservative_percentile > 99:
    parser.error("--conservative-percentile must be between 50 and 99, "
                 "otherwise it is either not a percentile, or not "
                 "conservative.")

logging.info("%d input files", len(args.trfits_files))

# We only want to combine fit results if they were done with the same
# configuration.  So start by finding the most recent fit file and reading its
# configuration parameters.

logging.info("Determining the most recent configuration parameters")

latest_date = None
for f in args.trfits_files:
    with h5py.File(f, 'r') as fit_f:
        if latest_date is None or fit_f.attrs['analysis_date'] > latest_date:
            latest_date = fit_f.attrs['analysis_date']
            bl = fit_f['bins_lower'][:]
            bu = fit_f['bins_upper'][:]
            sngl_rank = fit_f.attrs['sngl_ranking']
            fit_thresh = fit_f.attrs['fit_threshold']
            fit_func = fit_f.attrs['fit_function']

# Now go back through the fit files and read the actual information.  Skip the
# files that have fit parameters inconsistent with what we found earlier.

logging.info("Reading individual fit results")

live_times = {ifo: [] for ifo in args.ifos}
trigger_file_starts = []
trigger_file_ends = []
counts_all = {ifo: [] for ifo in args.ifos}
alphas_all = {ifo: [] for ifo in args.ifos}

for f in args.trfits_files:
    with h5py.File(f, 'r') as fits_f:
        # Check that the file uses the same setup as file 0, to make sure
        # all coefficients are comparable
        same_conf = (fits_f.attrs['sngl_ranking'] == sngl_rank
                     and fits_f.attrs['fit_threshold'] == fit_thresh
                     and fits_f.attrs['fit_function'] == fit_func
                     and all(fits_f['bins_lower'][:] == bl)
                     and all(fits_f['bins_upper'][:] == bu))
        if not same_conf:
            logging.warn(
                "Found a change in the fit configuration, skipping %s",
                f
            )
            continue

        # We now determine the (approximate) start/end times of the
        # trigger_fits file via the time of the first/last triggers in it.
        # Ideally this would be recorded exactly in the file.
        gps_last = 0
        gps_first = np.inf
        for ifo in args.ifos:
            if ifo not in fits_f:
                continue
            trig_times = fits_f[ifo]['triggers']['end_time'][:]
            gps_last = max(gps_last, trig_times.max())
            gps_first = min(gps_first, trig_times.min())
        trigger_file_starts.append(gps_first)
        trigger_file_ends.append(gps_last)

        # Read the fitting parameters
        for ifo in args.ifos:
            if ifo not in fits_f:
                live_times[ifo].append(0)
                counts_all[ifo].append(-1 * np.ones_like(bl))
                alphas_all[ifo].append(-1 * np.ones_like(bl))
            else:
                ffi = fits_f[ifo]
                live_times[ifo].append(ffi.attrs['live_time'])
                counts_all[ifo].append(ffi['counts'][:])
                alphas_all[ifo].append(ffi['fit_coeff'][:])
                if any(np.isnan(ffi['fit_coeff'][:])):
                    logging.warn("nan in %s, %s", f, ifo)
                    logging.warn(ffi['fit_coeff'][:])

# Set up the date array, this is stored as an offset from the first trigger time of
# the first file to the last trigger of the file

trigger_file_starts = np.array(trigger_file_starts)
trigger_file_ends = np.array(trigger_file_ends)
ad_order = np.argsort(trigger_file_starts)
start_time_n = trigger_file_starts[ad_order[0]]
ad = trigger_file_ends[ad_order] - start_time_n

# Swap the time and bin dimensions for counts and alphas
counts_bin = {ifo: [c for c in zip(*counts_all[ifo])] for ifo in args.ifos}
alphas_bin = {ifo: [a for a in zip(*alphas_all[ifo])] for ifo in args.ifos}

alphas_out = {ifo: np.zeros(len(alphas_bin[ifo])) for ifo in args.ifos}
counts_out = {ifo: np.inf * np.ones(len(counts_bin[ifo])) for ifo in args.ifos}
cons_alphas_out = {ifo: np.zeros(len(alphas_bin[ifo])) for ifo in args.ifos}
cons_counts_out = {ifo: np.inf * np.ones(len(alphas_bin[ifo])) for ifo in args.ifos}

logging.info("Writing results")

fout = h5py.File(args.output, 'w')
fout.attrs['fit_threshold'] = fit_thresh
fout.attrs['conservative_percentile'] = args.conservative_percentile
fout.attrs['ifos'] = args.ifos
fout['bins_edges'] = list(bl) + [bu[-1]]
fout['fits_dates'] = ad + start_time_n

for ifo in args.ifos:
    logging.info(ifo)
    fout_ifo = fout.create_group(ifo)
    l_times = np.array(live_times[ifo])
    fout_ifo.attrs['live_time'] = l_times.sum()

    fout_ifo['separate_fits/live_times'] = l_times[ad_order]
    fout_ifo['separate_fits/start_time'] = trigger_file_starts[ad_order]
    fout_ifo['separate_fits/end_time'] = trigger_file_ends[ad_order]

    for counter, a_c_u_l in enumerate(zip(alphas_bin[ifo],
                                          counts_bin[ifo], bu, bl)):
        a, c, u, l = a_c_u_l
        # Sort alpha and counts by date
        a = np.array(a)[ad_order]
        c = np.array(c)[ad_order]
        invalphan = c / a
        mean_alpha = c.mean() / invalphan.mean()
        cons_alpha = np.percentile(a, 100 - args.conservative_percentile)
        cons_alphas_out[ifo][counter] = cons_alpha
        alphas_out[ifo][counter] = mean_alpha
        # To get the count values, we need to convert to rates and back again
        r = c / l_times[ad_order]
        cons_rate = np.percentile(r, args.conservative_percentile)
        cons_counts_out[ifo][counter] = cons_rate * l_times[ad_order].sum()
        counts_out[ifo][counter] = np.mean(r) * l_times[ad_order].sum()

        fout_ifo[f'separate_fits/bin_{counter:d}/fit_coeff'] = a
        fout_ifo[f'separate_fits/bin_{counter:d}/counts'] = c

    # Output the mean average values
    fout_ifo['mean/fit_coeff'] = alphas_out[ifo]
    fout_ifo['mean/counts'] = counts_out[ifo]

    # Output the conservative values
    fout_ifo['conservative/fit_coeff'] = cons_alphas_out[ifo]
    fout_ifo['conservative/counts'] = cons_counts_out[ifo]

    # Take some averages for plotting and summary values
    overall_invalphan = counts_out[ifo] / alphas_out[ifo]
    overall_meanalpha = counts_out[ifo].mean() / overall_invalphan.mean()

    # For the fixed version, we just set this to 1
    fout_ifo['fixed/counts'] = [1] * len(counts_out[ifo])
    fout_ifo['fixed/fit_coeff'] = [0] * len(alphas_out[ifo])

    # Add some useful info to the output file
    fout_ifo.attrs['mean_alpha'] = overall_meanalpha
    fout_ifo.attrs['total_counts'] = counts_out[ifo].sum()

fout.close()

logging.info('Done')
