#!/usr/bin/python3.13 -s

# Copyright (C) 2023 Veronica Villa-Ortega
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""
Plot histograms of triggers around gated times.
"""

import numpy
import logging
import argparse
from matplotlib import use
use('Agg')
import matplotlib.pyplot as plt
from pycbc.io.hdf import HFile

from pycbc import add_common_pycbc_options, init_logging
from pycbc.events import ranking

parser = argparse.ArgumentParser(description=__doc__)
add_common_pycbc_options(parser)
parser.add_argument('--single-trigger-files', nargs='+',
                    help='HDF format single detector merged trigger file(s) path. '
                         'All triggers are combined in the histogram')
parser.add_argument('--ifo-tag', default='',
                    help='String to label plots, eg H1, H1L1')
parser.add_argument('--gating-type', choices=['auto', 'detchar'],
                    help='Type of gate to analyze.')
parser.add_argument('--window', default=5, type=float,
                    help='Time in seconds around the gate to plot (default 5s).')
parser.add_argument('--snr-cut-type', choices=['snr', 'newsnr'],
                    help='Type of snr threshold to apply.')
parser.add_argument('--snr-cut-vals', nargs='+', type=float,
                    help='Threshold snr/newsnr value(s) for histogram(s).')
parser.add_argument('--template-duration', nargs='+',
                    help='Triggers with a template duration greater than '
                         '(GT) or less than (LT) some time value in seconds. '
                         'Usage: --template-duration GT/LT time value')
parser.add_argument('--bin-duration', type=float, default=0.25,
                    help='Time in seconds represented by each bin in the histogram '
                          '(default 0.25s).')
parser.add_argument('--log-y', action='store_true',
                    help='Set y-axis scale logarithmic.')
parser.add_argument('--output-file',
                    help='Plot file name (optional).')
args = parser.parse_args()

init_logging(args.verbose)

if args.gating_type:
    if args.gating_type == 'auto':
        gating_type = 'auto'
    elif args.gating_type == 'detchar':
        gating_type = 'file'
else:
    raise ValueError('A gating type argument must be provided.')

if args.snr_cut_type:
    if args.snr_cut_vals:
        snr_cut_vals = sorted(args.snr_cut_vals)
    else:
        raise ValueError('At least one SNR cut value must be provided!')
else:
    snr_cut_vals = numpy.array([0])

if args.template_duration:
    args.duration_ltgt = args.template_duration[0]
    assert args.duration_ltgt in ['LT', 'GT']
    args.duration_cut = float(args.template_duration[1])

# Binning along trigger time
nbins = int(2 * args.window / args.bin_duration + 1)

ifo_trigger_times = [[] for thr in snr_cut_vals]
n_gates = 0

for single_file in args.single_trigger_files:
    logging.info("Reading trigger file %s...", single_file.split('/')[-1])
    file_ = HFile(single_file, "r")
    ifo = str(list(file_.keys())[0])
    data = file_[ifo]

    if gating_type in data['gating'].keys():
        logging.info("Extracting gated times...")
        gate_time = numpy.unique(data[f'gating/{gating_type}/time'][:])
        n_gates += len(gate_time)
        logging.info("%i %s gates found", len(gate_time), args.gating_type)
    else:
        logging.info('Trigger file does not contain %s gates. '
                     'Skipping file...', args.gating_type)
        continue
    logging.info("Extracting trigger times...")
    times = data['end_time'][:]

    if args.template_duration:
        logging.info("Cutting on template duration...")
        template_duration = data['template_duration'][:]
        if args.duration_ltgt == 'LT':
            times = times[template_duration <= args.duration_cut]
        else:
            times = times[template_duration >= args.duration_cut]

    if args.snr_cut_type:
        logging.info("Cutting on %s...", args.snr_cut_type)
        if args.snr_cut_type == 'snr':
            snr_val = data['snr'][:]
        if args.snr_cut_type == 'newsnr':
            rchisq = data['chisq'][:] / (2 * data['chisq_dof'][:] - 2)
            if len(rchisq) > 0:
                snr_val = ranking.newsnr(data['snr'][:], rchisq)
            else:
                snr_val = numpy.array([])
            del rchisq
        if args.template_duration:
            if args.duration_ltgt == 'LT':
                snr_val = snr_val[template_duration <= args.duration_cut]
            else:
                snr_val = snr_val[template_duration >= args.duration_cut]

    logging.info("Starting loop over gates...")
    time_triggers = [[] for thr in snr_cut_vals]
    for i, thr in enumerate(snr_cut_vals):
        if args.snr_cut_type:
            times = times[snr_val >= thr]
            snr_val = snr_val[snr_val >= thr]
        for gt in gate_time:
            idx = numpy.logical_and(times < gt + args.window, times > gt - args.window)
            time_trigger = times[idx] - gt
            time_triggers[i].extend(time_trigger)
        ifo_trigger_times[i].extend(time_triggers[i])

    file_.close()

if args.output_file:
    out_name = args.output_file
else:
    out_name = '%s_%s' % (str(args.gating_type).upper(), args.ifo_tag)

    if args.template_duration:
        templ_cut_name = (
            str("{:.1f}".format(args.duration_cut)).replace('.', '-') 
                if '.' in args.template_duration[1]
                else args.template_duration[1]
        )
        out_name = '_'.join([out_name, 'TD-%s-%s' % (args.duration_ltgt, templ_cut_name)])

    if args.snr_cut_type:
        snr_cut_name = (args.snr_cut_type).upper()
        snr_cuts = [
            str("{:.1f}".format(cut_val)).replace('.', '-') if '.' in str(cut_val)
            else str(cut_val) for cut_val in snr_cut_vals
        ]
        snr_cut_value = '_'.join(snr_cuts)
        out_name = '_'.join([out_name, '%s-%s' % (snr_cut_name, snr_cut_value)])

    window_name = (str("{:.1f}".format(args.window)).replace('.', '-') if '.' in
                   str(args.window) else str(args.window))
    out_name = '_'.join([out_name, 'WINDOW-%s' % window_name])

    if args.log_y:
        out_name = '_'.join([out_name, 'LOG'])

    out_name = out_name + '.png'  # Filenames via command line may have any extension

if n_gates > 0:
    for i, thr in enumerate(snr_cut_vals):
        plt.hist(
            ifo_trigger_times[i],
            bins=numpy.linspace(-args.window, args.window, nbins),
            label=['%s > %.1f' % (args.snr_cut_type, thr) if args.snr_cut_type else '']
        )
    if args.log_y:
        plt.yscale('log')
    plt.grid(True)
    plt.xlabel('Time relative to gate centre (s)')
    plt.ylabel('Number of triggers')
    plt.title(  # Complicated recipe
         '%s - %i %s GATES %s %s' % 
         (args.ifo_tag, n_gates, str(args.gating_type).upper(),
          'IN %i CHUNKS' % len(args.single_trigger_files) if \
          len(args.single_trigger_files) > 1 else '',
          '\n Triggers with template duration %s %s s' % 
          ('<' if args.duration_ltgt == 'LT' else '>', str(args.duration_cut))
          if args.template_duration else '')
    )
    if args.snr_cut_type:
        plt.legend()

else:  # Nothing to plot, make an empty figure
    fig = plt.figure()
    ax = fig.add_subplot(111)
    output_message = "Sorry, no gates to plot!"
    ax.text(0.5, 0.5, output_message, horizontalalignment='center',
            verticalalignment='center')

logging.info('Saving histogram...')
plt.savefig(out_name)
plt.close()

logging.info("Done")
