#!/usr/bin/python3.13 -s

# Copyright 2016 Thomas Dent, Alex Nitz, Gareth Cabourn Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.


import sys, argparse, logging, numpy
from scipy.stats import norm

import pycbc
from pycbc.events import triggers
from pycbc.io import HFile
from pycbc import init_logging

def dist(i1, i2, parvals, smoothing_width):
    """
    Computes the vector of parameter values at index/indices i1 and
    index/indices i2, and gives the Euclidean distance between
    the two with a metric of 1/(smoothing width^2)
    """
    dsq = 0
    for v, s in zip(parvals, smoothing_width):
        dsq += (v[i2] - v[i1]) ** 2.0 / s ** 2.0
    return dsq ** 0.5


def smooth_templates(nabove, invalphan, ntotal, template_idx,
                     weights=None):
    """
    Find the smoothed values according to the specified templates,
    weighted appropriately.
    The max likelihood fit for 1/alpha is linear in the trigger
    statistic values, so we perform a possibly-weighted average
    of (n_above / alpha) over templates and then invert this
    and multiply by (smoothed) nabove to obtain smoothed alpha.

    Parameters
    ----------
    nabove: ndarray
        The array of counts of triggers above threshold for all templates
    invalphan: ndarray
        The array of n_above / alpha values for all templates
    ntotal: ndarray
        The array of count of triggers in the template, both above and
        below threshold
    template_idx: ndarray of ints
        The indices of the templates to be used for the smoothing

    Optional Parameters
    -------------------
    weights: ndarray
        Weighting factor to apply to the templates specified by template_idx

    Returns
    -------
    tuple: 3 floats
        First float: the smoothed count above threshold value
        Second float: the smoothed fit coefficient (alpha) value
        Third float: the smoothed total count in template value

    """
    if weights is None: weights = numpy.ones_like(template_idx)
    nabove_t_smoothed = numpy.average(nabove[template_idx], weights=weights)
    ntotal_t_smoothed = numpy.average(ntotal[template_idx], weights=weights)
    invalphan_mean = numpy.average(invalphan[template_idx], weights=weights)

    return_tuple = (nabove_t_smoothed,
                    nabove_t_smoothed / invalphan_mean,
                    ntotal_t_smoothed)
    return return_tuple


def smooth_tophat(nabove, invalphan, ntotal, dists):
    """
    Smooth templates using a tophat function with templates within unit
    dists
    """
    idx_within_area = numpy.flatnonzero(dists < 1.)
    return smooth_templates(nabove,
                            invalphan,
                            ntotal,
                            idx_within_area)


# This is the default number of triggers required for n_closest smoothing
_default_total_trigs = 500


def smooth_n_closest(nabove, invalphan, ntotal, dists,
                     total_trigs=_default_total_trigs):
    """
    Smooth templates according to the closest N templates
    No weighting is applied
    """
    dist_sort = numpy.argsort(dists)
    templates_required = 0
    n_triggers = 0
    # Count number of templates required to gather total_trigs templates,
    # starting at closest
    ntcs = nabove[dist_sort].cumsum()
    templates_required = numpy.searchsorted(ntcs, total_trigs) + 1
    logging.debug("%d template(s) required to obtain %d(>%d) triggers",
                  templates_required, ntcs[templates_required - 1], total_trigs)
    idx_to_smooth = dist_sort[:templates_required]
    return smooth_templates(nabove, invalphan, ntotal, idx_to_smooth)


def smooth_distance_weighted(nabove, invalphan, ntotal, dists):
    """
    Smooth templates weighted according to dists in a unit-width normal
    distribution, truncated at three sigma
    """
    idx_within_area = numpy.flatnonzero(dists < 3.)
    weights = norm.pdf(dists[idx_within_area])
    return smooth_templates(nabove, invalphan, ntotal,
                            idx_within_area, weights=weights)

_smooth_dist_func = {
    'smooth_tophat': smooth_tophat,
    'n_closest': smooth_n_closest,
    'distance_weighted': smooth_distance_weighted
}


