#!/usr/bin/python3.13 -s
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Produces signal consistency plots of the form network power/bank/auto or
single detector chi-square vs single IFO/coherent/reweighted/null/coinc SNR.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy
from matplotlib import pyplot as plt
from matplotlib import rc
import pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_chisq_veto"


# =============================================================================
# Functions
# =============================================================================
# Function to load trigger data: includes applying cut in reweighted SNR
def load_data(input_file, ifos, vetoes, opts, injections=False, slide_id=None):
    """Load data from a trigger/injection file"""

    snr_type = opts.snr_type
    veto_type = opts.y_variable

    # Initialize the dictionary
    data = {}
    data[snr_type] = None
    data[veto_type] = None
    data['dof'] = None

    # Ensure that newtwork power chi-square plots show all the data to see
    # the impact of the reweighted SNR cut
    rw_snr_threshold = 0. if veto_type=='network' else opts.newsnr_threshold

    if input_file:
        if injections:
            logging.info("Loading injections...")
            # This will eventually become load_injections
            trigs_or_injs = \
                ppu.load_triggers(input_file, ifos, vetoes,
                                  rw_snr_threshold=rw_snr_threshold,
                                  slide_id=slide_id)
        else:
            logging.info("Loading triggers...")
            trigs_or_injs = \
                ppu.load_triggers(input_file, ifos, vetoes,
                                  rw_snr_threshold=rw_snr_threshold,
                                  slide_id=slide_id)

        # Count surviving points
        num_trigs_or_injs = len(trigs_or_injs['network/reweighted_snr'])

        if snr_type in ['coherent', 'null', 'reweighted']:
            data[snr_type] = trigs_or_injs['network/%s_snr' % snr_type][:]
        elif snr_type == 'single':
            key = opts.ifo + '/snr'
            data[snr_type] = trigs_or_injs[key][:]

        # Calculate coincident SNR
        elif snr_type == 'coincident':
            data[snr_type] = ppu.get_coinc_snr(trigs_or_injs)

        # Tags to find vetoes in HDF files
        veto_tags = {'power': 'chisq',
                     'bank': 'bank_chisq',
                     'auto': 'auto_chisq',
                     'network': 'my_network_chisq'}

        # This chi-square is already normalized
        if veto_type == 'network':
            chisq_key = 'network/my_network_chisq'
            data['dof'] = 1.
        else:
            chisq_key = opts.ifo + '/' + veto_tags[veto_type]
            dof_key = '%s/%s_dof' % (opts.ifo, veto_tags[veto_type])
            data['dof'] = trigs_or_injs[dof_key][:]

        # Normalize
        data[veto_type] = trigs_or_injs[chisq_key][:]/data['dof']

        # Floor single IFO chi-square at 0.005
        numpy.putmask(data[veto_type], data[veto_type] == 0, 0.005)

        label = "injections" if injections else "triggers"

        logging.info("{0} {1} found.".format(num_trigs_or_injs, label))

    return data


# Function to calculate chi-square weight for the reweighted SNR
def new_snr_chisq(snr, new_snr, chisq_index=4.0, chisq_nhigh=3.0):
    """Returns the chi-square value needed to weight SNR into new SNR"""

    chisqnorm = (snr/new_snr)**chisq_index
    if chisqnorm <= 1:
        return 1E-20

    return (2*chisqnorm - 1)**(chisq_nhigh/chisq_index)


# Function that produces the contours to be plotted
def calculate_contours(trig_data, opts, new_snrs=None):
    """Generate the contours for the veto plots"""

    # Add the new SNR threshold contour to the list if necessary
    # and keep track of where it is
    if new_snrs is None:
        new_snrs = [5.5, 6, 6.5, 7, 8, 9, 10, 11]
    try:
        cont_value = new_snrs.index(opts.newsnr_threshold)
    except ValueError:
        new_snrs.append(opts.newsnr_threshold)
        cont_value = -1

    # Get SNR values for contours
    snr_low_vals = numpy.arange(1, 30, 0.1)
    snr_high_vals = numpy.arange(30, 500, 1)
    snr_vals = numpy.asarray(list(snr_low_vals) + list(snr_high_vals))

    # Initialise contours
    contours = numpy.zeros([len(new_snrs), len(snr_vals)],
                           dtype=numpy.float64)

    # Loop over SNR values and calculate chisq variable needed
    for j, snr in enumerate(snr_vals):
        for i, new_snr in enumerate(new_snrs):
            contours[i][j] = new_snr_chisq(snr, new_snr,
                                           opts.chisq_index,
                                           opts.chisq_nhigh)

    # Colors and styles of the contours
    colors = ["k-" if snr == opts.newsnr_threshold else
              "y-" if snr == int(snr) else
              "y--" for snr in new_snrs]

    return contours, snr_vals, cont_value, colors


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument("--found-missed-file",
                    help="The hdf injection results file", required=False)
parser.add_argument("-z", "--zoom-in", default=False, action="store_true",
                    help="Output file a zoomed in version of the plot.")
