#!/usr/bin/python3.13 -s

import sys
import numpy, argparse, matplotlib
from matplotlib import colors
matplotlib.use('Agg')
import pylab, pycbc.results
from pycbc.io import (
    get_chisq_from_file_choice, chisq_choices, SingleDetTriggers, HFile
)
from pycbc import conversions, init_logging, add_common_pycbc_options
from pycbc.detector import Detector

def snr_from_chisq(chisq, newsnr, q=6.):
    snr = numpy.zeros(len(chisq)) + float(newsnr)
    ind = numpy.where(chisq > 1.)[0]
    snr[ind] = float(newsnr) / ( 0.5 * (1. + chisq[ind] ** (q/2.)) ) ** (-1./q)
    return snr

parser = argparse.ArgumentParser()
add_common_pycbc_options(parser)
parser.add_argument('--found-injection-file', required=True,
                    help='HDF format found injection file. Required')
parser.add_argument('--single-injection-file', required=True,
                    help='Single detector trigger files from the injection set'
                         ': one file per ifo')
parser.add_argument('--coinc-statistic-file', required=True,
                    help='HDF format statistic file. Required')
parser.add_argument('--single-trigger-file', required=True,
                    help='Single detector trigger files from the zero lag run.'
                         ' Required')
parser.add_argument('--newsnr-contours', nargs='*', default=[],
                    help="List of newsnr values to draw contours. Optional")
parser.add_argument('--background-front', action='store_true', default=False,
                    help='If set, plot background on top of injections rather '
                         'than vice versa')
parser.add_argument('--colorbar-choice', choices=('effective_spin', 'mchirp',
                    'eta', 'effective_distance', 'mtotal', 'optimal_snr',
                    'redshift'), default='effective_distance',
                    help='Parameter to use for the colorbar. '
                         'Default=effective_distance')
parser.add_argument('--chisq-choice', choices=chisq_choices,
                    default='traditional',
                    help='Which chisquared to plot. Default=traditional')
parser.add_argument('--output-file', required=True)
args = parser.parse_args()

init_logging(args.verbose)
# First - check the IFO being used
with HFile(args.single_trigger_file, 'r') as stf:
    ifo = tuple(stf.keys())[0]

# Add the background triggers
with HFile(args.coinc_statistic_file, 'r') as csf:
    b_tids = csf['background_exc'][ifo]['trigger_id'][:]
    # Remove trigger ids == -1, as this indicates that it was found in
    # other detector(s)
    b_tids = b_tids[b_tids >= 0]
    ifos = csf.attrs['ifos'].split(' ')

trigs = SingleDetTriggers(
    args.single_trigger_file,
    ifo,
    premask=b_tids
)

bkg_snr = trigs.snr
bkg_chisq = get_chisq_from_file_choice(trigs, args.chisq_choice)

# SingleDetTriggers will have discarded the ordering of the tids, and any
# duplicates, so we need to do a bit of index juggling to get this back.
# Numpy's unique() function also removes ordering and duplicates, and
# the return_inverse gives us the array to convert it back!:
_, bkg_tid_idx_inverse = numpy.unique(b_tids, return_inverse=True)
bkg_snr = bkg_snr[bkg_tid_idx_inverse]
bkg_chisq = bkg_chisq[bkg_tid_idx_inverse]

# don't plot if chisq is not calculated
bkg_pos = bkg_chisq > 0
bkg_snr = bkg_snr[bkg_pos]
bkg_chisq = bkg_chisq[bkg_pos]
fig = pylab.figure()
pylab.scatter(bkg_snr, bkg_chisq, marker='o', color='black',
              linewidth=0, s=4, label='Background', alpha=0.6,
              zorder=args.background_front)

# Add the found injection points
f = HFile(args.found_injection_file, 'r')
inj_tid = f['found_after_vetoes'][ifo]['trigger_id'][:]
# Remove trigger ids == -1, as this indicates that it was found in
# other detector(s)
valid_inj = inj_tid >= 0
inj_tid = inj_tid[valid_inj]

eff_dists = Detector(ifo).effective_distance(f['injections/distance'][:],
                                             f['injections/ra'][:],
                                             f['injections/dec'][:],
                                             f['injections/polarization'][:],
                                             f['injections/tc'][:],
                                             f['injections/inclination'][:])

inj_idx = f['found_after_vetoes/injection_index'][valid_inj]
eff_dist = eff_dists[inj_idx]
m1, m2 = f['injections/mass1'][:][inj_idx], f['injections/mass2'][:][inj_idx]
s1, s2 = f['injections/spin1z'][:][inj_idx], f['injections/spin2z'][:][inj_idx]
mchirp = conversions.mchirp_from_mass1_mass2(m1, m2)
eta = conversions.eta_from_mass1_mass2(m1, m2)
weighted_spin = conversions.chi_eff(m1, m2, s1, s2)
redshift = f['injections/redshift'][:][inj_idx] if \
    args.colorbar_choice == 'redshift' else None