def smooth(nabove, invalphan, ntotal, dists, smoothing_method, **kwargs):
    """
    Wrapper for smoothing according to a function defined by smoothing_method

    nabove, invalphan, ntotal are as defined in the above smooth_templates
    function docstring

    dists is an array of the distances of the templates from the
    template of interest
    """
    return _smooth_dist_func[smoothing_method](nabove, invalphan,
                                               ntotal, dists, **kwargs)

# Number of smoothing lengths around the current template where
# distances will be calculated
# n_closest has no limit as it needs to contain enough
# templates to contain n triggers, which we cannot know beforehand

_smooth_cut = {
    'smooth_tophat': 1,
    'n_closest': numpy.inf,
    'distance_weighted': 3,
}


def report_percentage(i, length):
    """
    Convenience function - report how long through the loop we are.
    Every ten percent
    Parameters
    ----------
    i: integer
        index being looped through
    length : integer
        number of loops we will go through in total
    """
    pc = int(numpy.floor(i / length * 100))
    pc_last = int(numpy.floor((i - 1) / length * 100))
    if not pc % 10 and pc_last % 10:
        logging.info(f"Template {i} out of {length} ({pc:.0f}%)")

parser = argparse.ArgumentParser(usage="",
    description="Smooth (regress) the dependence of coefficients describing "
                "single-ifo background trigger distributions on a template "
                "parameter, to suppress random noise in the resulting "
                "background model.")

pycbc.add_common_pycbc_options(parser)
parser.add_argument("--template-fit-file", required=True, nargs='+',
                    help="hdf5 file(s) containing fit coefficients for each "
                         "individual template. Can smooth over multiple "
                         "files provided they correspond to the same bank "
                         "and fitting settings. Required")
parser.add_argument("--bank-file", required=True,
                    help="hdf file containing template parameters. Required")
parser.add_argument("--output", required=True,
                    help="Location for output file containing smoothed fit "
                         "coefficients. Required")
parser.add_argument("--fit-param", nargs='+',
                    help="Parameter(s) over which to regress the background "
                         "fit coefficients. Required. Either read from "
                         "template fit file or choose from mchirp, mtotal, "
                         "chi_eff, eta, tau_0, tau_3, template_duration, "
                         "a frequency cutoff in pnutils or a frequency function"
                         "in LALSimulation. To regress the background over "
                         "multiple parameters, provide them as a list.")
parser.add_argument("--approximant", default="SEOBNRv4",
                    help="Approximant for template duration. Default SEOBNRv4")
parser.add_argument("--f-lower", type=float,
                    help="Start frequency for calculating template duration.")
parser.add_argument("--min-duration", type=float, default=0.,
                    help="Fudge factor for templates with tiny or negative "
                         "values of template_duration: add to duration values"
                         " before fitting. Units seconds.")
parser.add_argument("--log-param", nargs='+',
                    help="Take the log of the fit param before smoothing. "
                         "Must be a list corresponding to fit params.")
parser.add_argument("--smoothing-width", type=float, nargs='+', required=True,
                    help="Distance in the space of fit param values (or their"
                         " logs) to smooth over. Required. Must be a list "
                         "corresponding to fit params.")
parser.add_argument("--smoothing-method", default="smooth_tophat",
                    choices = _smooth_dist_func.keys(),
                    help="Method used to smooth the fit parameters; "
                         "'smooth_tophat' (default) finds all templates within "
                         "unit distance from the template of interest "
                         "(distance normalised by --smoothing-width). "
                         "'n_closest' adds the closest templates to "
                         "the smoothing until 500 triggers are reached. "
                         "'distance_weighted' weights the closest templates "
                         "with a normal distribution of width smoothing-width "
                         "truncated at three smoothing-widths.")
parser.add_argument("--smoothing-keywords", nargs='*',
                    help="Keywords for the smoothing function, supplied "
                         "as key:value pairs, e.g. total_trigs:500 to define "
                         "the number of templates for n_closest smoothing.")
