#!/usr/bin/python3.13 -s
""" Plot found and missed injections.
"""
import numpy
import logging
import os.path
import argparse
import sys
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plot

import pycbc.results.followup, pycbc.pnutils, pycbc.results
import pycbc.pnutils
from pycbc import init_logging, add_common_pycbc_options
from pycbc.detector import Detector
from pycbc.io.hdf import HFile

labels={'mchirp': 'Chirp Mass',
        'mtotal': 'Total Mass',
        'mass_ratio': 'Mass Ratio',
        'decisive_distance': 'Decisive Distance (Mpc)',
        'dec_chirp_distance': 'Decisive Chirp Distance (Mpc)',
        'min_eff_distance': 'Minimum Effective Distance (Mpc)',
        'min_eff_chirp_distance': 'Minimum Effective Chirp Distance (Mpc)',
        'chirp_distance': 'Chirp Distance (Mpc)',
        'comb_optimal_snr': 'Combined Optimal SNR',
        'decisive_optimal_snr': 'Decisive Optimal SNR',
        'max_optimal_snr': 'Maximum Optimal SNR',
        'redshift': 'Redshift',
        'time': 'Time (s)',
        'effective_spin': 'Effective Inspiral Spin',
       }

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--injection-file',
                    help="The hdf injection file to plot", required=True)
parser.add_argument('--axis-type', default='mchirp', choices=['mchirp',
                    'effective_spin', 'time', 'mtotal', 'mass_ratio'])
parser.add_argument('--log-x', action='store_true', default=False)
parser.add_argument('--distance-type', default='decisive_optimal_snr',
                    choices=list(labels),
                    help="Variable related to injected distance. Decisive "
                         "distance and dec chirp distance only available for "
                         "2-ifo search")
parser.add_argument('--plot-all-distance', action='store_true', default=False,
                    help="Plot all values of distance or SNR. If not given, "
                         "the plot will be truncated below 1")
parser.add_argument('--colormap',default='cividis_r',
                   help="Type of colormap to be used for the plots.")
parser.add_argument('--log-distance', action='store_true', default=False)
parser.add_argument('--dynamic', action='store_true', default=False)
parser.add_argument('--gradient-far', action='store_true',
                    help="Show far of found injections as a gradient")
parser.add_argument('--output-file', required=True)
parser.add_argument('--ifar-limits', nargs=2, type=float,
                    help="Supply upper and lower limits for IFAR colors. "
                         "0 indicates no limit.")
parser.add_argument('--far-type', choices=('inclusive', 'exclusive'),
                    default='inclusive',
                    help="Type of far to plot for the color. Choices are "
                         "'inclusive' or 'exclusive'. Default = 'inclusive'")
parser.add_argument('--missed-on-top', action='store_true',
                    help="Plot missed injections on top of found ones and "
                         "high FAR on top of low FAR")
args = parser.parse_args()

init_logging(args.verbose)

logging.info('Read in the data')
f = HFile(args.injection_file, 'r')
time = f['injections/tc'][:]
found = f['found_after_vetoes/injection_index'][:]
missed = f['missed/after_vetoes'][:]

if args.far_type == 'inclusive':
    ifar_found = f['found_after_vetoes/ifar'][:]
    far_title = 'Inclusive'
elif args.far_type == 'exclusive':
    ifar_found = f['found_after_vetoes/ifar_exc'][:]
    far_title = 'Exclusive'

upper_lim = False
lower_lim = False
if args.ifar_limits:
    if args.ifar_limits[0] < 0 or args.ifar_limits[1] < 0:
        parser.error("IFAR limits must be zero (no limit) or positive")
    if args.ifar_limits[0] > 0:
        # Lower limit on IFAR
        lower_lim = True
        ifar_found = numpy.maximum(ifar_found, args.ifar_limits[0])
    if args.ifar_limits[1] > 0:
        # Upper limit on IFAR
        upper_lim = True
        ifar_found = numpy.minimum(ifar_found, args.ifar_limits[1])

