#!/usr/bin/python3.13 -s

# Copyright 2019 Gareth S. Cabourn Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.


import sys
import argparse, logging

from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt
import numpy as np

from pycbc import events, bin_utils, results
from pycbc.io import SingleDetTriggers, HFile
from pycbc.events import triggers as trigs
from pycbc.events import trigger_fits as trstats
from pycbc.events import stat as pystat
from pycbc.types.optparse import MultiDetOptionAction
from pycbc.tmpltbank import bank_conversions
import pycbc.version

parser = argparse.ArgumentParser(usage="",
    description="Plot histograms of triggers split over various parameters")
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--trigger-file", required=True,
                    help="Input hdf5 file containing single triggers. "
                         "Required")
parser.add_argument("--bank-file", default=None, required=True,
                    help="hdf file containing template parameters. Required")
parser.add_argument('--output-file', required=True,
                    help="Output image file. Required")
parser.add_argument('--bin-param', default='template_duration',
                    choices=bank_conversions.conversion_options,
                    help="Parameter for binning within plots, default "
                         "'template_duration'")
parser.add_argument('--bin-spacing', choices=['linear', 'log'], default='log',
                    help="How to space bin-param bin edges. "
                         "Choices=[linear, log], default log")
parser.add_argument('--num-bins', type=int, default=6,
                    help="Number of bins over which to split bin-param, default 6")
parser.add_argument('--max-bin-param', type=float, default=None,
                    help="Maximum allowed value of bin-param")
parser.add_argument('--split-param-one', default='eta',
                    choices=bank_conversions.conversion_options,
                    help="Parameter for splitting plot grid in y-direction, "
                         "default 'eta'")
parser.add_argument('--split-param-two', default='chi_eff',
                    choices=bank_conversions.conversion_options,
                    help="Parameter for splitting plot grid in x-direction, "
                         "default 'chi_eff'")
parser.add_argument('--split-one-nbins', default=3, type=int,
                    help="Number of split plots over split-param-one, default 3")
parser.add_argument('--split-two-nbins', default=4, type=int,
                    help="Number of split plots over split-param-two, default 4")
parser.add_argument('--split-one-spacing', choices=['linear', 'log'], default='linear',
                    help="How to space split-param-one bin edges. "
                         "Choices=[linear, log], default=linear")
parser.add_argument('--split-two-spacing', choices=['linear', 'log'], default='linear',
                    help="How to space split-param-two bin edges. "
                         "Choices=[linear, log], default=linear")
parser.add_argument('--ifo', help="Detector. Required")
parser.add_argument('--plot-max-x', type=float, default=None,
                    help="Maximum stat value to plot, if not given "
                         "1.05 * largest stat value will be used")
parser.add_argument("--veto-file",
                    help="File(s) in .xml format with veto segments to apply "
                         "to triggers before fitting")
parser.add_argument("--veto-segment-name",
                    help="Name(s) of veto segments to apply. Optional, if not "
                         "given all triggers for the given ifo will be used")
parser.add_argument("--gating-veto-windows", nargs='+',
                    action=MultiDetOptionAction,
                    help="Seconds to be vetoed before and after the central time "
                         "of each gate. Given as detector-values pairs, e.g. "
                         "H1:-1,2.5 L1:-1,2.5 V1:0,0")
parser.add_argument("--stat-fit-threshold", type=float, required=True,
                    help="Only fit triggers with statistic value above this "
                         "threshold. Required")
parser.add_argument("--plot-lower-stat-limit", type=float, required=True,
                    help="Plot triggers down to this value. Setting this too"
                         "low will incur huge memory usage in a full search."
                         "To avoid this, choose 5.5 or larger.")