parser.add_argument("--output-fits-by-template", action='store_true',
                    help="If given, will output the input file fits to "
                         "fit_by_template group.")
args = parser.parse_args()

if args.smoothing_keywords:
    smooth_kwargs = args.smoothing_keywords
else:
    smooth_kwargs = []

kwarg_dict = {}
for inputstr in smooth_kwargs:
    try:
        key, value = inputstr.split(':')
        kwarg_dict[key] = value
    except ValueError:
        err_txt = "--smoothing-keywords must take input in the " \
                  "form KWARG1:VALUE1 KWARG2:VALUE2 KWARG3:VALUE3 ... " \
                  "Received {}".format(' '.join(args.smoothing_keywords))
        raise ValueError(err_txt)

assert len(args.log_param) == len(args.fit_param) == len(args.smoothing_width)

init_logging(args.verbose)

analysis_time = 0
attr_dict = {}

# These end up as n_files * n_templates arrays
tid = numpy.array([], dtype=int)
nabove = numpy.array([], dtype=int)
ntotal = numpy.array([], dtype=int)
alpha = numpy.array([], dtype=float)
median_sigma = numpy.array([], dtype=float)

logging.info("Loading input template fits")

# Check on fit files having the same number of templates, as expected if
# they use the same bank
num_templates = None
for filename in args.template_fit_file:
    with HFile(filename, 'r') as fits:
        if num_templates is None:
            num_templates = fits['template_id'].size
        elif not num_templates == fits['template_id'].size:
            raise RuntimeError(
                "Input fit files correspond to different banks. "
                "This situation is not yet supported."
            )
        # get attributes from the template-level fit
        for k in fits.attrs.keys():
            if k == 'analysis_time':
                # For this attribute only, we want the mean
                analysis_time += fits.attrs['analysis_time']
                continue
            if k not in attr_dict:
                # It's the first time we encounter this attribute
                attr_dict[k] = fits.attrs[k]
            elif k == 'ifo':
                # We don't mind if this attribute is different, however the
                # output attributes will only correspond to the first file's
                # IFO. Warn if different IFOs are being used.
                if not attr_dict[k] == fits.attrs[k]:
                    logging.warning(
                        "Fit files correspond to different IFOs: %s and %s, "
                        "only %s is being used for output file attributes",
                        attr_dict[k], fits.attrs[k], attr_dict[k],
                    )
                continue
            elif not attr_dict[k] == fits.attrs[k]:
                # Check that the files are consistent with one another
                err_msg = f"Input files are not consistent, file {filename} "
                err_msg += f"has attribute {k} of {fits.attrs[k]}, whereas "
                err_msg += f"previous files have value {attr_dict[k]}"
                raise RuntimeError(err_msg)

        # get template id and template parameter values
        tid = numpy.concatenate((tid, fits['template_id'][:]))
        nabove = numpy.concatenate((nabove, fits['count_above_thresh'][:]))
        ntotal = numpy.concatenate((ntotal, fits['count_in_template'][:]))
        alpha = numpy.concatenate((alpha, fits['fit_coeff'][:]))
        try:
            median_sigma = numpy.concatenate((median_sigma, fits['median_sigma'][:]))
        except KeyError:
            logging.info('Median_sigma dataset not present in input file')
            median_sigma = None

# For an exponential fit 1/alpha is linear in the trigger statistic values
# so calculating weighted sums or averages of 1/alpha is appropriate
invalpha = 1. / alpha
invalphan = invalpha * nabove

# convert the sum above into a mean
analysis_time /= len(args.template_fit_file)

