#!/usr/bin/python3.13 -s
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot single IFO SNR vs coherent SNR for a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import collections
import operator
from matplotlib import pyplot as plt
from matplotlib import rc
import numpy
import scipy

import pycbc.version
from pycbc import init_logging
from pycbc.detector import Detector
from pycbc.results import save_fig_with_metadata
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_coh_ifosnr"


# =============================================================================
# Functions
# =============================================================================
# Function to load trigger data
def load_data(input_file, ifos, vetoes, opts, injections=False, slide_id=None):
    """Load data from a trigger/injection file"""

    # Initialize the dictionary
    data = {}
    data['coherent'] = None
    data['single'] = dict((ifo, None) for ifo in ifos)
    data['f_resp_mean'] = dict((ifo, None) for ifo in ifos)
    data['sigma_mean'] = dict((ifo, None) for ifo in ifos)
    data['sigma_max'] = None
    data['sigma_min'] = None

    if input_file:
        if injections:
            logging.info("Loading injections...")
            # TODO: This will eventually become load_injections
            trigs = ppu.load_triggers(
                input_file,
                ifos,
                vetoes,
                rw_snr_threshold=opts.newsnr_threshold,
                slide_id=slide_id
            )
        else:
            logging.info("Loading triggers...")
            trigs = ppu.load_triggers(
                input_file,
                ifos,
                vetoes,
                rw_snr_threshold=opts.newsnr_threshold,
                slide_id=slide_id
            )

        # Load SNR data
        data['coherent'] = trigs['network/coherent_snr']

        # Get single ifo SNR data
        ifo_trig_index = {}
        for ifo in ifos:
            sorted_ifo_ids = numpy.array(
                [
                    numpy.where(trigs['%s/event_id' % ifo] == idx)[0][0]
                    for idx in trigs['network/event_id']
                ]
            )
            ifo_trig_index[ifo] = sorted_ifo_ids
            data['single'][ifo] = trigs['%s/snr' % ifo][:][sorted_ifo_ids]

        # Get sigma for each ifo
        sigma = {}
        for ifo in ifos:
            sigma[ifo] = trigs['%s/sigmasq' % ifo][:][ifo_trig_index[ifo]]

        # Create array for sigma_tot
        sigma_tot = numpy.zeros(len(data['coherent']))

        # Get parameters necessary for antenna responses
        ra = trigs['network/ra'][:]
        dec = trigs['network/dec'][:]
        time = trigs['network/end_time_gc'][:]

        # Get antenna response based parameters
        for ifo in ifos:
            antenna = Detector(ifo)
            ifo_f_resp = ppu.get_antenna_responses(antenna, ra, dec, time)
            # Get the average for f_resp_mean and calculate sigma_tot
            data['f_resp_mean'][ifo] = ifo_f_resp.mean()
            sigma_tot += sigma[ifo] * ifo_f_resp

        for ifo in ifos:
            try:
                sigma_norm = sigma[ifo] / sigma_tot
                data['sigma_mean'][ifo] = sigma_norm.mean()
                if ifo == opts.ifo:
                    data['sigma_max'] = sigma_norm.max()
                    data['sigma_min'] = sigma_norm.min()
            except ValueError:
                data['sigma_mean'][ifo] = 0
                if ifo == opts.ifo:
                    data['sigma_max'] = 0
                    data['sigma_min'] = 0

        logging.info("%d triggers found.", len(data['coherent']))

    return data


# Plot lines representing deviations based on non-central chi-square
def plot_deviation(percentile, snr_grid, y, ax, style):
    """Plot deviations based on non-central chi-square"""

    # ncx2: non-central chi-squared; ppf: percent point function
    # ax.plot(snr_grid, scipy.stats.ncx2.ppf(percentile, 2, y*y)**0.5, style)

    # Using interpolation to work around "saturation" given by the
    # original code line (commented out above)
    y_vals = scipy.stats.ncx2.ppf(percentile, 2, y * y) ** 0.5
    y_vals = numpy.unique(y_vals)
    x_vals = snr_grid[0:len(y_vals)]
    n_vals = int(len(y_vals) / 2)
    f = scipy.interpolate.interp1d(
        x_vals[0:n_vals],
        y_vals[0:n_vals],
        bounds_error=False,
        fill_value="extrapolate",
    )
    ax.plot(snr_grid, f(snr_grid), style)


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument(
    "-t",
    "--trig-file",
    action="store",
    default=None,
    required=True,
    help="The location of the trigger file",
)
parser.add_argument(
    "--found-missed-file",
    help="The hdf injection results file",
    required=False,
)
parser.add_argument(
    "-z",
    "--zoom-in",
    default=False,
    action="store_true",
    help="Output file a zoomed in version of the plot.",
)
ppu.pygrb_add_bestnr_cut_opt(parser)
ppu.pygrb_add_slide_opts(parser)
opts = parser.parse_args()
ppu.slide_opts_helper(opts)

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_file = (
    os.path.abspath(opts.found_missed_file) if opts.found_missed_file else None
)
zoom_in = opts.zoom_in
ifo = opts.ifo
if ifo is None:
    err_msg = "Please specify an interferometer"
    parser.error(err_msg)

