#!/usr/bin/python3.13 -s

# Copyright 2023 Jam Sadiq, Praveen Kumar
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

import numpy, operator, argparse, logging
from pycbc import init_logging, add_common_pycbc_options
import pycbc.conversions as convert
from pycbc import libutils
from pycbc.events import triggers
from pycbc.io import HFile
akde = libutils.import_optional('awkde')
kf = libutils.import_optional('sklearn.model_selection')

parser = argparse.ArgumentParser(description=__doc__)
add_common_pycbc_options(parser)
parser.add_argument('--signal-file', help='File with parameters of GW signals '
                    'for KDE calculation')
parser.add_argument('--template-file', required=True, help='Hdf5 file with '
                    'template masses and spins')
parser.add_argument('--injection-file', help='Hdf5 file with masses and spins')
parser.add_argument('--min-mass', type=float, default=None,
                    help='Used only on signal masses: remove all' 
                         'signal events with mass2 < min_mass')
parser.add_argument('--min-snr', type=float, default=None,
                    help='Used only in injections case: remove all'
                         'injection events < min_snr')
parser.add_argument('--nfold-signal', type=int,
                    help='Number of k-folds for signal KDE cross validation')
parser.add_argument('--nfold-template', type=int,
                    help='Number of k-folds for template KDE cross validation')
parser.add_argument('--nfold-injection', type=int,
                    help='Number of k-folds for injection KDE cross validation')
parser.add_argument('--fit-param', nargs='+', required=True,
                    help='Parameters over which KDE is calculated')
parser.add_argument('--log-param', nargs='+', choices=['True', 'False'], 
                    required=True)
parser.add_argument('--output-file', required=True, help='Name of .hdf output')
parser.add_argument('--make-signal-kde', action='store_true')
parser.add_argument('--make-template-kde', action='store_true')
parser.add_argument('--make-injection-kde', action='store_true')
parser.add_argument('--fom-plot', help='Make a FOM plot for cross-validation'
                    ' and save it as this file')
parser.add_argument('--alpha-grid', type=float, nargs="+",
                    help='Grid of choices of sensitivity parameter alpha for'
                         ' local bandwidth')
parser.add_argument('--bw-grid', type=float, nargs='+', 
                    help='Grid of choices of global bandwidth')
parser.add_argument('--extra-cpt-fraction', type=float,
                    help='Fraction of the extra component in the signal density')
parser.add_argument('--temp-volume', type=float,
                    help='Volume covered by the template bank')
parser.add_argument('--seed', type=int,
                    help='Random number generator seed')
parser.add_argument('--mchirp-downsample-power', type=float,
                    help='Exponent value for the power law distribution')
parser.add_argument('--min-ratio', type=float, 
                    help='Minimum ratio for template_kde relative to the maximum')
args = parser.parse_args()
init_logging(args.verbose)


assert len(args.fit_param) == len(args.log_param)
if args.make_signal_kde + args.make_template_kde + args.make_injection_kde != 1:
    parser.error("Choose exactly one option out of --make-signal-kde, \
                 --make-template-kde, or --make-injection-kde")


if (args.extra_cpt_fraction is None and args.temp_volume is not None) or \
    (args.extra_cpt_fraction is not None and args.temp_volume is None):  \
    parser.error("Both --extra-cpt-fraction and --temp-volume arguments  \
                  must be provided or neither one should be provided")


def kde_awkde(x, x_grid, alp=0.5, gl_bandwidth=None, ret_kde=False):
    if gl_bandwidth is None:  # Use default from awkde
        kde = akde.GaussianKDE(alpha=alp, diag_cov=True)
    else:
        kde = akde.GaussianKDE(glob_bw=gl_bandwidth, alpha=alp, diag_cov=True)

    kde.fit(x)
    y = kde.predict(x_grid)

    if ret_kde == True:
        return kde, y
    return y


def kfcv_awkde(sample, bwchoice, alphachoice, k=2):
    """
    Evaluate the K-fold cross validated log likelihood for an awKDE with
    specific bandwidth and sensitivity (alpha) parameters
    """
    fomlist = []
    kfold = kf.KFold(n_splits=k, shuffle=True, random_state=None)
    for train_index, test_index in kfold.split(sample):
        train, test = sample[train_index], sample[test_index]
        y = kde_awkde(train, test, alp=alphachoice, gl_bandwidth=bwchoice)
        # Figure of merit : log likelihood for training samples
        fomlist.append(numpy.sum(numpy.log(y)))

    # Return the sum over all K sets of training samples
    return numpy.sum(fomlist)