s1z = f['injections/spin1z'][:]
s2z = f['injections/spin2z'][:]
dist = f['injections/distance'][:]
m1, m2 = f['injections/mass1'][:], f['injections/mass2'][:]

vals = {}
vals['mchirp'], eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)
# GRRR should use pnutils formula
vals['effective_spin'] = (m1 * s1z + m2 * s2z) / (m1 + m2)
vals['time'] = time
vals['mtotal'] = m1 + m2
# use convention that q>1
vals['mass_ratio'] = numpy.maximum(m1/m2, m2/m1)

dvals = {}

if 'ifos' in f.attrs:
    ifos = f.attrs['ifos'].split(' ')
else:
    logging.warning("Ifos not found in input file, assuming H-L")
    ifos = ['H1', 'L1']

# NOTE: Effective distance is hardcoded to these values. It's also normally
#       meaningless for precessing injections.
try:
    eff_dists = []
    for ifo in ['H1', 'L1', 'V1']:
        eff_dists.append(Detector(ifo).effective_distance(
                             f['injections/distance'][:],
                             f['injections/ra'][:],
                             f['injections/dec'][:],
                             f['injections/polarization'][:],
                             f['injections/tc'][:],
                             f['injections/inclination'][:]))
    eff_dists = numpy.array(eff_dists).T

    # "Decisive" distance for coincidences is the second smallest
    # effective distance
    dvals['decisive_distance'] = numpy.sort(eff_dists)[:,1]
    dvals['dec_chirp_distance'] = \
        pycbc.pnutils.chirp_distance(dvals['decisive_distance'],
                                     vals['mchirp'])
    # When single-detector triggers are included in the analysis,
    # minimum distance is the relevant quantity for injection efficiency
    dvals['min_eff_distance'] = numpy.min(eff_dists, axis=1)
    dvals['min_eff_chirp_distance'] = \
        pycbc.pnutils.chirp_distance(dvals['min_eff_distance'],
                                     vals['mchirp'])
except KeyError:
    # If the ifo isn't in the effective distance columns you can't get this.
    # But you can still use other values.
    pass

dvals['chirp_distance'] = pycbc.pnutils.chirp_distance(dist, vals['mchirp'])
if args.distance_type == 'redshift':
    dvals['redshift'] = f['injections/redshift'][:]

if 'snr' in args.distance_type:  # only evaluate SNRs if needed
    opt_snrsq_arr = \
        [f['injections/optimal_snr_%s' % ifo][:] ** 2. for ifo in ifos]
    dvals['comb_optimal_snr'] = \
        numpy.array([numpy.sqrt(sum(opt_snrsq))
                     for opt_snrsq in zip(*opt_snrsq_arr)])
    # Decisive optimal SNR is the 2nd largest optimal SNR
    dvals['decisive_optimal_snr'] = \
        numpy.array([numpy.sqrt(sorted(opt_snrsq)[-2])
                     for opt_snrsq in zip(*opt_snrsq_arr)])
    dvals['max_optimal_snr'] = \
        numpy.array([numpy.sqrt(max(opt_snrsq))
                     for opt_snrsq in zip(*opt_snrsq_arr)])

fdvals = dvals[args.distance_type][found]
mdvals = dvals[args.distance_type][missed]

if args.missed_on_top:
  fig_title = 'Missed and Found Injections'
else:
  fig_title = 'Found and Missed Injections'

fig = plot.figure()
zmissed = args.missed_on_top
zfound = not args.missed_on_top

mpoints = plot.scatter(vals[args.axis_type][missed], mdvals, s=16,
                       linewidth=0.5, marker='x', color='red',
                       label='missed', zorder=zmissed)

fvals = vals[args.axis_type][found]

ifsort = numpy.argsort(ifar_found)
if args.missed_on_top:
    ifsort = ifsort[::-1]

fvals = fvals[ifsort]
fdval = fdvals[ifsort]
ifsorted = ifar_found[ifsort]

