#!/usr/bin/python3.11

# Copyright 2020 Gareth S. Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

import h5py, numpy as np, argparse
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt
import logging
from lal import gpstime
import pycbc



parser = argparse.ArgumentParser(usage="",
    description="Combine fitting parameters from several different files")
parser.add_argument("--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--combined-fits-file", required=True,
                    help="File containing information on the combined trigger "
                         "fits.")
default_plot_format = "{ifo}-TRIGGER-FITS-{type}.png"
parser.add_argument("--output-plot-name-format",
                    default=default_plot_format,
                    help="Format to save plots, must contain '{ifo}' and "
                         "'{type}' to indicate ifo and 'alphas' or 'counts' "
                         " in filename. Default: " +
                         default_plot_format)
parser.add_argument("--ifos", nargs="+",
                    help="List of ifos to plot. If not given, use all "
                         "in the combined file.")
parser.add_argument("--colormap", default="rainbow_r", choices=plt.colormaps(),
                    help="Colormap to use for choosing the colours of the "
                         "duration bin lines. Default=rainbow_r")
parser.add_argument("--log-colormap", action='store_true',
                    help="Use log spacing for choosing colormap values "
                         "based on duration bins.")
args=parser.parse_args()

if '{ifo}' not in args.output_plot_name_format or \
    '{type}' not in args.output_plot_name_format:
    parser.error("--output-plot-name-format must contain '{ifo}' and "
                 "'{type}' to indicate ifo and 'alphas' or 'counts' "
                 " in filename.")

pycbc.init_logging(args.verbose)

mean_count = {}
mean_alpha = {}
cons_count = {}
cons_alpha = {}
live_total = {}

separate_alphas = {}
separate_counts = {}
separate_starts = {}
separate_ends = {}
separate_times = {}

logging.info("Loading Data")
with h5py.File(args.combined_fits_file, 'r') as cff:
    ifos = args.ifos or cff.attrs['ifos']
    bins_edges = cff['bins_edges'][:]
    conservative_percentile = cff.attrs['conservative_percentile']
    n_bins = len(bins_edges) - 1
    for ifo in ifos:
        logging.info(ifo)
        live_total[ifo] = cff[ifo].attrs['live_time']
        mean_count[ifo] = cff[ifo]['mean']['counts'][:]
        mean_alpha[ifo] = cff[ifo]['mean']['fit_coeff'][:]
        cons_count[ifo] = cff[ifo]['conservative']['counts'][:]
        cons_alpha[ifo] = cff[ifo]['conservative']['fit_coeff'][:]

        separate_starts[ifo] = cff[ifo]['separate_fits']['start_time'][:]
        separate_ends[ifo] = cff[ifo]['separate_fits']['end_time'][:]
        separate_times[ifo] = cff[ifo]['separate_fits']['live_times'][:]

        separate_data = cff[ifo]['separate_fits']
        separate_alphas[ifo] = np.array([separate_data[f'bin_{i}']['fit_coeff'][:]
                                         for i in range(n_bins)])
        separate_counts[ifo] = np.array([separate_data[f'bin_{i}']['counts'][:]
                                         for i in range(n_bins)])

bin_starts = bins_edges[:-1]
bin_ends = bins_edges[1:]

bin_max = max(bin_ends)
bin_min = min(bin_starts)

def bin_proportion(upper, lower, log_spacing=False):
    if log_spacing:
        ll = np.log(lower)
        ul = np.log(lower)
        centl = (ll + ul) / 2.
        minl = np.log(bin_min)
        maxl = np.log(bin_max)
        return (centl - minl) / (maxl - minl)

    else:
        return ((lower + upper) / 2. - bin_min) / (bin_max - bin_min)

# Set up the x ticks - note that these are rounded to the nearest
# midnight, so may not line up exactly with the data
min_start = min([separate_starts[ifo].min() for ifo in ifos])
max_end = max([separate_ends[ifo].max() for ifo in ifos])