def optimizedparam(sampleval, bwgrid, alphagrid, nfold=2):
    npoints, ndim = sampleval.shape
    FOM = {}
    for gbw in bwgrid:
        for alphavals in alphagrid:
            FOM[(gbw, alphavals)] = kfcv_awkde(sampleval, gbw, alphavals,
                                               k=nfold)
    optval = max(FOM.items(), key=operator.itemgetter(1))[0]
    optbw, optalpha = optval[0], optval[1]
    maxFOM = FOM[(optbw, optalpha)]

    # Plotting FOM parameters
    if args.fom_plot:
        import matplotlib.pyplot as plt
        fig = plt.figure(figsize=(12,8))
        ax = fig.add_subplot(111)
        for bw in bwgrid:
            FOMlist = [FOM[(bw, al)] for al in alphagrid]
            ax.plot(alphagrid, FOMlist, label='{0:.3f}'.format(bw))
        ax.plot(optalpha, maxFOM, 'ko', linewidth=10, label=
                r'$\alpha={0:.3f},bw={1:.3f}$'.format(optalpha, optbw))
        ax.set_xlabel(r'$\alpha$', fontsize=15)
        ax.set_ylabel(r'$FOM$', fontsize=15)
        # Guess at a suitable range of FOM values to plot
        ax.set_ylim(maxFOM - 0.5 * npoints, maxFOM + 0.2 * npoints)
        ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.135), ncol=8)
        plt.savefig(args.fom_plot)
        plt.close()

    return optbw, optalpha


# Obtaining template parameters
temp_file = HFile(args.template_file, 'r')
mass1 = temp_file['mass1'][:]
tid = numpy.arange(len(mass1))  # Array of template ids
mass_spin = triggers.get_mass_spin(temp_file, tid)

f_dest = HFile(args.output_file, 'w')
f_dest.create_dataset("template_id", data=tid)
template_pars = []
for param, slog in zip(args.fit_param, args.log_param):
    pvals = triggers.get_param(param, args, *mass_spin)
    # Write the KDE param values to output file
    f_dest.create_dataset(param, data=pvals)
    if slog in ['False']:
        logging.info('Using param: %s', param)
        template_pars.append(pvals)
    elif slog in ['True']:
        logging.info('Using log param: %s', param)
        template_pars.append(numpy.log(pvals))
    else:
        raise ValueError("invalid log param argument, use 'True', or 'False'")

# Copy standard data to output file
f_dest.attrs['fit_param'] = args.fit_param
f_dest.attrs['log_param'] = args.log_param
with HFile(args.template_file, "r") as f_src:
    f_src.copy(f_src["./"], f_dest["./"], "input_template_params")
temp_samples = numpy.vstack((template_pars)).T


if args.make_template_kde:

    # Rejection sampling to reduce computational load
    if args.mchirp_downsample_power is not None:
        logging.info('Downsampling with mchirp power '
                     f'{args.mchirp_downsample_power}')
        f_dest.attrs['mchirp_downsample_power'] = args.mchirp_downsample_power
        try:
            mchirp_index = args.fit_param.index('mchirp')
        except:
            raise ValueError("mchirp does not exist in args.fit_param")

        mc_vals = template_pars[mchirp_index]
        if args.log_param[mchirp_index]:
            mc_vals = numpy.exp(mc_vals)
        power_vals = mc_vals ** args.mchirp_downsample_power
        probabilities = power_vals / numpy.max(power_vals)
        if args.seed is not None:
            f_dest.attrs['seed'] = args.seed
            numpy.random.seed(args.seed)
        rand_nums = numpy.random.uniform(0, 1, len(mass1))
        ind = rand_nums < probabilities
        logging.info(f'{ind.sum()} templates after downsampling')
        kde_train_samples = temp_samples[ind]
        f_dest.create_dataset("kde_train_samples", data=kde_train_samples)

        logging.info('Starting optimization of template KDE parameters')
        optbw, optalpha = optimizedparam(kde_train_samples, alphagrid=args.alpha_grid,
                                         bwgrid=args.bw_grid, nfold=args.nfold_template)
        logging.info('Bandwidth %.4f, alpha %.2f' % (optbw, optalpha))
        logging.info('Evaluating template KDE')
        template_kde = kde_awkde(kde_train_samples, temp_samples, alp=optalpha,
                                 gl_bandwidth=optbw)

        # Compensation factor for downsampling of templates
        template_kde *= 1. / probabilities
        if args.min_ratio is not None:
            logging.info(f'Applying minimum template KDE ratio {args.min_ratio}')
            f_dest.attrs['min-kde-ratio'] = args.min_ratio
            min_val = args.min_ratio * numpy.max(template_kde)
            template_kde = numpy.maximum(template_kde, min_val)

    else:
        logging.info('Starting optimization of template KDE parameters')
        optbw, optalpha = optimizedparam(temp_samples, alphagrid=args.alpha_grid,
                                         bwgrid=args.bw_grid, nfold=args.nfold_template)
        logging.info('Bandwidth %.4f, alpha %.2f' % (optbw, optalpha))
        logging.info('Evaluating template KDE')
        template_kde = kde_awkde(temp_samples, temp_samples, alp=optalpha,
                                 gl_bandwidth=optbw)

    f_dest.create_dataset("data_kde", data=template_kde)
    f_dest.attrs['stat'] = "template-kde_file"
f_dest.attrs['template-file'] = args.template_file


def signal_kde_extra_cpt(signal_kde, frac=None, volume=None):
    # Calculate the constant density value for the additional component
    const_density = 1 / volume
    # Calculate the modified signal density
    modified_kde = (1 - frac) * signal_kde + frac * const_density
    return modified_kde


