#!/usr/bin/python3.13 -s

# Copyright 2020 Gareth S. Cabourn Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""Plot histograms of PyCBC Live triggers split over various parameters, and
the corresponding fits.
"""

import argparse
import logging
import numpy as np
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt

from lal import gpstime

import pycbc
from pycbc.bin_utils import IrregularBins
from pycbc.events.trigger_fits import cum_fit as eval_cum_fit
from pycbc.io.hdf import HFile


parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--trigger-fits-file", required=True,
                    help="Trigger fits file to plot")
default_plot_format = "{ifo}-TRIGGER-FITS.png"
parser.add_argument("--output-plot-name-format",
                    default=default_plot_format,
                    help="Format to save plots, must contain '{ifo}' to "
                         "indicate ifo in filename. Default: " +
                         default_plot_format)
parser.add_argument("--colormap", default="rainbow_r", choices=plt.colormaps(),
                    help="Colormap to use for choosing the colours of the "
                         "duration bin lines. Default=rainbow_r")
parser.add_argument("--log-colormap", action='store_true',
                    help="Use log spacing for choosing colormap values "
                         "based on duration bins.")
parser.add_argument("--x-lim-lower", type=float,
                    help="Add a lower limit to the x-axis of the plot")
parser.add_argument("--x-lim-upper", type=float,
                    help="Add an upper limit to the x-axis of the plot")
parser.add_argument("--y-lim-lower", type=float,
                    help="Add a lower limit to the y-axis of the plot")
parser.add_argument("--y-lim-upper", type=float,
                    help="Add an upper limit to the y-axis of the plot")

#Add some input sanitisation
args = parser.parse_args()

if '{ifo}' not in args.output_plot_name_format:
    parser.error("--output-plot-name-format must contain '{ifo}' "
                 "to indicate ifo in filename.")

pycbc.init_logging(args.verbose)

logging.info("Getting trigger fits file information")
with HFile(args.trigger_fits_file, 'r') as trfit_f:
    # Get the ifos to plot from the file
    # Check that all the ifos we want to plot are in the file
    all_ifos = trfit_f.attrs['ifos'].split(',')
    ifos = [k for k in trfit_f.keys() if not k.startswith('bins')]

    # Grab some info from the attributes
    if trfit_f.attrs['ranking_statistic'] in ['quadsum','single_ranking_only']:
        x_label = trfit_f.attrs['sngl_ranking']
    else:
        x_label = "Ranking statistic"
    fit_threshold = {ifo: trfit_f[ifo].attrs['fit_threshold'] for ifo in ifos}
    fit_function = {ifo: trfit_f[ifo].attrs['fit_function'] for ifo in ifos}
    start_time = trfit_f.attrs['fit_start_gps_time']
    end_time = trfit_f.attrs['fit_end_gps_time']
    start_str = gpstime.gps_to_utc(start_time).strftime("%Y-%m-%d %H:%M:%S")
    end_str = gpstime.gps_to_utc(end_time).strftime("%Y-%m-%d %H:%M:%S")

    # Get the triggers for each detector
    # (This is ones which passed the cuts in the fitting code)
    stats = {ifo: {} for ifo in ifos}
    durations = {ifo: {} for ifo in ifos}
    for ifo in ifos:
        if 'triggers' not in trfit_f[ifo]:
            continue
        stats[ifo] = trfit_f[ifo]['triggers']['stat'][:]
        durations[ifo] = trfit_f[ifo]['triggers']['template_duration'][:]
    live_time = {ifo: trfit_f[ifo].attrs['live_time'] for ifo in ifos}
    alphas = {ifo: trfit_f[ifo]['fit_coeff'][:] for ifo in ifos}
    counts = {ifo: trfit_f[ifo]['counts'][:] for ifo in ifos}
    bu = trfit_f['bins_upper'][:]
    bl = trfit_f['bins_lower'][:]

duration_bin_edges = list(bl) + [bu[-1]]
tbins = IrregularBins(duration_bin_edges)

logger = logging.getLogger()
init_level = logger.level

logging.info("Plotting fits")

for ifo in all_ifos:
    fig, ax = plt.subplots(1)
    oput_plot = args.output_plot_name_format.format(ifo=ifo)

    if ifo not in ifos or not len(stats[ifo]):
        # Plot a blank plot with a message to show it worked, but there
        # weren't any triggers
        plt.tick_params(labelcolor='none', top=False, bottom=False, left=False,
                        right=False)
        ax.text(
            0.5, 0.5,
            "No triggers above threshold in this detector",
            horizontalalignment='center',
            verticalalignment='center',
        )
        logging.info(f"Saving {oput_plot}")
        # Save initial logging level
        logger.setLevel(logging.WARNING)
        fig.savefig(oput_plot)
        logger.setLevel(init_level)

        continue

    # Keep track of some maxima for use in setting the plot limits
    maxstat = np.nanmax(stats[ifo])
    max_rate = 0

    statrange = maxstat - max(np.nanmin(stats[ifo]), fit_threshold[ifo])
    plotmax = maxstat + statrange * 0.05

    plotbins = np.linspace(fit_threshold[ifo], plotmax, 400)

    logging.info("Putting events into bins")
    event_bin = np.array([tbins[d] for d in durations[ifo]])

    for bin_num, lower_upper in enumerate(zip(tbins.lower(), tbins.upper())):
        lower, upper = lower_upper
        binlabel = f"{lower:.3g} - {upper:.3g}"

        inbin = event_bin == bin_num
        bin_prop = bin_num / len(tbins)
        bin_colour = plt.get_cmap(args.colormap)(bin_prop)

        # Skip if there are no triggers in this bin in this IFO
        if not any(inbin) or alphas[ifo][bin_num] == -1:
            ax.plot(
                [],
                [],
                linewidth=2,
                color=bin_colour,
                label=binlabel,
                alpha=0.6
            )
            ax.plot(
                [],
                [],
                "--",
                color=bin_colour,
                label="No triggers"
            )
            continue
        binned_sngl_stats = stats[ifo][event_bin == bin_num]

        # Histogram the triggers
        histcounts, edges = np.histogram(binned_sngl_stats,
                                         bins=plotbins)
        cum_rate = histcounts[::-1].cumsum()[::-1] / live_time[ifo]

        max_rate = max(max_rate, cum_rate[0])

        ecf = eval_cum_fit(
            fit_function[ifo],
            plotbins,
            alphas[ifo][bin_num],
            fit_threshold[ifo]
        )
        cum_fit = counts[ifo][bin_num] / live_time[ifo] * ecf

        ax.plot(edges[:-1], cum_rate, linewidth=2,
                color=bin_colour, label=binlabel, alpha=0.6)
        ax.plot(plotbins, cum_fit, "--", color=bin_colour,
                label=r"$\alpha = $%.2f" % alphas[ifo][bin_num])
    ax.semilogy()
    ax.grid()
    ax.set_xlim(
        fit_threshold[ifo] if args.x_lim_lower is None else args.x_lim_lower,
        plotmax if args.x_lim_upper is None else args.x_lim_upper
    )
    ax.set_ylim(
        0.5 / live_time[ifo] if args.y_lim_lower is None else args.y_lim_lower,
        1.5 * max_rate if args.y_lim_upper is None else args.y_lim_upper
    )
    ax.set_xlabel(x_label)
    ax.set_ylabel("Number of louder triggers per live time")
    title = f"{ifo} singles significance fits from\n{start_str} to {end_str}"
    ax.set_title(title)
    ax.legend(loc='center left', bbox_to_anchor=(1.01, 0.5))
    logging.info(f"Saving {oput_plot}")
    # Save initial logging level
    logger.setLevel(logging.WARNING)
    fig.savefig(oput_plot, bbox_inches="tight")
    logger.setLevel(init_level)

logging.info("Done")
