#!/usr/bin/python3.11

"""Make table of the foreground events. Also includes ability to write FARs and
p-values for intermediary hierarchical removal steps.
"""

import sys
import argparse
import logging
import numpy
import h5py
import pycbc
import pycbc.results
import pycbc.version
from pycbc.io import hdf
from pycbc.pnutils import mass1_mass2_to_mchirp_eta


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-file', required=True)
parser.add_argument('--bank-file', required=True)
parser.add_argument('--single-detector-triggers', nargs='+')
parser.add_argument('--verbose', action='count')
parser.add_argument('--output-file', required=True)
parser.add_argument('--num-to-write', type=int)
parser.add_argument('--use-hierarchical-level', type=int,
                    help='Indicate which FARs to write to the table '
                         'based on the number of hierarchical removals done. '
                         'Choosing None defaults to giving the FARs after '
                         'all hierarchical removals were done depending on '
                         'previous configuration choices. Choosing 0 selects '
                         'writing the FARs prior to any hierarchical '
                         'removals. Choosing 1 means writing the FARs after '
                         'doing 1 hierarchical removal. The program will '
                         'fail if the user selects a number above the number '
                         'of hierarchical removals done. [default=None]')
args = parser.parse_args()

pycbc.init_logging(args.verbose)

with h5py.File(args.trigger_file, 'r') as f:
    try:
        h_iterations = f.attrs['hierarchical_removal_iterations']
    except KeyError:
        h_iterations = None

# Parse which inclusive background to use for the plotting
h_num_rm = args.use_hierarchical_level

if h_num_rm is not None and (h_iterations is None or h_num_rm > h_iterations):
    col_one = numpy.array([h_num_rm])
    col_two = numpy.array([h_iterations])
    columns = [col_one, col_two]
    names = ["Hierarchical Removals Requested", "Hierarchical Removals Performed"]
    format_strings = ['#', '#']
    html_table = pycbc.results.html_table(columns, names,
                                   format_strings=format_strings, page_size=10)

    kwds = { 'title' :
        'No more events louder than all background at this removal level.',
        'cmd' :' '.join(sys.argv), }
    pycbc.results.save_fig_with_metadata(
        str(html_table),
        args.output_file,
        **kwds
    )
    sys.exit(0)

if args.output_file.endswith('.xml') or args.output_file.endswith('.xml.gz'):
    if args.single_detector_triggers is None:
        err_msg = "If creating xml files must provide the single detector "
        err_msg += "trigger lists with --single-detector-triggers."
        raise ValueError(err_msg)

# Ensure that the user uses a trigger file that is hierarchical
# removal compatible. Also default to regular behavior if hierarchical
# removals were not performed.
if h_num_rm is not None and h_iterations is not None \
        and h_num_rm >= 0 and h_iterations != 0:
    # read in the triggers from the group corresponding to the removal stage
    fortrigs = hdf.ForegroundTriggers(args.trigger_file, args.bank_file,
                                      sngl_files=args.single_detector_triggers,
                                      n_loudest=args.num_to_write,
                                      group='foreground_h%s' % h_num_rm)
    # read in the regular triggers to get the exclusive ifar which is
    # independent of any removal
    fortrigs_exclusive = hdf.ForegroundTriggers(args.trigger_file, args.bank_file,
                                           sngl_files=args.single_detector_triggers,
                                           n_loudest=args.num_to_write)

else :
    fortrigs = hdf.ForegroundTriggers(args.trigger_file, args.bank_file,
                                      sngl_files=args.single_detector_triggers,
                                      n_loudest=args.num_to_write)