if not args.gradient_far:
    color = numpy.ones(len(found))
    ten = numpy.where(ifsorted > 10)[0]
    hundred = numpy.where(ifsorted > 100)[0]
    thousand = numpy.where(ifsorted > 1000)[0]
    color[hundred] = 0.5
    color[thousand] = 0

    norm = matplotlib.colors.Normalize()
    caption = (fig_title + ": Red x's are missed injections. "
              "Blue circles are found with IFAR < 100 years, gray are < "
              "1000 years, and yellow are found with IFAR >=1000 years. ")
else:
    color = 1.0 / ifsorted
    if len(color) < 2:
        color=None

    norm = matplotlib.colors.LogNorm()
    caption = (fig_title + ": Red x's are missed injections. "
               "Circles are found injections. The color indicates the value "
               "of the false alarm rate." )

points = plot.scatter(fvals, fdval, c=color, linewidth=0, s=16, norm=norm,
                      marker='o', label='found', zorder=zfound,
                      cmap=args.colormap)
if args.gradient_far:
    try:
        if upper_lim and lower_lim:
            ext = 'both'
        elif lower_lim and not upper_lim:
            ext = 'max'
        elif upper_lim and not lower_lim:
            ext = 'min'
        else:
            ext = 'neither'
        c = plot.colorbar(extend=ext)
        c.set_label('False Alarm Rate $(yr^{-1})$, %s' % far_title)

        # Set up tick labels - there will be 5 if possible
        min_tick = numpy.ceil(min(numpy.log10(color)))
        max_tick = numpy.floor(max(numpy.log10(color)))
        tick_step = max(1, numpy.floor((max_tick - min_tick) / 5))
        ticks = numpy.arange(min_tick, max_tick, tick_step)
        c.set_ticks(numpy.power(10, ticks))
    except (TypeError, ZeroDivisionError):
        # Can't make colorbar if no quiet found injections
        if len(fvals):
            raise

if args.missed_on_top:
  caption += "Missed injections are shown on top of found injections."
else:
  caption += "Found injections are shown on top of missed injections."

ax = plot.gca()
plot.xlabel(labels[args.axis_type])
plot.ylabel(labels[args.distance_type])
plot.grid()

if args.log_x:
    # log x axis may fail for some choices, eg effective spin
    ax.set_xscale('log')
    tmpxvals = list(vals[args.axis_type][missed])
    tmpxvals += list(vals[args.axis_type][found])
    xmax = 1.4 * max(tmpxvals)
    xmin = 0.7 * min(tmpxvals)
    plot.xlim(xmin, xmax)
if args.log_distance:
    ax.set_yscale('log')

tmpyvals = list(fdvals) + list(mdvals)
ymax = 1.2 * max(tmpyvals)
ymin = 0.9 * min(tmpyvals)
if args.plot_all_distance:
    # note: ymin=0 will clash with args.log_distance
    # in that case it *should* throw an error!
    plot.ylim(ymin, ymax)
elif args.distance_type == 'redshift':
    # default y limit: min redshift 0
    plot.ylim(ymin=0, ymax=ymax)
else:
    # arbitrary limit of 1 distance-unit or snr-unit
    plot.ylim(ymin=1, ymax=ymax)

fig_kwds = {}
if '.png' in args.output_file:
    fig_kwds['dpi'] = 200

if ('.html' in args.output_file):
    plot.subplots_adjust(left=0.1, right=0.8, top=0.9, bottom=0.1)
    import mpld3, mpld3.plugins, mpld3.utils
    mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fmt='.5g'))
    legend =  mpld3.plugins.InteractiveLegendPlugin([mpoints, points],
                                                    ['missed', 'found'],
                                                    alpha_unsel=0.1)
    mpld3.plugins.connect(fig, legend)

title = '%s: %s vs %s' % (fig_title, args.axis_type, args.distance_type)
cmd = ' '.join(sys.argv)
pycbc.results.save_fig_with_metadata(fig, args.output_file, fig_kwds=fig_kwds,
                                     title=title, cmd=cmd, caption=caption)
