#!/usr/bin/python3.13 -s

# Copyright 2020 Gareth S. Cabourn Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""Combine PyCBC Live single-detector trigger fitting parameters from several
different files."""

import argparse
import logging
import numpy as np

from ligo.segments import segment, segmentlist

import pycbc
from pycbc.io.hdf import HFile

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--trfits-files", nargs="+", required=True,
                    help="Files containing daily trigger fits")
parser.add_argument("--conservative-percentile", type=int, default=95,
                    help="What percentile to use for the conservative "
                         "combined fit. Integer in range 50-99. Default=95")
parser.add_argument("--output", required=True,
                    help="Output file for combined fit parameters")
parser.add_argument("--ifos", required=True, nargs="+",
                    help="list of ifos fo collect info for")

args = parser.parse_args()

pycbc.init_logging(args.verbose)

# Assert some sensible limits on the arguments

if args.conservative_percentile < 50 or \
    args.conservative_percentile > 99:
    parser.error("--conservative-percentile must be between 50 and 99, "
                 "otherwise it is either not a percentile, or not "
                 "conservative.")

logging.info("%d input files", len(args.trfits_files))

# We only want to combine fit results if they were done with the same
# configuration.  So start by finding the most recent fit file and reading its
# configuration parameters.

logging.info("Determining the most recent configuration parameters")

latest_date = None
for f in args.trfits_files:
    with HFile(f, 'r') as fit_f:
        if latest_date is None or fit_f.attrs['fit_end_gps_time'] > latest_date:
            latest_date = fit_f.attrs['fit_end_gps_time']
            bl = fit_f['bins_lower'][:]
            bu = fit_f['bins_upper'][:]
            sngl_rank = fit_f.attrs['sngl_ranking']
            fit_thresh = {ifo: fit_f[ifo].attrs['fit_threshold']
                          for ifo in args.ifos}
            fit_func = {ifo: fit_f[ifo].attrs['fit_function']
                        for ifo in args.ifos}

# Now go back through the fit files and read the actual information.  Skip the
# files that have fit parameters inconsistent with what we found earlier.

logging.info("Reading individual fit results")

live_times = {ifo: [] for ifo in args.ifos}
trigger_file_starts = []
trigger_file_ends = []
counts_all = {ifo: [] for ifo in args.ifos}
alphas_all = {ifo: [] for ifo in args.ifos}

# Check whether the fit files overlap one another
seg_all = segmentlist([])
for f in args.trfits_files:
    with HFile(f, 'r') as fits_f:
        # Check that the file uses the same setup as file 0, to make sure
        # all coefficients are comparable
        new_fit_func = {ifo: fits_f[ifo].attrs['fit_function']
                        for ifo in args.ifos}
        new_fit_thresh = {ifo: fits_f[ifo].attrs['fit_threshold']
                          for ifo in args.ifos}
        same_conf = (fits_f.attrs['sngl_ranking'] == sngl_rank
                     and new_fit_thresh == fit_thresh
                     and new_fit_func == fit_func
                     and fits_f['bins_lower'].size == bl.size
                     and all(fits_f['bins_lower'][:] == bl)
                     and all(fits_f['bins_upper'][:] == bu))
        if not same_conf:
            logging.warning(
                "Found a change in the fit configuration, skipping %s",
                f
            )
            continue

        # Get the start/end times of the trigger_fits file
        fit_start = fits_f.attrs['fit_start_gps_time']
        fit_end = fits_f.attrs['fit_end_gps_time']
        trigger_file_starts.append(fit_start)
        trigger_file_ends.append(fit_end)
        seg = segmentlist([segment(fit_start, fit_end)])
        if seg.intersects(seg_all):
            logging.warning(
                "File %s fit time %d-%d overlaps with other file(s)",
                f,
                fit_start,
                fit_end,
            )
        seg_all |= seg

        # Read the fitting parameters
        for ifo in args.ifos:
            if ifo not in fits_f:
                live_times[ifo].append(0)
                counts_all[ifo].append(-1 * np.ones_like(bl))
                alphas_all[ifo].append(-1 * np.ones_like(bl))
            else:
                ffi = fits_f[ifo]
                live_times[ifo].append(ffi.attrs['live_time'])
                counts_all[ifo].append(ffi['counts'][:])
                alphas_all[ifo].append(ffi['fit_coeff'][:])
                if any(np.isnan(ffi['fit_coeff'][:])):
                    logging.warning("nan in %s, %s", f, ifo)
                    logging.warning(ffi['fit_coeff'][:])

# Set up the date array, this is stored as an offset from the first trigger time of
# the first file to the last trigger of the file

trigger_file_starts = np.array(trigger_file_starts)
trigger_file_ends = np.array(trigger_file_ends)
ad_order = np.argsort(trigger_file_starts)
start_time_n = trigger_file_starts[ad_order[0]]
ad = trigger_file_ends[ad_order] - start_time_n

# Swap the time and bin dimensions for counts and alphas
counts_bin = {ifo: [c for c in zip(*counts_all[ifo])] for ifo in args.ifos}
alphas_bin = {ifo: [a for a in zip(*alphas_all[ifo])] for ifo in args.ifos}

alphas_out = {ifo: np.zeros(len(alphas_bin[ifo])) for ifo in args.ifos}
counts_out = {ifo: np.inf * np.ones(len(counts_bin[ifo])) for ifo in args.ifos}
cons_alphas_out = {ifo: np.zeros(len(alphas_bin[ifo])) for ifo in args.ifos}
cons_counts_out = {ifo: np.inf * np.ones(len(alphas_bin[ifo])) for ifo in args.ifos}

logging.info("Writing results")

fout = HFile(args.output, 'w')
fout.attrs['conservative_percentile'] = args.conservative_percentile
fout.attrs['ifos'] = args.ifos
fout['bins_edges'] = list(bl) + [bu[-1]]
fout['fits_dates'] = ad + start_time_n

for ifo in args.ifos:
    logging.info(ifo)
    fout_ifo = fout.create_group(ifo)
    fout_ifo.attrs['fit_threshold'] = fit_thresh[ifo]
    fout_ifo.attrs['fit_function'] = fit_func[ifo]
    l_times = np.array(live_times[ifo])
    total_time = l_times.sum()
    fout_ifo.attrs['live_time'] = total_time

    fout_ifo['separate_fits/live_times'] = l_times[ad_order]
    fout_ifo['separate_fits/start_time'] = trigger_file_starts[ad_order]
    fout_ifo['separate_fits/end_time'] = trigger_file_ends[ad_order]

    for counter, (a, c) in enumerate(zip(alphas_bin[ifo],
                                         counts_bin[ifo])):
        # Sort alpha and counts by date
        a = np.array(a)[ad_order]
        c = np.array(c)[ad_order]

        # ignore anything with the 'invalid' salient values
        valid = c > 0

        fout_ifo[f'separate_fits/bin_{counter:d}/fit_coeff'] = a
        fout_ifo[f'separate_fits/bin_{counter:d}/counts'] = c

        if not any(valid):
            cons_alphas_out[ifo][counter] = np.nan
            alphas_out[ifo][counter] = np.nan
            cons_counts_out[ifo][counter] = np.nan
            counts_out[ifo][counter] = np.nan
            continue

        invalphan = c[valid] / a[valid]
        mean_alpha = c[valid].mean() / invalphan.mean()
        cons_alpha = np.percentile(a[valid], 100 - args.conservative_percentile)
        cons_alphas_out[ifo][counter] = cons_alpha
        alphas_out[ifo][counter] = mean_alpha

        # To get the count values, we need to convert to rates and back again
        r = c[valid]/ l_times[ad_order][valid]
        cons_rate = np.percentile(r, args.conservative_percentile)
        cons_counts_out[ifo][counter] = cons_rate * total_time
        counts_out[ifo][counter] = np.mean(r) * total_time

    # Output the mean average values
    fout_ifo['mean/fit_coeff'] = alphas_out[ifo]
    fout_ifo['mean/counts'] = counts_out[ifo]

    # Output the conservative values
    fout_ifo['conservative/fit_coeff'] = cons_alphas_out[ifo]
    fout_ifo['conservative/counts'] = cons_counts_out[ifo]

    # For the fixed version, we just set this to 1
    fout_ifo['fixed/counts'] = [1] * len(counts_out[ifo])
    fout_ifo['fixed/fit_coeff'] = [0] * len(alphas_out[ifo])

    # Take some averages for plotting and summary values
    overall_invalphan = counts_out[ifo] / alphas_out[ifo]
    overall_meanalpha = np.nanmean(counts_out[ifo]) \
        / np.nanmean(overall_invalphan)

    # Add some useful info to the output file
    fout_ifo.attrs['mean_alpha'] = overall_meanalpha
    fout_ifo.attrs['total_counts'] = np.nansum(counts_out[ifo])

fout.close()

logging.info('Done')