if args.output_file.endswith('.html'):
    ifar = fortrigs.get_coincfile_array('ifar')
    fap = fortrigs.get_coincfile_array('fap')
    stat = fortrigs.get_coincfile_array('stat')
    mass1 = fortrigs.get_bankfile_array('mass1')
    mass2 = fortrigs.get_bankfile_array('mass2')
    spin1z = fortrigs.get_bankfile_array('spin1z')
    spin2z = fortrigs.get_bankfile_array('spin2z')
    time = fortrigs.get_end_time()

    if h_num_rm is not None and h_iterations is not None \
            and h_num_rm >= 0 and h_iterations != 0:
        # FIXME: Currently hierarchical removal is not in the multi-detector
        #        workflow. If this functionality is added, and we want to make
        #        this table for hierarchically removed triggers, we would need
        #        to change the code below to support that.

        # This information is not stored at all levels of the trigger file
        # so grab from the top level and populate the variables for output later

        ifar_exc_orig = fortrigs_exclusive.get_coincfile_array('ifar_exc')
        fap_exc_orig = fortrigs_exclusive.get_coincfile_array('fap_exc')

        trig_id1 = fortrigs_exclusive.get_coincfile_array('trigger_id1')
        trig_id1_hrm = fortrigs.get_coincfile_array('trigger_id1')

        # Get the list of id's that map from one to the other set
        idx1_list = []
        for i in range(len(trig_id1)):
            for j in range(len(trig_id1_hrm)):
                 if trig_id1[i] == trig_id1_hrm[j]:
                     idx1_list.append(i)

        ifar_exc = numpy.zeros(len(ifar), dtype=numpy.float64)
        fap_exc = numpy.zeros(len(ifar), dtype=numpy.float64)
        for k in range(len(idx1_list)):
            ifar_exc[k] = ifar_exc_orig[idx1_list[k]]
            fap_exc[k] = fap_exc_orig[idx1_list[k]]

    else :
        ifar_exc = fortrigs.get_coincfile_array('ifar_exc')
        fap_exc = fortrigs.get_coincfile_array('fap_exc')

    mchirp, eta = mass1_mass2_to_mchirp_eta(mass1, mass2)

    columns = [ifar_exc, ifar, fap_exc, fap, stat,  time,
               mchirp, mass1, mass2, spin1z, spin2z]

    # Not supposed to use FAP anymore, nomenclature is p-value. Should be
    # fixed consistently accross PyCBC.
    names = ['Exc. IFAR (YR)', 'Inc. IFAR (YR)', 'Exc. FAP',
             'Inc. FAP', 'Ranking Statistic', 'End Time',
             'mchirp', 'm1','m2', 's1z', 's2z']
    format_strings = ['#.###E0', '#.###E0', '#.##E0', '#.##E0', '##.###',
                      None, '##.##', '##.##', '##.##', '##.##', '##.##']

    if args.single_detector_triggers:
        single_snr = fortrigs.get_snglfile_array_dict('snr')
        single_chisq = fortrigs.get_snglfile_array_dict('chisq')
        single_chisq_dof = fortrigs.get_snglfile_array_dict('chisq_dof')

        for ifo in sorted(single_snr.keys()):
            snrs = []
            chisqs = []
            for idx in range(len(single_snr[ifo][0])):
                if single_snr[ifo][1][idx]:
                    snrs.append('%.2f' % single_snr[ifo][0][idx])
                    rchisq = single_chisq[ifo][0][idx] / \
                        (single_chisq_dof[ifo][0][idx]*2 - 2)
                    chisqs.append('%.2f' % rchisq)
                else:
                    snrs.append(' ')
                    chisqs.append(' ')
            columns.append(numpy.array(snrs))
            names.append("%s SNR" %(ifo))
            format_strings.append("##.##")
            columns.append(numpy.array(chisqs))
            names.append("%s Red. Chisq" %(ifo))
            format_strings.append("##.##")

    logging.info('Making table of foreground triggers')
    html_table = pycbc.results.html_table(columns, names,
                                   format_strings=format_strings, page_size=10)

    kwds = { 'title' : 'Loudest Event Table',
        'cmd' :' '.join(sys.argv), }
    pycbc.results.save_fig_with_metadata(
        str(html_table),
        args.output_file,
        **kwds
    )

elif args.output_file.endswith('.xml') or args.output_file.endswith('.xml.gz'):
    fortrigs.to_coinc_xml_object(args.output_file)

elif args.output_file.endswith('.hdf'):
    fortrigs.to_coinc_hdf_object(args.output_file)

else:
    raise NotImplementedError("Output file format not recognised")