# Obtaining signal parameters
if args.make_signal_kde:
    signal_pars = []
    signal_file = numpy.genfromtxt(args.signal_file, dtype=float,
                                   delimiter=',', names=True)
    f_dest.attrs['signal-file'] = args.signal_file
    mass2_sgnl = signal_file['mass2']
    N_original = len(mass2_sgnl)
    if args.min_mass:
        idx = mass2_sgnl > args.min_mass
        mass2_sgnl = mass2_sgnl[idx]
        logging.info('%i triggers out of %i with MASS2 > %s' %
                         (len(mass2_sgnl), N_original, str(args.min_mass)))
    else:
        idx = numpy.full(N_original, True)
    mass1_sgnl = signal_file['mass1'][idx]
    assert min(mass1_sgnl - mass2_sgnl) > 0
 
    for param, slog in zip(args.fit_param, args.log_param):
        pvals = signal_file[param][idx]
        if slog in ['False']:
            logging.info('Using param: %s', param)
            signal_pars.append(pvals)
        elif slog in ['True']:
            logging.info('Using log param: %s', param)
            signal_pars.append(numpy.log(pvals))
        else:
            raise ValueError("invalid log param argument, use 'True', \
                             or 'False'")
    
    signal_samples = numpy.vstack((signal_pars)).T
    f_dest.create_dataset("kde_train_samples", data=signal_samples)
    logging.info('Starting optimization of signal KDE parameters')  
    optbw, optalpha = optimizedparam(signal_samples, bwgrid=args.bw_grid,
                      alphagrid=args.alpha_grid, nfold=args.nfold_signal)
    logging.info('Bandwidth %.4f, alpha %.2f' % (optbw, optalpha))
    logging.info('Evaluating signal KDE')
    signal_kde = kde_awkde(signal_samples, temp_samples,
                           alp=optalpha, gl_bandwidth=optbw)
    if args.extra_cpt_fraction is not None and args.temp_volume is not None:
        f_dest.attrs.update({'extra-fraction-used': args.extra_cpt_fraction,
                             'volume': args.temp_volume})
        modified_kde = signal_kde_extra_cpt(signal_kde, frac=args.extra_cpt_fraction,
                                            volume=args.temp_volume)
    else:
        modified_kde = signal_kde
    f_dest.create_dataset("data_kde", data=modified_kde)
    f_dest.attrs['stat'] = "signal-kde_file"


if args.make_injection_kde:
    inj_pars = []
    inj_file = HFile(args.injection_file, 'r')
    f_dest.attrs['injection-file'] = args.injection_file
    snr = inj_file["events"]["snr_net"][:]
    N_original = len(snr)
    if args.min_snr:
        idx = snr > args.min_snr
        snr = snr[idx]
        logging.info('%i triggers out of %i with SNR > %s' %
                         (len(snr), N_original, str(args.min_snr)))
    else:
        idx = numpy.full(N_original, True)
    mass1 = inj_file["events"]["mass1_detector"][idx]
    mass2 = inj_file["events"]["mass2_detector"][idx]
    spin1z = inj_file["events"]["spin1z"][idx]
    spin2z = inj_file["events"]["spin2z"][idx]

    for param, slog in zip(args.fit_param, args.log_param):
        pvals = triggers.get_param(param, args, mass1, mass2, spin1z, spin2z)
        if slog in ['False']:
            logging.info('Using param: %s', param)
            inj_pars.append(pvals)
        elif slog in ['True']:
            logging.info('Using log param: %s', param)
            inj_pars.append(numpy.log(pvals))
        else:
            raise ValueError("invalid log param argument, use 'True', or 'False'")

    inj_samples = numpy.vstack((inj_pars)).T
    f_dest.create_dataset("kde_train_samples", data=inj_samples)
    logging.info('Starting optimization of injection KDE parameters')
    optbw, optalpha = optimizedparam(inj_samples, alphagrid=args.alpha_grid,
                                     bwgrid=args.bw_grid, nfold=args.nfold_injection)
    logging.info('Bandwidth %.4f, alpha %.2f' % (optbw, optalpha))
    logging.info('Evaluating injection KDE')
    injection_kde = kde_awkde(inj_samples, temp_samples, alp=optalpha,
                              gl_bandwidth=optbw)
    if args.extra_cpt_fraction is not None and args.temp_volume is not None:
        f_dest.attrs.update({'extra-fraction-used': args.extra_cpt_fraction,
                             'volume': args.temp_volume})
        modified_kde = signal_kde_extra_cpt(injection_kde, frac=args.extra_cpt_fraction,
                                            volume=args.temp_volume)
    else:
        modified_kde = injection_kde
    f_dest.create_dataset("data_kde", data=modified_kde)
    f_dest.attrs['stat'] = "signal-kde_file"

f_dest.attrs['alpha'] = optalpha
f_dest.attrs['bandwidth'] = optbw
f_dest.close()

logging.info('Done!')