# choices to color the found injections
coloring = {'effective_distance': (eff_dist, "Effective Distance (Mpc)",
                                                             colors.LogNorm()),
            'mchirp': (mchirp, "Chirp Mass", colors.LogNorm()),
            'eta': (eta, "Symmetric Mass Ratio", colors.LogNorm()),
            'effective_spin': (weighted_spin, "Weighted Aligned Spin", None),
            'mtotal': (m1 + m2, "Total Mass", colors.LogNorm()),
            'redshift': (redshift, "Redshift", None)
            }

if 'optimal_snr_{}'.format(ifo) in f['injections']:
    opt_snr_str = 'injections/optimal_snr_{}'.format(ifo)
    opt_snr = f[opt_snr_str][:][inj_idx]
    coloring['optimal_snr'] = (opt_snr, 'Optimal SNR', colors.LogNorm())

f = HFile(args.single_injection_file, 'r')[ifo]
if len(inj_tid):
    inj_trigs = SingleDetTriggers(
        args.single_injection_file,
        ifo,
        premask=inj_tid,
    )
    inj_snr = inj_trigs.snr
    inj_chisq = get_chisq_from_file_choice(inj_trigs, args.chisq_choice)

    # Again, the SingleDetTriggers will have discarded order etc. of tids,
    # so we need to do index juggling
    _, inj_tid_idx_inverse = numpy.unique(
        inj_tid,
        return_inverse=True
    )
    inj_snr = inj_snr[inj_tid_idx_inverse]
    inj_chisq = inj_chisq[inj_tid_idx_inverse]

else:
    inj_snr = numpy.array([])
    inj_chisq = numpy.array([])

inj_pos = inj_chisq > 0
# Catch not enough found injections case
if len(coloring[args.colorbar_choice][0]) == 0:
    coloring[args.colorbar_choice] = (None, None, None)
else:  # Only plot positive chisq
    pylab.scatter(inj_snr[inj_pos], inj_chisq[inj_pos],
                  c=coloring[args.colorbar_choice][0][inj_pos],
                  norm=coloring[args.colorbar_choice][2], s=20,
                  marker='^', linewidth=0, label="Injections",
                  zorder=(not args.background_front))

try:
    r = numpy.logspace(numpy.log(min(bkg_chisq.min(), inj_chisq[inj_pos].min())
                                 * 0.9),
                   numpy.log(max(bkg_chisq.max(), inj_chisq.max()) * 1.1), 200)
except ValueError:
    # Allow code to continue in the absence of injection triggers
    r = numpy.logspace(numpy.log(bkg_chisq.min() * 0.9),
                       numpy.log(bkg_chisq.max() * 1.1), 200)

if args.newsnr_contours:
    for cval in args.newsnr_contours:
        snrv = snr_from_chisq(r, cval)
        pylab.plot(snrv, r, '--', color='grey', linewidth=1)

ax = pylab.gca()
ax.set_xscale('log')
ax.set_yscale('log')

try:
    cb = pylab.colorbar()
    cb.set_label(coloring[args.colorbar_choice][1], size='large')
except (TypeError, ZeroDivisionError):
    # Catch case of no injection triggers
    if len(inj_chisq):
        raise

pylab.title('%s Coincident Triggers' % ifo, size='large')
pylab.xlabel('SNR', size='large')
pylab.ylabel('Reduced $\chi^2$', size='large')
try:
    pylab.xlim(min(inj_snr.min(), bkg_snr.min()) * 0.99,
               max(inj_snr.max(), bkg_snr.max()) * 1.4)
    pylab.ylim(min(bkg_chisq.min(), inj_chisq[inj_pos].min()) * 0.7,
               max(bkg_chisq.max(), inj_chisq.max()) * 1.4)
except ValueError:
    # Raised if no injection triggers
    pass
pylab.legend(loc='lower right', prop={'size': 10})
pylab.grid(which='major', ls='solid', alpha=0.7, linewidth=.5)
pylab.grid(which='minor', ls='solid', alpha=0.7, linewidth=.1)

title = '%s %s chisq vs SNR. %s background with injections %s' \
        % (ifo.upper(), args.chisq_choice, ''.join(ifos).upper(),
           'behind' if args.background_front else 'ontop')
caption = """Distribution of SNR and %s chi-squared veto for single detector
triggers. Black points are %s background triggers. Triangles are injection
triggers colored by %s of the injection. Dashed lines show contours of
constant NewSNR.""" % (args.chisq_choice, ''.join(ifos).upper(),
                       coloring[args.colorbar_choice][1])
pycbc.results.save_fig_with_metadata(fig,
                                     args.output_file,
                                     title=title,
                                     caption=caption,
                                     cmd=' '.join(sys.argv),
                                     fig_kwds={'dpi':200})