parser.add_argument("--fit-function",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit")
parser.add_argument("--prune-number", type=int, default=0,
                    help="Number of loudest events to remove from each split "
                         "histogram, default 0")
parser.add_argument("--prune-window", type=float, default=0.1,
                    help="Time (s) to remove all triggers around a trigger "
                         "which is loudest in each split, default 0.1s")

pystat.insert_statistic_option_group(parser,
    default_ranking_statistic='single_ranking_only')
args = parser.parse_args()

assert(args.stat_fit_threshold >= args.plot_lower_stat_limit)

pycbc.init_logging(args.verbose)

logging.info('Opening trigger file: %s' % args.trigger_file)
trigf = HFile(args.trigger_file, 'r')

logging.info('Opening template file: %s' % args.bank_file)
bank = HFile(args.bank_file, 'r')
logging.info('Getting template bank parameters')

# Calculate params needed
usedparams = [args.split_param_two, args.split_param_one, args.bin_param]

params = {}
for par in usedparams:
    if par in ['template_duration', 'duration']:
        logging.info('Reading duration from trigger file')
        # Loop over template regions, return the first duration value from each region.
        dur = []
        for ref in trigf[args.ifo + '/template_duration_template'][:]:
            dur.append(trigf[args.ifo + '/template_duration'][ref][0])
        params['template_duration'] = np.array(dur)
    else:
        logging.info("Calculating %s from template parameters", par)
        params[par] = bank_conversions.get_bank_property(
            par, bank, np.arange(bank['mass1'].size))

bank.close()

logging.info('setting up %s bins', ', '.join(usedparams))
sp_one_bin_input = (params[args.split_param_one].min(),
                    params[args.split_param_one].max(), args.split_one_nbins)
sp_two_bin_input = (params[args.split_param_two].min(),
                    params[args.split_param_two].max(), args.split_two_nbins)
if args.max_bin_param:
    logging.info(
        'setting maximum %s value: %.3f',
        args.bin_param,
        args.max_bin_param
    )
    pbin_upper_lim = float(args.max_bin_param)
else:
    pbin_upper_lim = params[args.bin_param].max()

# For templates with no triggers, the duration will be read as zero
if args.bin_param == 'template_duration' and params[args.bin_param].min() == 0:
    # Accessing the 0th entry of an empty region reference will return
    # zero due to a quirk of h5py.
    logging.warning('WARNING: Some templates do not contain triggers')
    # Use the lowest nonzero template duration as lower limit for bins
    pbin_lower_lim = params[args.bin_param][params[args.bin_param] > 0].min()
else:
    pbin_lower_lim = params[args.bin_param].min()

bb_input = (pbin_lower_lim, pbin_upper_lim, args.num_bins)

logging.info('splitting %s into bins', args.bin_param)
if args.bin_spacing == 'log':
    assert pbin_lower_lim > 0
    pbins = bin_utils.LogarithmicBins(*bb_input)
else:
    pbins = bin_utils.LinearBins(*bb_input)
# Use sentinel value -1 for templates outside range
pind = np.array([pbins[par] if pbin_lower_lim < par < pbin_upper_lim
                 else -1 for par in params[args.bin_param]])

if args.split_one_spacing == 'log':
    assert params[args.split_param_one].min() > 0
    sp_one_bounds = bin_utils.LogarithmicBins(*sp_one_bin_input)
else:
    sp_one_bounds = bin_utils.LinearBins(*sp_one_bin_input)

if args.split_two_spacing == 'log':
    assert params[args.split_param_two].min() > 0
    sp_two_bounds = bin_utils.LogarithmicBins(*sp_two_bin_input)
else:
    sp_two_bounds = bin_utils.LinearBins(*sp_two_bin_input)

logging.info('assigning template ids to different splits')
id_in_bin1 = [[]] * args.split_one_nbins
id_in_bin2 = [[]] * args.split_two_nbins
for i, lower_1, upper_1 in zip(range(args.split_one_nbins),
                               sp_one_bounds.lower(), sp_one_bounds.upper()):
    id_in_bin1[i] = np.intersect1d(np.argwhere(params[args.split_param_one] > lower_1),
                                   np.argwhere(params[args.split_param_one] <= upper_1))

for i, lower_2, upper_2 in zip(range(args.split_two_nbins),
                               sp_two_bounds.lower(), sp_two_bounds.upper()):
    id_in_bin2[i] = np.intersect1d(np.argwhere(params[args.split_param_two] > lower_2),
                                   np.argwhere(params[args.split_param_two] <= upper_2))

logging.info('Getting template boundaries from trigger file')
boundaries = trigf[args.ifo + '/template_boundaries'][:]

trigf.close()

logging.info('Calculating single stat values from trigger file')
trigs = SingleDetTriggers(
    args.trigger_file,
    args.ifo,
    filter_rank=args.sngl_ranking,
    filter_threshold=args.plot_lower_stat_limit,
)
# This is the direct pointer to the HDF file, used later on
trigf = trigs.trigs_f

stat = trigs.get_ranking(args.sngl_ranking)
time = trigs.end_time

logging.info('Processing template boundaries')
max_boundary_id = np.argmax(boundaries)
sorted_boundary_list = np.sort(boundaries)

# In the two blocks of code that follows we are trying to figure out the index
# ranges in the masked trigger lists corresponding to the "boundaries".
# We will do this by walking over the boundaries in the order they're
# stored in the file, adding in the number of triggers not removed by the
# mask every time.

# First step is to loop over the "boundaries" which gives the start position
# of each block of triggers (corresponding to one template) in the full trigger
# merge file.
# Here we identify the end_idx for the triggers corresponding to each template.
where_idx_end = np.zeros_like(boundaries)
for idx, idx_start in enumerate(boundaries):
    if idx == max_boundary_id:
        where_idx_end[idx] = trigf[args.ifo + '/end_time'].size
    else:
        where_idx_end[idx] = sorted_boundary_list[
            np.argmax(sorted_boundary_list == idx_start) + 1]

# Next we need to map these start/stop indices in the full file, to the start
# stop indices in the masked list of triggers. We do this by figuring out
# how many triggers are in the masked list for each template in the order they
# are stored in the trigger merge file, and keep a running sum.
curr_count = 0
mask_start_idx = np.zeros_like(boundaries)
mask_end_idx = np.zeros_like(boundaries)
for idx_start in sorted_boundary_list:
    boundary_idx = np.argmax(boundaries == idx_start)
    idx_end = where_idx_end[boundary_idx]
    mask_start_idx[boundary_idx] = curr_count
    curr_count += np.sum(trigs.mask[idx_start:idx_end])
    mask_end_idx[boundary_idx] = curr_count


if args.veto_file:
    logging.info('Applying DQ vetoes')
    remove, junk = events.veto.indices_within_segments(
        time,
        [args.veto_file],
        ifo=args.ifo,
        segment_name=args.veto_segment_name
    )
    # Set stat to zero for triggers being vetoed: given that the fit threshold
    # is >0 these will not be fitted or plotted.  Avoids complications from
    # changing the number of triggers, ie changes of template boundary.
    stat[remove] = 0.
    time[remove] = 0.
    logging.info(
        '%d out of %d trigs removed after vetoing with %s from %s',
        remove.size,
        stat.size,
        args.veto_segment_name,
        args.veto_file
    )

if args.gating_veto_windows:
    logging.info('Applying veto to triggers near gates')
    gating_veto = args.gating_veto_windows[args.ifo].split(',')
    gveto_before = float(gating_veto[0])
    gveto_after = float(gating_veto[1])
    if gveto_before > 0 or gveto_after < 0:
        raise ValueError("Gating veto window values must be negative before "
                         "gates and positive after gates.")
    if not (gveto_before == 0 and gveto_after == 0):
        autogate_times = np.unique(trigf[args.ifo + '/gating/auto/time'][:])
        if args.ifo + '/gating/file' in trigf:
            detgate_times = trigf[args.ifo + '/gating/file/time'][:]
        else:
            detgate_times = []
        gate_times = np.concatenate((autogate_times, detgate_times))
        gveto_remove = events.veto.indices_within_times(
            time,
            gate_times + gveto_before,
            gate_times + gveto_after
        )
        stat[gveto_remove] = 0.
        time[gveto_remove] = 0.
        logging.info(
            '%d out of %d trigs removed after vetoing triggers near gates',
            gveto_remove.size,
            stat.size
        )

for x in range(args.split_one_nbins):
    if not args.prune_number:
        logging.info('Not performing any pruning')
        break
    elif args.prune_number and x == 0:
        logging.info('Applying pruning around loudest triggers in each split')
    id_bin1 = id_in_bin1[x]
    for y in range(args.split_two_nbins):
        id_bin2 = id_in_bin2[y]
        # Finding ids of templates in split
        id_in_both = np.intersect1d(id_bin1, id_bin2)
        if len(id_in_both) == 0: continue
        vals_inbin = []
        time_inbin = []
        # getting triggers that are in these templates
        for idx in id_in_both:
            vals_inbin += list(stat[mask_start_idx[idx]:mask_end_idx[idx]])
            time_inbin += list(time[mask_start_idx[idx]:mask_end_idx[idx]])

        vals_inbin = np.array(vals_inbin)
        time_inbin = np.array(time_inbin)

        count_pruned = 0
        logging.info(
            'Pruning in split %s-%i %s-%i',
            args.split_param_one, x,
            args.split_param_two, y
        )
        logging.info('Currently have %d triggers', len(vals_inbin))
        while count_pruned < args.prune_number:
            # Getting loudest statistic value in split
            max_val_arg = vals_inbin.argmax()
            max_statval = vals_inbin[max_val_arg]

            remove = np.nonzero(
                abs(time_inbin[max_val_arg] - time) < args.prune_window
            )[0]
            # Remove from inbin triggers as well in case there
            # are more pruning iterations
            remove_inbin = np.nonzero(
                abs(time_inbin[max_val_arg] - time_inbin) < args.prune_window
            )[0]
            logging.info(
                'Prune %d: removing %d triggers around %.2f, %d in this split',
                count_pruned,
                remove.size,
                time[max_val_arg],
                remove_inbin.size
            )
            # Set pruned triggers' stat values to zero, as above for vetoes
            vals_inbin[remove_inbin] = 0.
            time_inbin[remove_inbin] = 0.
            stat[remove] = 0.
            time[remove] = 0.
            count_pruned += 1

logging.info('Setting up plotting and fitting limit values')
minplot = max(stat[np.nonzero(stat)].min(), args.plot_lower_stat_limit)
min_fit = max(minplot, args.stat_fit_threshold)
max_fit = 1.05 * stat.max()
if args.plot_max_x:
    maxplot = args.plot_max_x
else:
    maxplot = max_fit
fitrange = np.linspace(min_fit, max_fit, 100)

logging.info('Setting up plotting variables')
histcolors = ['r',(1.0,0.6,0),'y','g','c','b','m','k',(0.8,0.25,0),(0.25,0.8,0)]
fig, axes = plt.subplots(args.split_one_nbins, args.split_two_nbins,
                         sharex=True, sharey=True, squeeze=False,
                         figsize=(3 * (args.split_two_nbins + 1),
                                  3 * args.split_one_nbins))

# Setting up overall legend outside the split-up plots
lines = []
labels = []
for i, lower, upper in zip(range(args.num_bins), pbins.lower(), pbins.upper()):
    binlabel = f"{lower:#.3g} - {upper:#.3g}"
    line, = axes[0,0].plot([0,0], [0,0], linewidth=2,
                     color=histcolors[i], alpha=0.6)
    lines.append(line)
    labels.append(binlabel)

line_fit, = axes[0,0].plot([0,0], [0,0], linestyle='--',
                           color='k', alpha=0.6)
lines.append(line_fit)
labels.append(args.fit_function + ' fit to counts')
fig.legend(lines, labels, labelspacing=0.2,
           loc='upper left', title=args.bin_param)

pidx = []
for i in range(args.num_bins):
    pidx.append([np.argwhere(pind == i)])

logging.info('Starting bin, histogram and plot loop')
maxyval = 0
for x in range(args.split_one_nbins):
    id_bin1 = id_in_bin1[x]
    for y in range(args.split_two_nbins):
        id_bin2 = id_in_bin2[y]
        id_in_both = np.intersect1d(id_bin1, id_bin2)
        logging.info(
            'Split %s-%i %s-%i',
            args.split_param_one, x,
            args.split_param_two, y,
        )
        ax = axes[x,y]
        for i, lower, upper in zip(range(args.num_bins), pbins.lower(),
                                   pbins.upper()):
            indices_all_conditions = np.intersect1d(pidx[i], id_in_both)
            logging.info('%s split %#.3g-%#.3g', args.bin_param, lower, upper)
            if len(indices_all_conditions) == 0: continue
            vals_inbin = []
            for idx in indices_all_conditions:
                vals_inbin += list(stat[mask_start_idx[idx]:mask_end_idx[idx]])

            vals_inbin = np.array(vals_inbin)
            vals_above_thresh = vals_inbin[vals_inbin >= args.stat_fit_threshold]
            if not len(vals_above_thresh):
                logging.info('No triggers above threshold')
                continue
            else:
                logging.info('%d triggers out of %d above threshold',
                             len(vals_above_thresh), len(vals_inbin))
            alpha, sig_alpha = trstats.fit_above_thresh(args.fit_function,
                                 vals_above_thresh, args.stat_fit_threshold)
            fitted_cum_counts = len(vals_above_thresh) * \
                                trstats.cum_fit(args.fit_function, fitrange,
                                                alpha, args.stat_fit_threshold)
            # upper and lower 1-sigma bounds on fit are not currently plotted
            #fitted_cum_counts_plus = len(vals_above_thresh) * \
            #                             trstats.cum_fit(args.fit_function,
            #                                          fitrange, alpha + sig_alpha,
            #                                          args.stat_fit_threshold)
            #fitted_cum_counts_minus = len(vals_above_thresh) * \
            #                              trstats.cum_fit(args.fit_function, fitrange,
            #                                          alpha - sig_alpha,
            #                                          args.stat_fit_threshold)

            # Histogram of fitted values
            histcounts, edges = np.histogram(vals_inbin, bins=50)
            cum_counts = histcounts[::-1].cumsum()[::-1]
            # Plot the lines!
            ax.semilogy(edges[:-1], cum_counts, linewidth=2,
                     color=histcolors[i], alpha=0.6)
            ax.semilogy(fitrange, fitted_cum_counts, "--", color=histcolors[i],
                    label=f"$\\alpha = ${alpha:#.3f}" % alpha)
            lgd_sub = ax.legend(fontsize='small', framealpha=0.5)

            del vals_inbin, vals_above_thresh

        maxyval = max(maxyval, cum_counts.max())
        ax.grid()

for i in range(args.split_one_nbins):
    for j in range(args.split_two_nbins):
        axes[i,j].semilogy([args.stat_fit_threshold, args.stat_fit_threshold],
                           [1, 5 * maxyval], 'k', linestyle=':', alpha=0.2)
axes[0,0].set_ylim(1, 5 * maxyval)
axes[0,0].set_xlim(minplot, maxplot)

for j in range(args.split_two_nbins):
    axes[args.split_one_nbins - 1, j].set_xlabel(args.sngl_ranking, size="large")
    if args.split_two_nbins == 1:
        break
    xrange_string = f'{sp_two_bounds.lower()[j]:#.3g} to {sp_two_bounds.upper()[j]:#.3g}'
    axes[0, j].set_xlabel(f'{args.split_param_two}: {xrange_string}', size="large")
    axes[0, j].xaxis.set_label_position("top")

for i in range(args.split_one_nbins):
    if args.split_one_nbins == 1:
        axes[0, 0].set_ylabel('Cumulative number', size='large')
        break
    yrange_string = f'{sp_one_bounds.lower()[i]:#.3g} to {sp_one_bounds.upper()[i]:#.3g}'
    axes[i, 0].set_ylabel(f'{args.split_param_one}: {yrange_string}\ncumulative number',
        size="large")

fig.tight_layout(rect=(1./(args.split_two_nbins+1), 0, 1, 1))

logging.info('Saving to file %s', args.output_file)
results.save_fig_with_metadata(
    fig, args.output_file,
    title="{}: {} histogram of single detector triggers split by"
          " {} and {}".format(args.ifo, args.sngl_ranking, args.split_param_one,
                              args.split_param_two),
    caption=(r"Histogram of {} single detector {} values binned by {}, split by "
             "{} and {}, with fitted {} distribution parameterized by"
             " &alpha;".format(args.ifo, args.sngl_ranking, args.bin_param,
                               args.split_param_one, args.split_param_two,
                               args.fit_function)),
    cmd=" ".join(sys.argv)
)
logging.info('Done!')