if opts.plot_title is None:
    opts.plot_title = '%s SNR vs Coherent SNR' % ifo
if opts.plot_caption is None:
    opts.plot_caption = "Blue crosses: background triggers.  "
    if found_file:
        opts.plot_caption += "Red crosses: injections triggers.  "
    opts.plot_caption = (
        opts.plot_caption
        + "Black line: veto line.  "
        + "Gray shaded region: vetoed area - The cut is "
        + "applied only to the two most sensitive detectors, "
        + "which can vary with template parameters and sky location. "
        + "Green lines: the expected SNR for optimally "
        + "oriented injections (mean, min, and max).  "
        + "Magenta lines: 2 sigma errors.  "
        + "Cyan lines: 3 sigma errors."
    )

logging.info("Imported and ready to go.")

# Set output directories
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(
    trig_file, opts.veto_files, opts.veto_category
)

# Extract trigger data
trig_data = load_data(trig_file, ifos, vetoes, opts, slide_id=opts.slide_id)

# Extract (or initialize) injection data
inj_data = load_data(found_file, ifos, vetoes, opts, injections=True, slide_id=0)

# Generate plots
logging.info("Plotting...")

# Order the IFOs by sensitivity
ifo_senstvty = {}
for i_ifo in ifos:
    senstvty = trig_data['f_resp_mean'][i_ifo] * trig_data['sigma_mean'][i_ifo]
    ifo_senstvty.update({i_ifo: senstvty})
ifo_senstvty = collections.OrderedDict(
    sorted(ifo_senstvty.items(), key=operator.itemgetter(1), reverse=True)
)
loudness_labels = ['first', 'second', 'third']

# Determine the maximum coherent SNR value we are dealing with
x_max = plu.axis_max_value(
    trig_data['coherent'], inj_data['coherent'], found_file
)
max_snr = x_max
if x_max < 50.0:
    max_snr = 50.0

# Determine the maximum auto veto value we are dealing with
y_max = plu.axis_max_value(
    trig_data['single'][ifo], inj_data['single'][ifo], found_file
)

# Setup the plots
x_label = "Coherent SNR"
y_label = "%s sngl SNR" % ifo
fig = plt.figure()
ax = fig.gca()
# Plot trigger data
ax.plot(trig_data['coherent'], trig_data['single'][ifo], 'bx')
ax.grid()
# Plot injection data
if found_file:
    ax.plot(inj_data['coherent'], inj_data['single'][ifo], 'r+')
# Sigma-mean, min, max
y_data = {
    'mean': trig_data['sigma_mean'][ifo],
    'min': trig_data['sigma_min'],
    'max': trig_data['sigma_max'],
}
# Calculate: zoom-snr * sqrt(response * sigma-mean, min, max)
snr_grid = numpy.arange(0.01, max_snr, 1)
y_data = dict(
    (key, snr_grid * (trig_data['f_resp_mean'][ifo] * y_data[key]) ** 0.5)
    for key in y_data
)
for key in y_data:
    ax.plot(snr_grid, y_data[key], 'g-')
# 2 sigma (0.9545)
plot_deviation(0.02275, snr_grid, y_data['min'], ax, 'm-')
plot_deviation(1 - 0.02275, snr_grid, y_data['max'], ax, 'm-')
# 3 sigma (0.9973)
plot_deviation(0.00135, snr_grid, y_data['min'], ax, 'c-')
plot_deviation(1 - 0.00135, snr_grid, y_data['max'], ax, 'c-')
# Non-zoomed plot
ax.plot([0, max_snr], [4, 4], 'k-')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_xlim([0, 1.01 * x_max])
ax.set_ylim([0, 1.20 * y_max])
# Veto applies to the two most sensitive IFOs, so shade them
loudness_index = list(ifo_senstvty.keys()).index(ifo)
if loudness_index < 2:
    limy = ax.get_ylim()[0]
    polyx = [0, max_snr]
    polyy = [4, 4]
    polyx.extend([max_snr, 0])
    polyy.extend([limy, limy])
    ax.fill(polyx, polyy, color='#dddddd')
opts.plot_title += (
    " (%s average sensitivity)" % loudness_labels[loudness_index]
)
# Zoom in if asked to do so
if opts.zoom_in:
    ax.set_xlim([6, 50])
    ax.set_ylim([0, 20])
# Save plot
save_fig_with_metadata(
    fig,
    opts.output_file,
    cmd=' '.join(sys.argv),
    title=opts.plot_title,
    caption=opts.plot_caption,
)
plt.close()
