#!/usr/bin/python3.13 -s
import logging
import numpy
import h5py
import argparse
import matplotlib
import sys
matplotlib.use('Agg')
import pylab

import pycbc.results
from pycbc.events import veto
from pycbc.io import (
    get_chisq_from_file_choice, chisq_choices, SingleDetTriggers, HFile
)

parser = argparse.ArgumentParser()
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--trigger-file', help='Single ifo trigger file')
parser.add_argument('--veto-file',
                    help='Optional, file of veto segments to remove triggers')
parser.add_argument('--segment-name', default=None, type=str,
                    help='Optional, name of segment list to use for vetoes')
parser.add_argument('--min-snr', type=float,
                    help='Optional, Minimum SNR to plot')
parser.add_argument('--output-file')
parser.add_argument(
    '--newsnr-contours',
    nargs='*',
    help="List of newsnr values to draw contours at. Only valid for "
         "--chisq-choice traditional (which is default)"
)
parser.add_argument('--chisq-choice', choices=chisq_choices,
                    default='traditional',
                    help='Which chisquared to plot. Default=traditional')
args = parser.parse_args()

pycbc.init_logging(args.verbose, default_level=1)


if args.newsnr_contours is not None and not args.chisq_choice == 'traditional':
    parser.error(
        "Newsnr contours are being plotted assuming a traditional "
        "chisquared, but this is not being plotted! We are plotting %s" %
        args.chisq_choice
    )

args.newsnr_contours = \
    args.newsnr_contours if args.newsnr_contours is not None else []

# First - get the IFO from the file
with HFile(args.trigger_file, 'r') as f:
    ifo = tuple(f.keys())[0]
    n_triggers = f[ifo]['snr'].size

logging.info(
    "Plotting SNR vs %s chisq for %s triggers",
    args.chisq_choice,
    ifo
)

if args.min_snr is not None:
    filter_rank = 'snr'
else:
    filter_rank = None
    if n_triggers > 1e7:
        logging.warning(
            "Getting %d triggers with no SNR cut - this will be expensive!",
            n_triggers
        )

trigs = SingleDetTriggers(
    args.trigger_file,
    ifo,
    veto_file=args.veto_file,
    segment_name=args.segment_name,
    filter_rank=filter_rank,
    filter_threshold=args.min_snr,
)

logging.info(
    "Reading SNR and %s chisq for %d triggers",
    args.chisq_choice,
    trigs.mask_size
)
snr = trigs.snr
chisq = get_chisq_from_file_choice(trigs, args.chisq_choice)


def snr_from_chisq(chisq, newsnr, q=6.):
    snr = numpy.zeros(len(chisq)) + float(newsnr)
    ind = numpy.where(chisq > 1.)[0]
    snr[ind] = float(newsnr) / ( 0.5 * (1. + chisq[ind] ** (q/2.)) ) ** (-1./q)
    return snr

fig = pylab.figure(1)

r = numpy.logspace(numpy.log(chisq.min()), numpy.log(chisq.max()), 300)

for i, cval in enumerate(args.newsnr_contours):
    logging.info("Plotting newsnr %s contour", cval)
    snrv = snr_from_chisq(r, cval)
    pylab.plot(snrv, r, color='black', lw=0.5)
    if i == 0:
        label = "$\\hat{\\rho} = %s$" % cval
    else:
        label = "$%s$" % cval
    try:
        label_pos_idx = numpy.where(snrv > snr.max() * 0.8)[0][0]
    except IndexError:
        label_pos_idx = 0
    pylab.text(snrv[label_pos_idx], r[label_pos_idx], label, fontsize=6,
               horizontalalignment='center', verticalalignment='center',
               bbox=dict(facecolor='white', lw=0, pad=0, alpha=0.9))

pylab.hexbin(snr, chisq, gridsize=300, xscale='log', yscale='log', lw=0.04,
             mincnt=1, norm=matplotlib.colors.LogNorm())

ax = pylab.gca()
pylab.grid()   
ax.set_xscale('log')
cb = pylab.colorbar() 
pylab.xlim(snr.min(), snr.max() * 1.1)
pylab.ylim(chisq.min(), chisq.max() * 1.1)
cb.set_label('Trigger Density')
pylab.xlabel('Signal-to-Noise Ratio')
pylab.ylabel('Reduced $\\chi^2$')
pycbc.results.save_fig_with_metadata(fig, args.output_file, 
     title="%s :SNR vs Reduced %s &chi;<sup>2</sup>" % (ifo, args.chisq_choice),
     caption="Distribution of SNR and %s &chi;&sup2; for single detector triggers: "
             "Black lines show contours of constant NewSNR." \
              %(args.chisq_choice,),
     cmd=' '.join(sys.argv),
     fig_kwds={'dpi':300})
