#!/usr/bin/python3.13 -s
""" Make a table of found injection information
"""
import argparse
import numpy as np
import sys
from itertools import combinations

import pycbc.results
import pycbc.detector
import pycbc.pnutils
import pycbc.events
from pycbc.io.hdf import HFile
from pycbc import add_common_pycbc_options, init_logging
from pycbc.types import MultiDetOptionAction


parser = argparse.ArgumentParser(description=__doc__)
add_common_pycbc_options(parser)
parser.add_argument('--injection-file',
                    help='HDF File containing the matched injections')
parser.add_argument('--single-trigger-files', nargs='*',
                    action=MultiDetOptionAction,
                    help="HDF format single detector trigger files")
parser.add_argument('--show-missed', action='store_true')
parser.add_argument('--output-file')
args = parser.parse_args()

init_logging(args.verbose)

f = HFile(args.injection_file,'r')
inj = f['injections']
found_cols, found_names, found_formats = [], [], []

ifos = f.attrs['ifos'].split(' ')

if args.show_missed:
    title = "Missed Injections"
    idx = f['missed/after_vetoes'][:]
else:
    title = "Found Injections"
    found = f['found_after_vetoes']
    idx = found['injection_index'][:]
    detectors = f.attrs['ifos'].split(' ')
    keys = f['found_after_vetoes'].keys()
    detectors_used = []
    found = f['found_after_vetoes']
    for det in detectors:
        if(det in keys):
            detectors_used.append(det)
    det_two_combo= np.array(list(combinations(detectors_used,2)))
    tdiff = []
    tdiff_str = []
    tdiff_format =[]
    for i in range(len(det_two_combo)):
        time_1 = np.array(found[det_two_combo[i,0]+'/time'][:])
        time_2 = np.array(found[det_two_combo[i,1]+'/time'][:])
        tdiff_vals = (time_1 - time_2) * 1000
        tdiff_vals[np.logical_or(time_1 < 0, time_2 < 0)] = np.nan
        tdiff_1 = ['%.2f' % td if not np.isnan(td) else ' ' for td in tdiff_vals]
        tdiff.append(tdiff_1)
        tdiff_head= '%s - %s time (ms)' % (det_two_combo[i,0], det_two_combo[i,1])
        tdiff_str.append(tdiff_head)
        tdiff_format.append('##.##')
    ids = {detector:found[detector+'/trigger_id'][:] for detector in detectors_used}

    found_cols = [found['stat'], found['ifar_exc']] + tdiff
    found_names = ['Ranking Stat.', 'Exc. IFAR'] + tdiff_str
    found_formats =  ['##.##', '##.##'] + tdiff_format


    if args.single_trigger_files:
        for ifo in args.single_trigger_files:
            f = HFile(args.single_trigger_files[ifo], 'r')[ifo]
            ids_ifo = np.array(ids[ifo])
            ids_na = np.argwhere(ids_ifo == -1)
            snr_vals = f['snr'][:][ids_ifo]
            snr_vals[ids_ifo == -1] = np.nan
            chisq_vals = f['chisq'][:][ids_ifo] / (2 * f['chisq_dof'][:][ids_ifo] - 2)
            chisq_vals[ids_ifo == -1] = np.nan
            newsnr_vals = pycbc.events.ranking.newsnr(snr_vals, chisq_vals)
            snr = ['%.2f' % s if not np.isnan(s) else ' ' for s in snr_vals]
            chisq = ['%.2f' % c if not np.isnan(c) else ' ' for c in chisq_vals]
            newsnr = ['%.2f' % s if not np.isnan(s) else ' ' for s in newsnr_vals]

            found_names += [ifo + " SNR", ifo + " CHISQ", ifo + " NewSNR"]
            found_cols += [snr, chisq, newsnr]
            found_formats += ['##.##', '##.##', '##.##']

eff_dist = {'eff_dist_%s' % i[0].lower() : 'Eff Dist (%s)' % i for i in ifos}

keys = inj.keys()
eff_dist_str = []
eff_distance = []
eff_dist_format = []
for dist in eff_dist :
    ifo = ('%s1' % dist.split('_')[-1]).upper()
    d = pycbc.detector.Detector(ifo)
    edist = d.effective_distance(
                 inj['distance'][:][idx],
                 inj['ra'][:][idx],
                 inj['dec'][:][idx],
                 inj['polarization'][:][idx],
                 inj['tc'][:][idx],
                 inj['inclination'][:][idx])
    eff_distance.append(edist)
    eff_dist_str.append(eff_dist[dist])
    eff_dist_format.append('##.##')

dec_dist = np.max(eff_distance, 0)
m1, m2 = inj['mass1'][:][idx], inj['mass2'][:][idx]
mchirp, eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)
dec_chirp_dist = pycbc.pnutils.chirp_distance(dec_dist, mchirp)

columns = [dec_chirp_dist, inj['tc'][:][idx], m1, m2, mchirp, eta,
           inj['spin1x'][:][idx], inj['spin1y'][:][idx], inj['spin1z'][:][idx],
           inj['spin2x'][:][idx], inj['spin2y'][:][idx], inj['spin2z'][:][idx],
           inj['distance'][:][idx]] + eff_distance + found_cols

names = ['DChirp Dist', 'Inj Time', 'Mass1', 'Mass2', 'Mchirp', 'Eta',
         's1x', 's1y', 's1z',
         's2x', 's2y', 's2z',
         'Dist']  + eff_dist_str + found_names

format_strings = ['##.##', '##.##', '##.##', '##.##', '##.##', '##.##',
                  '##.##', '##.##', '##.##',
                  '##.##', '##.##', '##.##',
                  '##.##'] + eff_dist_format +  found_formats
columns = [np.array(col) for col in columns]
html_table = pycbc.results.html_table(columns, names,
                                 format_strings=format_strings,
                                 page_size=20)

kwds = { 'title' : title,
        'caption' : "A table of %s and their coincident statistic information." % title.lower(),
        'cmd' :' '.join(sys.argv), }
pycbc.results.save_fig_with_metadata(str(html_table), args.output_file, **kwds)