xtix = []
xtix_labels = []
t = min_start
while t < max_end:
    # Strip off the time information, ticks are at midnight
    time_dt = gpstime.gps_to_utc(t).date()
    xtix_labels.append(time_dt.strftime("%Y-%m-%d"))
    xtix.append(gpstime.utc_to_gps(time_dt).gpsSeconds)
    t += 86400

logging.info("Plotting fits information")
for ifo in ifos:
    logging.info(ifo)

    # Some things for the plots
    fig_alpha, ax_alpha = plt.subplots(1, figsize=(12, 7.5))
    fig_count, ax_count = plt.subplots(1, figsize=(12, 7.5))
    alpha_lines = []
    count_lines = []

    for i, bl_bu in enumerate(zip(bin_starts, bin_ends)):
        bl, bu = bl_bu

        alphas = separate_alphas[ifo][i]
        counts = separate_counts[ifo][i]

        # replace zeros with infinity, so that it is
        # not plotted rather than plotted as zero
        valid = np.logical_and(alphas > 0, np.isfinite(alphas))
        alphas[np.logical_not(valid)] = np.inf

        if not any(valid):
            logging.warning("No valid fit coefficients for %s", ifo)
            continue

        l_times = separate_times[ifo]
        rate = counts / l_times

        ma = mean_alpha[ifo][i]
        ca = cons_alpha[ifo][i]
        mr = mean_count[ifo][i] / live_total[ifo]
        cr = cons_count[ifo][i] / live_total[ifo]

        bin_prop = bin_proportion(bu, bl,
                                  log_spacing=args.log_colormap)
        bin_colour = plt.get_cmap(args.colormap)(bin_prop)
        bin_label = f"duration {bl:.2f}-{bu:.2f}"
        alpha_lines += ax_alpha.plot(separate_starts[ifo], alphas, c=bin_colour,
                                     label=bin_label)
        alpha_lines.append(ax_alpha.axhline(ma,
                                            label="total fit = %.2f" % ma,
                                            c=bin_colour, linestyle='--',))
        alpha_lab = f"{conservative_percentile:d}th %ile = {ca:.2f}"
        alpha_lines.append(ax_alpha.axhline(ca,
                                            c=bin_colour, linestyle=':',
                                            label=alpha_lab))

        count_lines += ax_count.plot(separate_starts[ifo], rate, c=bin_colour,
                                     label=bin_label)
        count_lines.append(ax_count.axhline(mr,
                                            c=bin_colour, linestyle='--',
                                            label=f"mean = {mr:.3f}"))
        count_lab = f"{conservative_percentile:d}th %ile = {cr:.3f}"
        count_lines.append(ax_count.axhline(cr,
                                            c=bin_colour, linestyle=':',
                                            label=count_lab))

    alpha_labels = [l.get_label() for l in alpha_lines]
    ax_alpha.legend(alpha_lines, alpha_labels, loc='lower center',
                    ncol=5, bbox_to_anchor=(0.5, 1.01))
    ax_alpha.set_ylabel('Fit coefficient')

    count_labels = [l.get_label() for l in count_lines]
    ax_count.legend(count_lines, count_labels, loc='lower center',
                    ncol=5, bbox_to_anchor=(0.5, 1.01))
    ax_count.set_ylabel('Rate of triggers above fit threshold [s$^{-1}$]')

    for ax in [ax_count, ax_alpha]:
        ax.set_xticks(xtix)
        ax.set_xticklabels(xtix_labels, rotation=90)
        # Add 1/4 day padding either side
        ax.set_xlim(xtix[0] - 21600, xtix[-1] + 21600)
        ax.grid(zorder=-30)

    fig_count.tight_layout()
    fig_count.savefig(args.output_plot_name_format.format(ifo=ifo, type='counts'))
    fig_alpha.tight_layout()
    fig_alpha.savefig(args.output_plot_name_format.format(ifo=ifo, type='fit_coeffs'))