if len(args.template_fit_file) > 1:
    # From the n_templates * n_files arrays, average within each template.
    # To do this, we average the n_files occurrences which have the same tid

    # The linearity of the average means that we can do this in two steps
    # without biasing the final result.
    logging.info("Averaging within the same template for multiple files")
    tid_unique = numpy.unique(tid)

    # sort into template_id order
    tidsort = tid.argsort()

    # For each unique template id, find the range of identical template ids
    left = numpy.searchsorted(tid[tidsort], tid_unique, side='left')
    right = numpy.searchsorted(tid[tidsort], tid_unique, side='right') - 1

    # Precompute the sums so we can quickly look up differences
    nasum = nabove[tidsort].cumsum()
    invsum = invalphan[tidsort].cumsum()
    ntsum = ntotal[tidsort].cumsum()
    num = right - left

    tid = tid_unique
    nabove = (nasum[right] - nasum[left]) / num
    invalphan = (invsum[right] - invsum[left]) / num
    ntotal = (ntsum[right] - ntsum[left]) / num
    if median_sigma is not None:
        # Median sigma is a special one - we need to make sure that
        # we do not mess things up when nan values are given, so we
        # can't use the special cumsum fast option
        median_sigma = [
            numpy.nanmean(median_sigma[tidsort[l:r]])
            for l, r in zip(left, right)
        ]

if args.output_fits_by_template:
    # Store fit_by_template values for output file
    # For more than one input fit file, these values are averaged over the same
    # template in different files
    fbt_dict = {
        'count_above_thresh': nabove,
        'count_in_template': ntotal,
    }
    with numpy.errstate(invalid='ignore'):
        # If n_above is zero, then we'll get an 'invalid' warning as we
        # are dividing zero by zero. This is normal, and we'll deal with
        # those properly just below, so ignore this so people don't see
        # a warning and panic
        alpha = nabove / invalphan
    alpha[nabove == 0] = -100
    fbt_dict['fit_coeff'] = alpha

n_required = _default_total_trigs if 'total_trigs' not in kwarg_dict \
    else kwarg_dict['total_trigs']
if args.smoothing_method == 'n_closest' and n_required > nabove.sum():
    logging.warning(
        "There are %.2f triggers above threshold, not enough to give a "
        "total of %d for smoothing",
        nabove.sum(),
        n_required
    )

logging.info('Calculating template parameter values')
bank = HFile(args.bank_file, 'r')
m1, m2, s1z, s2z = triggers.get_mass_spin(bank, tid)

parvals = []
parnames = []

for param, slog in zip(args.fit_param, args.log_param):
    data = triggers.get_param(param, args, m1, m2, s1z, s2z)
    if slog in ['false', 'False', 'FALSE']:
        logging.info('Using param: %s', param)
        parvals.append(data)
        parnames.append(param)
    elif slog in ['true', 'True', 'TRUE']:
        logging.info('Using log param: %s', param)
        parvals.append(numpy.log(data))
        parnames.append(f"log({param})")
    else:
        raise ValueError("invalid log param argument, use 'true', or 'false'")

nabove_smoothed = []
alpha_smoothed = []
ntotal_smoothed = []
rang = numpy.arange(0, len(nabove))

# Handle the one-dimensional case of tophat smoothing separately
# as it is easier to optimize computational performance.
if len(parvals) == 1 and args.smoothing_method == 'smooth_tophat':
    logging.info("Using efficient 1D tophat smoothing")
    sort = parvals[0].argsort()
    parvals_0 = parvals[0][sort]

    # For each template, find the range of nearby templates which fall within
    # the chosen window.
    left = numpy.searchsorted(parvals_0, parvals[0] - args.smoothing_width[0])
    right = numpy.searchsorted(parvals_0, parvals[0] + args.smoothing_width[0]) - 1

    del parvals_0
    # Precompute the sums so we can quickly look up differences between
    # templates
    nasum = nabove.cumsum()
    invsum = invalphan.cumsum()
    ntsum = ntotal.cumsum()
    num = right - left

    logging.info("Smoothing ...")
    nabove_smoothed = (nasum[right] - nasum[left]) / num
    invmean = (invsum[right] - invsum[left]) / num
    alpha_smoothed = nabove_smoothed / invmean
    ntotal_smoothed = (ntsum[right] - ntsum[left]) / num

