#!/usr/bin/python3.13 -s

# Copyright 2020 Gareth S. Cabourn Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""Plot the time evolution of fit parameters of PyCBC Live triggers.
"""

import argparse
import logging
import numpy as np
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt

from lal import gpstime

import pycbc
from pycbc.io.hdf import HFile


parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--combined-fits-file", required=True,
                    help="File containing information on the combined trigger "
                         "fits.")
default_plot_format = "{ifo}-TRIGGER-FITS-{type}.png"
parser.add_argument("--output-plot-name-format",
                    default=default_plot_format,
                    help="Format to save plots, must contain '{ifo}' and "
                         "'{type}' to indicate ifo and 'alphas' or 'counts' "
                         " in filename. Default: " +
                         default_plot_format)
parser.add_argument("--ifos", nargs="+",
                    help="List of ifos to plot. If not given, use all "
                         "in the combined file.")
parser.add_argument("--colormap", default="rainbow_r", choices=plt.colormaps(),
                    help="Colormap to use for choosing the colours of the "
                         "duration bin lines. Default=rainbow_r")
parser.add_argument("--log-colormap", action='store_true',
                    help="Use log spacing for choosing colormap values "
                         "based on duration bins.")
parser.add_argument(
    '--log-counts',
    action='store_true',
    help="Plot the trigger rate above threshold on a log scale."
)
args = parser.parse_args()

if '{ifo}' not in args.output_plot_name_format or \
    '{type}' not in args.output_plot_name_format:
    parser.error("--output-plot-name-format must contain '{ifo}' and "
                 "'{type}' to indicate ifo and 'alphas' or 'counts' "
                 " in filename.")

pycbc.init_logging(args.verbose)

mean_count = {}
mean_alpha = {}
cons_count = {}
cons_alpha = {}
live_total = {}

separate_alphas = {}
separate_counts = {}
separate_starts = {}
separate_ends = {}
separate_times = {}

logging.info("Loading Data")
with HFile(args.combined_fits_file, 'r') as cff:
    ifos = args.ifos or cff.attrs['ifos']
    bins_edges = cff['bins_edges'][:]
    conservative_percentile = cff.attrs['conservative_percentile']
    n_bins = len(bins_edges) - 1
    for ifo in ifos:
        logging.info(ifo)
        live_total[ifo] = cff[ifo].attrs['live_time']
        mean_count[ifo] = cff[ifo]['mean']['counts'][:]
        mean_alpha[ifo] = cff[ifo]['mean']['fit_coeff'][:]
        cons_count[ifo] = cff[ifo]['conservative']['counts'][:]
        cons_alpha[ifo] = cff[ifo]['conservative']['fit_coeff'][:]

        separate_starts[ifo] = cff[ifo]['separate_fits']['start_time'][:]
        separate_ends[ifo] = cff[ifo]['separate_fits']['end_time'][:]
        separate_times[ifo] = cff[ifo]['separate_fits']['live_times'][:]

        separate_data = cff[ifo]['separate_fits']
        separate_alphas[ifo] = np.array([separate_data[f'bin_{i}']['fit_coeff'][:]
                                         for i in range(n_bins)])
        separate_counts[ifo] = np.array([separate_data[f'bin_{i}']['counts'][:]
                                         for i in range(n_bins)])

bin_starts = bins_edges[:-1]
bin_ends = bins_edges[1:]

# Set up the x ticks - note that these are rounded to the nearest
# midnight, so may not line up exactly with the data
min_start = min([separate_starts[ifo].min() for ifo in ifos])
max_end = max([separate_ends[ifo].max() for ifo in ifos])

xtix = []
xtix_labels = []
# start ticks at the midnight before the first datapoint
t = min_start
# Last tick will be the midnight after the last datapoint
while t < max_end + 86400:
    # Strip off the time information, ticks are at midnight
    time_dt = gpstime.gps_to_utc(t).date()
    xtix_labels.append(time_dt.strftime("%Y-%m-%d"))
    xtix.append(gpstime.utc_to_gps(time_dt).gpsSeconds)
    t += 86400

logging.info("Plotting fits information")

for ifo in ifos:
    logging.info(ifo)

    # Some things for the plots
    fig_alpha, ax_alpha = plt.subplots(1, figsize=(12, 7.5))
    fig_count, ax_count = plt.subplots(1, figsize=(12, 7.5))
    alpha_lines = []
    count_lines = []

    for i, bl_bu in enumerate(zip(bin_starts, bin_ends)):
        bl, bu = bl_bu

        alphas = separate_alphas[ifo][i]
        counts = separate_counts[ifo][i]

        # replace zeros with infinity, so that it is
        # not plotted rather than plotted as zero
        valid = np.logical_and(alphas > 0, np.isfinite(alphas))
        alphas[np.logical_not(valid)] = np.inf

        if not any(valid):
            logging.warning("No valid fit coefficients for %s", ifo)
            continue

        l_times = separate_times[ifo]
        with np.errstate(divide='ignore', invalid='ignore'):
            rate = counts / l_times

        ma = mean_alpha[ifo][i]
        ca = cons_alpha[ifo][i]
        mr = mean_count[ifo][i] / live_total[ifo]
        cr = cons_count[ifo][i] / live_total[ifo]

        bin_prop = i / len(bin_starts)
        bin_colour = plt.get_cmap(args.colormap)(bin_prop)
        bin_label = f"duration {bl:.2f}-{bu:.2f}"
        alpha_lines += ax_alpha.plot(separate_starts[ifo], alphas, c=bin_colour,
                                     label=bin_label, marker='.',
                                     markersize=10)
        alpha_lines.append(ax_alpha.axhline(ma,
                                            label="total fit = %.2f" % ma,
                                            c=bin_colour, linestyle='--',))
        alpha_lab = f"{conservative_percentile:d}th %ile = {ca:.2f}"
        alpha_lines.append(ax_alpha.axhline(ca,
                                            c=bin_colour, linestyle=':',
                                            label=alpha_lab))

        # Invalid counts
        inv_counts = rate <= 0
        rate[inv_counts] = None

        count_lines += ax_count.plot(separate_starts[ifo], rate, c=bin_colour,
                                     label=bin_label, marker='.',
                                     markersize=10)

        if mr < 1e-3:
            mlab = f"mean = {mr:.3e}"
            clab = f"{conservative_percentile:d}th %ile = {cr:.3e}"
        else:
            mlab = f"mean = {mr:.3f}"
            clab = f"{conservative_percentile:d}th %ile = {cr:.3f}"

        count_lines.append(ax_count.axhline(mr,
                                            c=bin_colour, linestyle='--',
                                            label=mlab))
        count_lines.append(ax_count.axhline(cr,
                                            c=bin_colour, linestyle=':',
                                            label=clab))

    alpha_labels = [l.get_label() for l in alpha_lines]
    ax_alpha.legend(alpha_lines, alpha_labels, loc='lower center',
                    ncol=5, bbox_to_anchor=(0.5, 1.01))
    ax_alpha.set_ylabel('Fit coefficient')

    count_labels = [l.get_label() for l in count_lines]
    if args.log_counts:
        ax_count.semilogy()
    ax_count.legend(count_lines, count_labels, loc='lower center',
                    ncol=5, bbox_to_anchor=(0.5, 1.01))
    ax_count.set_ylabel('Rate of triggers above fit threshold [s$^{-1}$]')

    for ax in [ax_count, ax_alpha]:
        ax.set_xticks(xtix)
        ax.set_xticklabels(xtix_labels, rotation=90)
        # Add 1/4 day padding either side of the lines
        ax.set_xlim(
            min(separate_starts[ifo]) - 21600,
            max(separate_starts[ifo]) + 21600
        )
        ax.grid(zorder=-30)

    fig_count.tight_layout()
    fig_count.savefig(
        args.output_plot_name_format.format(ifo=ifo, type='counts')
    )
    fig_alpha.tight_layout()
    fig_alpha.savefig(
        args.output_plot_name_format.format(ifo=ifo, type='fit_coeffs')
    )

logging.info("Done")
