#!/usr/bin/python3.13 -s

# Copyright (C) 2021 Francesco Pannarale, Viviana Caceres
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot found/missed injection properties for the triggered search (PyGRB).
"""

import logging
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import os.path
import sys

import pycbc.pnutils
import pycbc.results
import pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.io.hdf import HFile

plt.switch_backend('Agg')
matplotlib.rc("image")

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_injs_results"


# =============================================================================
# Functions
# =============================================================================
def process_var_strings(qty):
    """Add underscores to match HDF column name conventions"""

    qty = qty.replace('skyerror', 'sky_error')
    qty = qty.replace('cos', 'cos_')
    qty = qty.replace('abs', 'abs_')
    qty = qty.replace('endtime', 'end_time')
    qty = qty.replace('spin1a', 'spin1_a')
    qty = qty.replace('spin2a', 'spin2_a')

    return qty


def load_incl_data(input_file_handle, key, tag):
    """Extract data related to inclination from raw injection data"""

    local_dict = {}

    # Whether the user requests incl, |incl|, cos(incl), or cos(|incl|)
    # the following information is needed
    # TODO: make sure it is clear that inclination is interpreted as theta_jn
    local_dict['incl'] = input_file_handle[tag+'/thetajn'][:]

    # Requesting |incl| or cos(|incl|)
    if 'abs_' in key:
        local_dict['abs_incl'] = 0.5*np.pi - \
            abs(local_dict['incl'] - 0.5*np.pi)

    # Requesting cos(incl) or cos(|incl|): take cosine
    if 'cos_' in key:
        angle = key.replace('cos_', '')
        angle_data = local_dict[angle]
        data = np.cos(angle_data)
    # Requesting incl or abs_incl: convert to degrees
    else:
        data = np.rad2deg(local_dict[key])

    return data


# Function to extract mass ratio or total mass data from a injection file
def load_mass_data(input_file_handle, key, tag):
    """Extract data related to mass ratio, chirp mass or total mass from raw
    injection data"""

    if key == 'mtotal':
        data = input_file_handle[tag+'/mass1'][:] + \
               input_file_handle[tag+'/mass2'][:]
    elif key == 'mchirp':
        data, _ = pycbc.pnutils.mass1_mass2_to_mchirp_eta(
            input_file_handle[tag+'/mass1'][:],
            input_file_handle[tag+'/mass2'][:])
    else:
        data = input_file_handle[tag+'/mass2'][:] / \
               input_file_handle[tag+'/mass1'][:]
        data = np.where(data > 1, 1./data, data)

    return data


def load_sky_error_data(input_file_handle, key, tag):
    """Extract data related to sky_error from raw injection and trigger data"""
    if tag == 'missed':
        # Missed injections are assigned null values
        data = np.full(len(input_file_handle[tag+'/mass1'][:]), None)
    else:
        inj = {}
        inj['ra'] = input_file_handle[tag+'/ra'][:]
        inj['dec'] = input_file_handle[tag+'/dec'][:]
        trig = {}
        trig['ra'] = input_file_handle['network/ra'][:]
        trig['dec'] = input_file_handle['network/dec'][:]
        data = np.arccos(np.cos(inj['dec'] - trig['dec']) -
                         np.cos(inj['dec']) * np.cos(trig['dec']) *
                         (1 - np.cos(inj['ra'] - trig['ra'])))

    return data


# These are keys in the found-missed-file under 'found' and 'missed'
# TODO: also in the found-missed injs file is 'inclination'
easy_keys = ['distance', 'mass1', 'mass2', 'polarization',
             'spin1_a', 'spin1x', 'spin1y', 'spin1z',
             'spin2_a', 'spin2x', 'spin2y', 'spin2z',
             'spin1_azimuthal', 'spin1_polar',
             'spin2_azimuthal', 'spin2_polar',
             'dec', 'ra', 'phi_ref']


# Function to extract desired data from an injection file
def load_data(input_file_handle, keys, tag):
    """Create a dictionary containing the data specified by the
    list of keys extracted from an injection file"""

    data_dict = {}

    for key in keys:
        data_dict[key] = np.array([])
        try:
            if key in easy_keys:
                data_dict[key] = f[tag][key][:]
            elif key == 'end_time':
                data_dict[key] = f[tag]['tc'][:]
                # data_dict[key] -= grb_time
            elif key in ['q', 'mchirp', 'mtotal']:
                data_dict[key] = load_mass_data(input_file_handle, key, tag)
            elif 'incl' in key:
                data_dict[key] = load_incl_data(input_file_handle, key, tag)
            elif key == 'sky_error':
                data_dict[key] = load_sky_error_data(input_file_handle, key,
                                                     tag)
        except KeyError:
            # raise NotImplemented(key+' not allowed: returning empty entry')
            logging.info(key+' not allowed yet')

    logging.info("{0} {1} injections analysed.".format(len(data_dict[keys[0]]), tag))

    return data_dict


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument("--found-missed-file",
                    help="The hdf injection results file", required=True)
parser.add_argument("--trig-file",
                    help="The hdf offsource trigger file", required=True)
parser.add_argument("--trigger-time",
                    help="The GPS time of the external trigger", required=True)
# FIXME: not all of these work
# TODO: effective distance
admitted_vars = easy_keys + ['mtotal', 'q', 'mchirp',
                             'spin1a', 'spin2a',
                             'incl', 'cos_incl', 'abs_incl', 'cos_abs_incl',
                             'cosincl', 'absincl', 'cosabsincl',
                             'sky_error', 'skyerror', 'end_time', 'endtime',
                             'latitude', 'longitude', 'coaphase', 'coa_phase',
                             'eff_site_dist', 'eff_dist',
                             'effsitedist', 'effdist']
parser.add_argument("-x", "--x-variable", default=None, required=True,
                    choices=admitted_vars,
                    help="Quantity to plot on the horizontal axis. " +
                    "(Underscores may be omitted in specifying this option).")
parser.add_argument("--x-log", action="store_true",
                    help="Use log horizontal axis")
parser.add_argument("-y", "--y-variable", default=None, required=True,
                    choices=admitted_vars,
                    help="Quantity to plot on the vertical axis. " +
                    "(Underscores may be omitted in specifying this option).")
parser.add_argument("--y-log", action="store_true",
                    help="Use log vertical axis")
parser.add_argument("--colormap", default='cividis_r',
                    help="Type of colormap to be used for the plots.")
parser.add_argument("--far-type", choices=('inclusive', 'exclusive'),
                    default='inclusive',
                    help="Type of far to plot for the color. Choices are "
                         "'inclusive' or 'exclusive'. Default = 'inclusive'")
parser.add_argument("--missed-on-top", action='store_true',
                    help="Plot missed injections on top of found ones and "
                         "high FAR on top of low FAR")
ppu.pygrb_add_bestnr_cut_opt(parser)
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

x_qty = process_var_strings(opts.x_variable)
y_qty = process_var_strings(opts.y_variable)

if 'eff_site_dist' in [x_qty, y_qty] and opts.ifo is None:
    err_msg = "A value for --ifo must be provided for "
    err_msg += "site specific effective distance"
    parser.error(err_msg)

# Store options used multiple times in local variables
outfile = opts.output_file
trig_file = os.path.abspath(opts.trig_file)
f = HFile(opts.found_missed_file, 'r')
grb_time = opts.trigger_time

# Set output directory
logging.info("Setting output directory.")
outdir = os.path.split(os.path.abspath(outfile))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Load triggers. Time-slides not yet available
logging.info("Loading triggers...")
trig_data = ppu.load_triggers(trig_file, ifos, vetoes,
                              rw_snr_threshold=opts.newsnr_threshold)
max_reweighted_snr = max(trig_data['network/reweighted_snr'][:])

# =======================
# Post-process injections
# =======================
# Extract the necessary data from the missed injections for the plot
missed_inj = load_data(f, [x_qty, y_qty], 'missed')

# Extract the necessary data from the found injections for the plot
found_inj = load_data(f, [x_qty, y_qty], 'found')

# Injections found surviving vetoes
found_after_vetoes = found_inj if 'found_after_vetoes' not in f.keys() \
    else f['found_after_vetoes']
# FIXME: Found but vetoed injections
# missed_after_vetoes = list(set(found) - set(found_after_vetoes))
# missed_after_vetoes = np.sort(missed_after_vetoes).astype(int)
missed_after_vetoes = {x_qty: np.array([]), y_qty: np.array([])}
# Injections found but not louder than background
# are populated further down

# Extract the detection statistic of injections found after vetoes
if len(list(f['network'].keys())) > 0:
    found_after_vetoes_stat = f['network/reweighted_snr'][:] \
        if 'found_after_vetoes' not in f.keys() else \
        f['found_after_vetoes/stat'][:]

    # Separate triggers into:
    # 1) Found louder than background
    louder_mask = found_after_vetoes_stat > max_reweighted_snr
    found_louder = {}
    for key in found_after_vetoes.keys():
        found_louder[key] = found_after_vetoes[key][louder_mask]
    found_louder['reweighted_snr'] = found_after_vetoes_stat[louder_mask]

    # 2) Found quieter than background: injections found (bestnr > 0)
    #    but not louder than background (non-zero FAP)
    quieter_mask = (found_after_vetoes_stat <= max_reweighted_snr) \
        & (found_after_vetoes_stat != 0)
    found_quieter = {}
    for key in found_after_vetoes.keys():
        found_quieter[key] = found_after_vetoes[key][quieter_mask]
    found_quieter['reweighted_snr'] = found_after_vetoes_stat[quieter_mask]

    # TODO: ifar still missing
    """
    # Extract inclusive/exclusive IFAR for injections found quieter than background
    ifar_string = 'found_after_vetoes/ifar' if opts.far_type == 'inclusive' \
        else 'found_after_vetoes/ifar_exc'
    found_quieter['ifar'] = f[ifar_string][quieter_mask]
    """

    # 3) Missed due to vetoes
    # TODO: needs function to cherry-pick a subset of inj_data specified by
    # a mask on FAP values.
else:
    found_louder = {x_qty: np.array([]), y_qty: np.array([])}
    found_quieter = {x_qty: np.array([]), y_qty: np.array([])}
# TMP FIX
vetoed = missed_after_vetoes

# Statistics: found on top (found-missed)
# TODO: ifar still missing
#FM = np.argsort(found_quieter['ifar'])
logging.info("WARNING: IFAR not implemented yet")
logging.info("WARNING: Avoiding failure by setting it to -1 and shifting colorbar from [0,1] to [-2,0].")
if found_quieter[x_qty].size:
   FM = np.full(found_quieter['reweighted_snr'].shape, -1)
   # Statistics: missed on top (missed-found)
   MF = FM[::-1]

# Post-processing of injections ends here

# ==========
# Make plots
# ==========

# Info for site-specific plots
sitename = {'G1': 'GEO', 'H1': 'Hanford', 'L1': 'Livingston', 'V1': 'Virgo',
            'K1': 'KAGRA', 'A1': 'India Aundha'}

# Take care of axes labels
axis_labels_dict = {'mchirp': "Chirp Mass (solar masses)",
                    'mtotal': "Total mass (solar masses)",
                    'q': "Mass ratio",
                    'distance': "Distance (Mpc)",
                    'eff_site_dist': "%s effective distance (Mpc)" % sitename.get(opts.ifo),
                    'eff_dist': "Inverse sum of effective distances (Mpc)",
                    'end_time': "Time since %s (s)" % grb_time,
                    'sky_error': "Rec. sky error (radians)",
                    'coa_phase': "Phase of complex SNR (radians)",
                    'latitude': "Latitude (degrees)",
                    'longitude': "Longitude (degrees)",
                    'incl': "Inclination (iota)",
                    'abs_incl': 'Magnitude of inclination (|iota|)',
                    'cos_incl': "cos(iota)",
                    'cos_abs_incl': "cos(|iota|)",
                    'mass1': "Mass of 1st binary component (solar masses)",
                    'mass2': "Mass of 2nd binary component (solar masses)",
                    'polarization': "Polarization phase (radians)",
                    'spin1_a': "Spin on 1st binary component",
                    'spin1x': "Spin x-component of 1st binary component",
                    'spin1y': "Spin y-component of 1st binary component",
                    'spin1z': "Spin z-component of 1st binary component",
                    'spin2_a': "Spin on 2nd binary component",
                    'spin2x': "Spin x-component of 2nd binary component",
                    'spin2y': "Spin y-component of 2nd binary component",
                    'spin2z': "Spin z-component of 2nd binary component"}

x_label = axis_labels_dict[x_qty]
y_label = axis_labels_dict[y_qty]

fig = plt.figure()
xscale = "log" if opts.x_log else "linear"
yscale = "log" if opts.y_log else "linear"
ax = fig.gca()
ax.set_xscale(xscale)
ax.set_yscale(yscale)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)

# Define p-value colour
cmap = plt.get_cmap('cividis_r')
# Set color for out-of-range values
# cmap.set_over('g')

# Define the 'found' injection colour
fnd_col = cmap(0)
fnd_col = np.array([fnd_col])
if not opts.missed_on_top:
    if missed_inj[x_qty].size and missed_inj[y_qty].size:
        ax.scatter(missed_inj[x_qty], missed_inj[y_qty],
                   c="black", marker="x", s=10)
    # FIXME: once vetoed is filled in properly
    # if vetoed[x_qty].size:
    if x_qty in vetoed.keys():
        ax.scatter(vetoed[x_qty], vetoed[y_qty], c="red", marker="x", s=10)
    if found_quieter[x_qty].size:
        p = ax.scatter(found_quieter[x_qty][FM], found_quieter[y_qty][FM],
                       c=found_quieter['reweighted_snr'][FM],
                       cmap=cmap, vmin=-2, vmax=0, s=40,
                       edgecolor="w", linewidths=2.0)
        cb = plt.colorbar(p, label="p-value")
    if found_louder[x_qty].size:
        ax.scatter(found_louder[x_qty], found_louder[y_qty],
                   c=fnd_col, marker="+", s=30)
elif opts.missed_on_top:
    if found_louder[x_qty].size:
        ax.scatter(found_louder[x_qty], found_louder[y_qty],
                   c=fnd_col, marker="+", s=15)
    if found_quieter[x_qty].size:
        p = ax.scatter(found_quieter[x_qty][MF], found_quieter[y_qty][MF],
                       c=found_quieter['reweighted_snr'][MF],
                       cmap=cmap, vmin=-2, vmax=0, s=40,
                       edgecolor="w", linewidths=2.0)
        cb = plt.colorbar(p, label="p-value")
    # FIXME: once vetoed is filled in properly
    # if vetoed[x_qty].size:
    if x_qty in vetoed.keys():
        ax.scatter(vetoed[x_qty], vetoed[y_qty], c="red", marker="x", s=40)
    if missed_inj[x_qty].size and missed_inj[y_qty].size:
        ax.scatter(missed_inj[x_qty], missed_inj[y_qty],
                   c="black", marker="x", s=40)
ax.grid()

# Handle axis limits when plotting spins
if "spin" in x_qty and missed_inj['spin1_a'].size:
    max_missed_inj = missed_inj['spin1_a'].max()
    ax.set_xlim([0, np.ceil(10 * max(max_missed_inj,
                                     found_inj[x_qty].max())) / 10])
if "spin" in y_qty and missed_inj['spin2_a'].size:
    max_missed_inj = missed_inj['spin2_a'].max()
    ax.set_ylim([0, np.ceil(10 * max(max_missed_inj,
                                     found_inj[y_qty].max())) / 10])

# Handle axis limits when plotting inclination
if "incl" in x_qty or "incl" in y_qty:
    max_inc = np.pi
    # max_inc = max(np.concatenate((g_found[qty], g_ifar[qty], \
    #                               g_missed2[qty], missed_inj[qty])))
    max_inc_deg = np.rad2deg(max_inc)
    max_inc_deg = np.ceil(max_inc_deg/10.0)*10
    max_inc = np.deg2rad(max_inc_deg)
    if x_qty == "incl":
        ax.set_xlim(0, max_inc_deg)
    elif x_qty == "abs_incl":
        ax.set_xlim(0, max_inc_deg*0.5)
    if y_qty == "incl":
        ax.set_ylim(0, max_inc_deg)
    elif y_qty == "abs_incl":
        ax.set_ylim(0, max_inc_deg*0.5)
    # if "cos_incl" in [x_qty, y_qty]:
    if "cos_" in [x_qty, y_qty]:
        # tt = np.arange(0, max_inc_deg + 10, 10)
        tt = np.asarray([0, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120,
                         130, 140, 150, 180])
        tks = np.cos(np.deg2rad(tt))
        tk_labs = ['cos(%d deg)' % tk for tk in tt]
        # if x_qty == "cos_incl":
        if "cos_" in x_qty:
            plt.xticks(tks, tk_labs, fontsize=10)
            fig.autofmt_xdate()
            ax.set_xlim(np.cos(max_inc), 1)
            # ax.set_xlim(-1, 1)
        # if y_qty == "cos_incl":
        if "cos_" in y_qty:
            plt.yticks(tks, tk_labs, fontsize=10)
            fig.autofmt_xdate()
            ax.set_ylim(np.cos(max_inc), 1)
            # ax.set_ylim(-1, 1)

# Take care of caption
plot_caption = opts.plot_caption
if plot_caption is None:
    plot_caption = "Black cross indicates no trigger was found "
    plot_caption += "coincident with the injection.\n"
    plot_caption += "Red cross indicates a trigger was found "
    plot_caption += "coincident with the injection but it was vetoed.\n"
    plot_caption += "Yellow plus indicates that a trigger was found "
    plot_caption += "coincident with the injection and it was louder "
    plot_caption += "than all events in the offsource.\n"
    plot_caption += "Coloured circle indicates that a trigger was "
    plot_caption += "found coincident with the injection but it was "
    plot_caption += "not louder than all offsource events. The colour "
    plot_caption += "bar gives the p-value of the trigger."

# Take care of title
plot_title = opts.plot_title
if plot_title is None:
    title_dict = {'mchirp': "chirp mass",
                  'mtotal': "total mass",
                  'q': "mass ratio",
                  'distance': "distance (Mpc)",
                  'eff_dist': "inverse sum of effective distances",
                  'eff_site_dist': "site specific effective distance",
                  'end_time': "time",
                  'coa_phase': "phase of complex SNR",
                  'latitude': "latitude",
                  'longitude': "longitude",
                  'incl': "inclination",
                  'cos_incl': "inclination",
                  'abs_incl': "inclination",
                  'cos_abs_incl': "inclination",
                  'mass1': "mass",
                  'mass2': "mass",
                  'polarization': "polarization",
                  'spin1_a': "spin",
                  'spin1x': "spin x-component",
                  'spin1y': "spin y-component",
                  'spin1z': "spin z-component",
                  'spin2_a': "spin",
                  'spin2x': "spin x-component",
                  'spin2y': "spin y-component",
                  'spin2z': "spin z-component"}

    if "sky_error" in [x_qty, y_qty]:
        plot_title = "Sky error of recovered injections"
    else:
        plot_title = "Injection recovery with respect to "
        plot_title += title_dict[x_qty]
        plot_title += " and " + title_dict[y_qty]

# Wrap up
plt.tight_layout()
pycbc.results.save_fig_with_metadata(fig, outfile, cmd=' '.join(sys.argv),
                                     title=plot_title, caption=plot_caption)
plt.close()
logging.info("Plot complete.")
