#!/usr/bin/python3.13 -s
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot null statistic or coincident SNR vs coherent SNR for a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy
from matplotlib import rc
import matplotlib.pyplot as plt
import pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_null_stats"


# =============================================================================
# Functions
# =============================================================================
# Function to load trigger data
def load_data(input_file, ifos, vetoes, opts, injections=False, slide_id=None):
    """Load data from a trigger/injection file"""

    null_stat_type = opts.y_variable

    # Initialize the dictionary
    data = {}
    data['coherent'] = None
    data[null_stat_type] = None

    if input_file:
        if injections:
            logging.info("Loading injections...")
            # This will eventually become ppu.load_injections()
            trigs_or_injs = \
                ppu.load_triggers(input_file, ifos, vetoes,
                                  rw_snr_threshold=opts.newsnr_threshold,
                                  slide_id=slide_id)
        else:
            logging.info("Loading triggers...")
            trigs_or_injs = \
                ppu.load_triggers(input_file, ifos, vetoes,
                                  rw_snr_threshold=opts.newsnr_threshold,
                                  slide_id=slide_id)

        # Coherent SNR is always used
        data['coherent'] = trigs_or_injs['network/coherent_snr']

        # The other SNR may vary
        if null_stat_type == 'null':
            data[null_stat_type] = \
                trigs_or_injs['network/%s_snr' % null_stat_type][:]
        elif null_stat_type == 'single':
            key = opts.ifo + '/snr'
            data[null_stat_type] = trigs_or_injs[key][:]
        elif null_stat_type == 'coincident':
            data[null_stat_type] = ppu.get_coinc_snr(trigs_or_injs)

        # Count surviving points
        num_trigs_or_injs = len(trigs_or_injs['network/reweighted_snr'])
        if injections:
            logging.info("%d injections found.", num_trigs_or_injs)
        else:
            logging.info("%d triggers found.", num_trigs_or_injs)

    return data


# Function that produces the contrours to be plotted
def calculate_contours(opts, new_snrs=None):
    """Generate the contours to plot"""

    # Add the new SNR threshold contour to the list if necessary
    if new_snrs is None:
        new_snrs = [5.5, 6, 6.5, 7, 8, 9, 10, 11]
    new_snr_thresh = opts.newsnr_threshold
    if new_snr_thresh not in new_snrs:
        new_snrs.append(new_snr_thresh)

    # Get SNR values for contours
    snr_low_vals = numpy.arange(4, 30, 0.1)
    snr_high_vals = numpy.arange(30, 500, 1)
    snr_vals = numpy.asarray(list(snr_low_vals) + list(snr_high_vals))

    # Determine contour consistenly with null_snr in coherent.py
    null_cont = []
    null_thresh = opts.null_snr_threshold
    null_grad_snr = opts.null_grad_thresh
    null_grad_val = opts.null_grad_val
    for snr in snr_vals:
        if snr > null_grad_snr:
            null_cont.append(null_thresh + (snr-null_grad_snr)*null_grad_val)
        else:
            null_cont.append(null_thresh)
    null_cont = numpy.asarray(null_cont)

    return null_cont, snr_vals


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument("--found-missed-file",
                    help="The hdf injection results file", required=False)
parser.add_argument("-z", "--zoom-in", default=False, action="store_true",
                    help="Output file a zoomed in version of the plot.")
parser.add_argument("-y", "--y-variable", default=None,
                    choices=['coincident', 'null'],  # TODO: overwhitened?
                    help="Quantity to plot on the vertical axis.")
ppu.pygrb_add_null_snr_opts(parser)
ppu.pygrb_add_bestnr_cut_opt(parser)
ppu.pygrb_add_slide_opts(parser)
opts = parser.parse_args()
ppu.slide_opts_helper(opts)

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_missed_file = os.path.abspath(opts.found_missed_file)\
    if opts.found_missed_file else None
zoom_in = opts.zoom_in
null_stat_type = opts.y_variable

# Prepare plot title and caption
y_labels = {'null': "Null SNR",
            'coincident': "Coincident SNR"}  # TODO: overwhitened
if opts.plot_title is None:
    opts.plot_title = y_labels[null_stat_type] + " vs Coherent SNR"
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers.  ")
    if found_missed_file:
        opts.plot_caption = opts.plot_caption +\
                            ("Red crosses: injections triggers.  ")

    if null_stat_type == 'coincident':
        opts.plot_caption += ("Green line: coincident SNR = coherent SNR.")
    else:
        opts.plot_caption = opts.plot_caption +\
                             "Black line: veto line.  " +\
                             "Magenta line: above this triggers have " +\
                             "reduced detection statistic."

logging.info("Imported and ready to go.")

# Set output directories
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Extract trigger data
trig_data = load_data(trig_file, ifos, vetoes, opts,
                      slide_id=opts.slide_id)

# Extract (or initialize) injection data
inj_data = load_data(found_missed_file, ifos, vetoes, opts,
                     injections=True, slide_id=0)

# Generate plots
logging.info("Plotting...")

# Contours
snr_vals = None
cont_colors = None
shade_cont_value = None
x_max = None
# Coincident SNR plot case: we want a coinc=coh diagonal line on the plot
if null_stat_type == 'coincident':
    cont_colors = ['g-']
    x_max = plu.axis_max_value(trig_data['coherent'], inj_data['coherent'],
                               found_missed_file)
    snr_vals = [4, x_max]
    null_stat_conts = [[4, x_max]]
# Overwhitened null stat (null SNR) and null stat  cases: newSNR contours
else:
    cont_colors = ['k-', 'm-']
    null_cont, snr_vals = calculate_contours(opts, new_snrs=None)
    null_stat_conts = [null_cont]
    if zoom_in:
        null_thresh = opts.null_snr_threshold
        null_stat_conts.append(numpy.asarray(null_cont) - 1)
    shade_cont_value = 0

# Overwhitened null stat (null SNR), null stat or coincident SNR vs
# Coherent SNR plot
if not opts.x_lims and zoom_in:
    opts.x_lims = '6,30'
if not opts.y_lims and zoom_in:
    opts.y_lims = '0,30'
# Get rcParams
rc('font', size=14)
# Set color for out-of-range values
# Determine y-axis values of triggers and injections
y_label = y_labels[null_stat_type]
trigs = [trig_data['coherent'], trig_data[null_stat_type]]
injs = [inj_data['coherent'], inj_data[null_stat_type]]
plu.pygrb_plotter(trigs, injs, "Coherent SNR", y_label, opts,
                  snr_vals=snr_vals, conts=null_stat_conts,
                  shade_cont_value=shade_cont_value,
                  colors=cont_colors, vert_spike=True,
                  cmd=' '.join(sys.argv))