parser.add_argument("-y", "--y-variable", required=True,
                    choices=['network', 'bank', 'auto', 'power'],
                    help="Quantity to plot on the vertical axis.")
parser.add_argument("--snr-type", default='coherent',
                    choices=['coherent', 'coincident', 'null', 'reweighted',
                             'single'], help="SNR value to plot on x-axis.")
ppu.pygrb_add_bestnr_cut_opt(parser)
ppu.pygrb_add_bestnr_opts(parser)
ppu.pygrb_add_slide_opts(parser)
opts = parser.parse_args()
ppu.slide_opts_helper(opts)

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_missed_file = os.path.abspath(opts.found_missed_file) \
    if opts.found_missed_file else None
zoom_in = opts.zoom_in
veto_type = opts.y_variable
ifo = opts.ifo
snr_type = opts.snr_type
# If this is false, coherent SNR is used on the horizontal axis
# otherwise the single IFO SNR is used
if snr_type == 'single':
    if ifo is None:
        err_msg = "--ifo must be given to plot single IFO SNR veto"
        parser.error(err_msg)

# Veto is intended as a single IFO quantity. Network chisq will be obsolete.
# TODO: fix vetoes

# Prepare plot title and caption
veto_labels = {'network': "Network Power",
               'bank': "Bank",
               'auto': "Auto",
               'power': "Power"}
if opts.plot_title is None:
    opts.plot_title = " %s Chi Square" % veto_labels[veto_type]
    if veto_type != 'network':
        opts.plot_title = ifo + opts.plot_title
    if snr_type == 'single':
        opts.plot_title += " vs %s SNR" % (ifo)
    else:
        opts.plot_title += " vs %s SNR" % snr_type.capitalize()
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers. ")
    if found_missed_file:
        opts.plot_caption += "Red crosses: injections triggers. "
    if veto_type == 'network':
        opts.plot_caption += ("Gray shaded region: area cut by the " +
                              "reweighted SNR threshold. " +
                              "Black line: reweighted SNR threshold. Yellow " +
                              "lines: contours of constant rewighted SNR.")

logging.info("Imported and ready to go.")

# Set output directory
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Exit gracefully if the requested IFO is not available
if ifo and ifo not in ifos:
    err_msg = "The IFO selected with --ifo is unavailable in the data."
    raise RuntimeError(err_msg)

# Extract trigger data
trig_data = load_data(trig_file, ifos, vetoes, opts,
                      slide_id=opts.slide_id)

# Extract (or initialize) injection data
inj_data = load_data(found_missed_file, ifos, vetoes, opts,
                     injections=True, slide_id=0)

# Sanity checks
if trig_data[snr_type] is None and inj_data[snr_type] is None:
    err_msg = "No data to be plotted on the x-axis was found"
    raise RuntimeError(err_msg)
if trig_data[veto_type] is None and inj_data[veto_type] is None:
    err_msg = "No data to be plotted on the y-axis was found"
    raise RuntimeError(err_msg)

# Generate plots
logging.info("Plotting...")

# Determine x-axis values of triggers and injections
# Default is coherent SNR
x_label = ifo if snr_type == 'single' else snr_type.capitalize()
x_label = "%s SNR" % x_label

# Determine the minumum and maximum SNR value we are dealing with
x_min = 0.9*plu.axis_min_value(trig_data[snr_type], inj_data[snr_type],
                               found_missed_file)
x_max = 1.1*plu.axis_max_value(trig_data[snr_type], inj_data[snr_type],
                               found_missed_file)

# Determine the minimum and maximum chi-square value we are dealing with
y_min = 0.9*plu.axis_min_value(trig_data[veto_type], inj_data[veto_type],
                               found_missed_file)
y_max = 1.1*plu.axis_max_value(trig_data[veto_type], inj_data[veto_type],
                               found_missed_file)

# Determine y-axis minimum value and label
y_label = "Network power chi-square" if veto_type == 'network' \
    else "%s Single %s chi-square" % (ifo, veto_labels[veto_type].lower())

# Determine contours for plots
conts = None
snr_vals = None
cont_value = None
colors = None
# Enable countours of constant reweighted SNR as a function of coherent SNR
if snr_type == 'coherent':
    conts, snr_vals, cont_value, colors = calculate_contours(trig_data,
                                                             opts,
                                                             new_snrs=None)
# The cut in reweighted SNR involves only the network power chi-square
if veto_type != 'network':
    cont_value = None

# Produce the veto vs. SNR plot
if not opts.x_lims:
    if zoom_in:
        opts.x_lims = str(x_min)+',50'
        opts.y_lims = str(y_min)+',20000'
    else:
        opts.x_lims = str(x_min)+','+str(x_max)
        opts.y_lims = str(y_min)+','+str(10*y_max)
trigs = [trig_data[snr_type], trig_data[veto_type]]
injs = [inj_data[snr_type], inj_data[veto_type]]
plu.pygrb_plotter(trigs, injs, x_label, y_label, opts,
                  snr_vals=snr_vals, conts=conts, colors=colors,
                  shade_cont_value=cont_value, vert_spike=True,
                  cmd=' '.join(sys.argv))
