#!/usr/bin/python3.13 -s
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot single IFO/coherent/reweighted/null SNR timeseries for a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy
import matplotlib.pyplot as plt
from matplotlib import rc
import pycbc.version
from pycbc import init_logging
from pycbc.events.coherent import reweightedsnr_cut
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_snr_timeseries"


# =============================================================================
# Functions
# =============================================================================
# Load trigger data
def load_data(input_file, ifos, vetoes, rw_snr_threshold=None,
              injections=False, slide_id=None):
    """Load data from a trigger/injection file"""

    trigs_or_injs = None
    if input_file:
        if injections:
            logging.info("Loading injections...")
            # This will eventually become load_injections
            trigs_or_injs = \
                ppu.load_triggers(input_file, ifos, vetoes,
                                  rw_snr_threshold=rw_snr_threshold,
                                  slide_id=slide_id)
        else:
            logging.info("Loading triggers...")
            trigs_or_injs = \
                ppu.load_triggers(input_file, ifos, vetoes,
                                  rw_snr_threshold=rw_snr_threshold,
                                  slide_id=slide_id)

    return trigs_or_injs


# Find start and end times of trigger/injecton data relative to a given time
def get_start_end_times(data_time, central_time):
    """Determine padded start and end times of data relative to central_time"""

    start = int(min(data_time)) - central_time
    end = int(max(data_time)) - central_time
    duration = end - start
    start -= duration*0.05
    end += duration*0.05

    return start, end


# Reset times so that t=0 is corresponds to the given trigger time
def reset_times(data_time, trig_time):
    """Reset times so that t=0 corresponds to the trigger time provided"""

    data_time = [t-trig_time for t in data_time]

    return data_time


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument("--found-missed-file",
                    help="The hdf injection results file", required=False)
parser.add_argument("--trigger-time", type=float, default=0,
                    help="External GPS time.  Used to center the plot.")
parser.add_argument("-y", "--y-variable", default=None,
                    choices=['coherent', 'single', 'reweighted', 'null'],
                    help="Quantity to plot on the vertical axis.")
ppu.pygrb_add_bestnr_cut_opt(parser)
ppu.pygrb_add_slide_opts(parser)
opts = parser.parse_args()
ppu.slide_opts_helper(opts)

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
inj_file = \
    os.path.abspath(opts.found_missed_file) if opts.found_missed_file else None
snr_type = opts.y_variable
ifo = opts.ifo
if snr_type == 'single' and ifo is None:
    err_msg = "Please specify an interferometer for a single IFO plot"
    parser.error(err_msg)

logging.info("Imported and ready to go.")

# Set output directories
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Load trigger and injections data: when plotting reweighted SNR, keep all
# points to show the impact of the cut, otherwise remove points with
# reweighted SNR below threshold
if snr_type == 'reweighted':
    trig_data = load_data(trig_file, ifos, vetoes,
                          slide_id=opts.slide_id)
    trig_data['network/reweighted_snr'] = \
        reweightedsnr_cut(trig_data['network/reweighted_snr'],
                          opts.newsnr_threshold)
    inj_data = load_data(inj_file, ifos, vetoes, injections=True,
                         slide_id=0)
    if inj_data is not None:
        inj_data['network/reweighted_snr'] = \
            reweightedsnr_cut(inj_data['network/reweighted_snr'],
                              opts.newsnr_threshold)
else:
    trig_data = load_data(trig_file, ifos, vetoes,
                          rw_snr_threshold=opts.newsnr_threshold,
                          slide_id=opts.slide_id)
    inj_data = load_data(inj_file, ifos, vetoes,
                         rw_snr_threshold=opts.newsnr_threshold,
                         injections=True, slide_id=0)

# Specify HDF file keys for x quantity (time) and y quantity (SNR)
if snr_type == 'single':
    x_key = opts.ifo + '/end_time'
    y_key = opts.ifo + '/snr'
else:
    x_key = 'network/end_time_gc'
    y_key = 'network/' + snr_type + '_snr'

# Obtain times
trig_data_time = trig_data[x_key][:]
inj_data_time = inj_data[x_key][:] if inj_file else None

# Obtain SNRs
trig_data_snr = trig_data[y_key][:]
inj_data_snr = inj_data[y_key][:] if inj_file else None

# Determine the central time (t=0) for the plot
central_time = opts.trigger_time

# Determine trigger data start and end times relative to the central time
start, end = get_start_end_times(trig_data_time, central_time)

# Reset trigger and injection times
trig_data_time = reset_times(trig_data_time, central_time)
if inj_file:
    inj_data_time = reset_times(inj_data_time, central_time)

# Generate plots
logging.info("Plotting...")

# Determine what goes on the vertical axis
y_labels = {'coherent': "Coherent SNR",
            'single': "%s SNR" % ifo,
            'null': "Null SNR",
            'reweighted': "Reweighted SNR"}
y_label = y_labels[snr_type]

# Determine title and caption
if opts.plot_title is None:
    opts.plot_title = y_label + " vs Time"
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers.  ")
    if inj_file:
        opts.plot_caption += ("Red crosses: injections triggers.")

# Single IFO SNR versus time plots
xlims = [start, end]
if opts.x_lims:
    xlims = opts.x_lims
    xlims = map(float, xlims.split(','))
plu.pygrb_plotter([trig_data_time, trig_data_snr],
                  [inj_data_time, inj_data_snr],
                  "Time since %.3f (s)" % (central_time), y_label,
                  opts, cmd=' '.join(sys.argv))