elif numpy.isfinite(_smooth_cut[args.smoothing_method]):
    c = _smooth_cut[args.smoothing_method]
    cut_lengths = [s * c for s in args.smoothing_width]
    # Find the "longest" dimension in cut lengths
    sort_dim = numpy.argmax([(v.max() - v.min()) / c
                              for v, c in zip(parvals, cut_lengths)])
    logging.info("Sorting / Cutting on dimension %s", parnames[sort_dim])

    # Sort parvals by the sort dimension
    par_sort = numpy.argsort(parvals[sort_dim])
    parvals = [p[par_sort] for p in parvals]

    # For each template, find the range of nearby templates which fall within
    # the chosen window.
    lefts = numpy.searchsorted(parvals[sort_dim],
            parvals[sort_dim] - cut_lengths[sort_dim])
    rights = numpy.searchsorted(parvals[sort_dim],
            parvals[sort_dim] + cut_lengths[sort_dim])
    n_removed = len(parvals[0]) - rights + lefts
    logging.info("Cutting between %d and %d templates for each smoothing",
                 n_removed.min(), n_removed.max())
    # Sort the values to be smoothed by parameter value
    logging.info("Smoothing ...")
    slices = [slice(l,r) for l, r in zip(lefts, rights)]
    for i in rang:
        report_percentage(i, rang.max())
        slc = slices[i]
        d = dist(i, slc, parvals, args.smoothing_width)

        smoothed_tuple = smooth(nabove[par_sort][slc],
                                invalphan[par_sort][slc],
                                ntotal[par_sort][slc],
                                d,
                                args.smoothing_method,
                                **kwarg_dict)
        nabove_smoothed.append(smoothed_tuple[0])
        alpha_smoothed.append(smoothed_tuple[1])
        ntotal_smoothed.append(smoothed_tuple[2])

    # Undo the sorts
    unsort = numpy.argsort(par_sort)
    parvals = [p[unsort] for p in parvals]
    nabove_smoothed = numpy.array(nabove_smoothed)[unsort]
    alpha_smoothed = numpy.array(alpha_smoothed)[unsort]
    ntotal_smoothed = numpy.array(ntotal_smoothed)[unsort]

else:
    logging.info("Smoothing ...")
    for i in rang:
        report_percentage(i, rang.max())
        d = dist(i, rang, parvals, args.smoothing_width)
        smoothed_tuple = smooth(nabove, invalphan, ntotal, d,
                                args.smoothing_method, **kwarg_dict)
        nabove_smoothed.append(smoothed_tuple[0])
        alpha_smoothed.append(smoothed_tuple[1])
        ntotal_smoothed.append(smoothed_tuple[2])

logging.info("Writing output")
outfile = HFile(args.output, 'w')
outfile['template_id'] = tid
outfile['count_above_thresh'] = nabove_smoothed
outfile['fit_coeff'] = alpha_smoothed
outfile['count_in_template'] = ntotal_smoothed
if median_sigma is not None:
    outfile['median_sigma'] = median_sigma

for param, vals, slog in zip(args.fit_param, parvals, args.log_param):
    if slog in ['false', 'False', 'FALSE']:
        outfile[param] = vals
    elif slog in ['true', 'True', 'TRUE']:
        outfile[param] = numpy.exp(vals)

if args.output_fits_by_template:
    fbt_group = outfile.create_group('fit_by_template')
    for k, v in fbt_dict.items():
        fbt_group[k] = v

# Add metadata, some is inherited from template level fit
for k, v in attr_dict.items():
    outfile.attrs[k] = v
if not analysis_time == 0:
    outfile.attrs['analysis_time'] = analysis_time

# Add a magic file attribute so that coinc_findtrigs can parse it
outfile.attrs['stat'] = attr_dict['ifo'] + '-fit_coeffs'
logging.info('Done!')
