#!/usr/bin/python3.13 -s
# Copyright (C) 2020 Josh Willis
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""
Detailed comparison of a specified injection set between two PyCBC runs
"""

import logging
from glob import glob
from os import path
import argparse
import numpy as np
from h5py import File

from pycbc.events import ranking
import pycbc


def parse_injection_path(injname, basedir):
    dirpath = path.expandvars(basedir)
    dirpath = path.expanduser(dirpath)
    dirpath = path.normpath(dirpath)
    if not path.exists(dirpath):
       raise RuntimeError("Directory {0} does not exist".format(basedir))
    if not path.isdir(dirpath):
       raise RuntimeError("Path {0} does not specify a"
                          " directory".format(basedir))
    bankpath = path.join(dirpath, "bank")
    if not path.exists(bankpath):
        raise RuntimeError("There is no bank sub-directory of"
                           " {0}".format(basedir))
    if not path.isdir(bankpath):
        raise RuntimeError("Path {0} does not specify a"
                           " directory".format(bankpath))
    # The following is really not that robust
    bank_files = glob(bankpath+"/*-BANK2HDF-*.hdf")
    if len(bank_files) != 1:
        raise RuntimeError("There is not exactly one complete HDF bank file in"
                           " the path {0}".format(bankpath))
    bank_file = bank_files[0]
    dirpath = path.join(dirpath, "{0}_INJ_coinc".format(injname))
    if not path.exists(dirpath):
        raise RuntimeError("There is no sub-directory {0}_INJ_coinc in"
                           " {1}".format(injname, basedir))
    if not path.isdir(dirpath):
        raise RuntimeError("Path {0}/{1}_INJ_coinc is not a"
                           " directory".format(basedir, injname))
    injfile = glob(dirpath+"/*HDFINJFIND*")
    if len(injfile) == 0:
        raise RuntimeError("No found-injections file in directory"
                           " {0}".format(dirpath))
    if len(injfile) > 1:
        raise RuntimeError("More than one found-injections file in directory"
                           " {0}".format(dirpath))
    injfile = injfile[0]
    return dirpath, injfile, bank_file

# Below are the parameters we use in hashing---for injections especially, we may
# need to watch out for additional meaningful parameters. Note, it does not work
# to simply take all of the keys present in the injection data group of an
# HDFINJFIND file---some of them can contain tiny numerical differences that
# spoil their use in a hash.
hash_inj_params = ['coa_phase', 'distance', 'end_time', 'inclination',
                   'latitude', 'longitude', 'mass1', 'mass2',
                   'polarization', 'spin1x', 'spin1y', 'spin1z',
                   'spin2x', 'spin2y', 'spin2z']

class Injections(object):
    def __init__(self, inj_group):
        self.injgroup = inj_group
        self.inj_params = self.injgroup.keys()
        self.ninj = len(self.injgroup[self.inj_params[0]][:])
        self.inj_hashes = np.array([hash(v) for v in
                                    zip(*[self.injgroup[p] for p \
                                          in hash_inj_params])])
        if not np.all(self.inj_hashes):
            raise RuntimeError("Not all injection hashes were finite")

    def __eq__(self, other):
        if not isinstance(other, Injections):
            raise ValueError("{0} cannot be compared to Injections instance; is"
                             " not itself an instance".format(other))
        if self.ninj == other.ninj:
            if set(self.inj_params) == set(other.inj_params):
                return (self.inj_hashes == other.inj_hashes).all()
            else:
                return False
        else:
            return False

class RunInjectionResults(object):
    def __init__(self, injname, basedir,
                 single_detector_statistic='newsnr',
                 nmissed=10, ifar_threshold=None):
        self.injname = injname
        # First, get the necessary files and load them up
        dirpath, injfile, bankfile = parse_injection_path(injname, basedir)
        self.dirpath = dirpath
        self.nmissed = nmissed
        self.ifar_threshold = ifar_threshold
        logging.info("Reading injection directory {0}".format(self.dirpath))
        self.injfile = File(injfile, "r")
        self.bankfile = File(bankfile, "r")
        self.inj_dets = [self.injfile.attrs['detector_1'],
                         self.injfile.attrs['detector_2']]
        self.trigger_merge_dict = {}
        for ifo in self.inj_dets:
            self.trigger_merge_dict.update({ifo: {}})
            merge_file_list = glob(self.dirpath+
                                   "/{0}-HDF_TRIGGER_MERGE*".format(ifo))
            if len(merge_file_list) != 1:
                raise RuntimeError("There is not exactly one trigger file for"
                                   " IFO {0} in directory {1}".format(
                                   ifo, self.dirpath))
            self.trigger_merge_dict[ifo].update({'fileptr':
                                                 File(merge_file_list[0], "r")})
            dgroup = self.trigger_merge_dict[ifo]['fileptr'][ifo]
            self.trigger_merge_dict[ifo].update({'data': dgroup})
        # Next, load the information about the injections themselves.
        self.injgroup = self.injfile['injections']
        self.injections = Injections(self.injgroup)
        # Now, read the information from the found injection file. The trigger
        # merge files are typically quite large; we do not automatically load
        # them into memory
        self.found = {}
        self.found_after_vetoes = {}
        self.missed = {}
        for ifo in self.inj_dets:
            self.found.update({ifo: {}})
            self.found_after_vetoes.update({ifo: {}})
        for param in ['fap', 'fap_exc', 'ifar', 'ifar_exc',
                      'injection_index', 'stat']:
            self.found.update({param: self.injfile['found'][param][:]})
            self.found_after_vetoes.update(
                {param: self.injfile['found_after_vetoes'][param][:]})
        template_hashes = self.bankfile['template_hash'][:]
        for found_class in ['found', 'found_after_vetoes']:
            fdict = getattr(self, found_class)
            idx = self.injfile[found_class]['template_id'][:]
            fdict.update({'template_hash': template_hashes[idx]})
            # The ordering of detectors is arbitrary in the found
            # injection file, so we disentangle it and only record
            # via the IFO name
            d1 = self.injfile.attrs['detector_1']
            d2 = self.injfile.attrs['detector_2']
            fdict[d1].update({'trigger_time':
                              self.injfile[found_class]['time1'][:]})
            fdict[d1].update({'trigger_id':
                              self.injfile[found_class]['trigger_id1'][:]})
            fdict[d2].update({'trigger_time':
                              self.injfile[found_class]['time2'][:]})
            fdict[d2].update({'trigger_id':
                              self.injfile[found_class]['trigger_id2'][:]})
        for param in ['after_vetoes', 'all', 'within_analysis']:
            self.missed.update({param: self.injfile['missed'][param][:]})
        # Record what the single detector statistic is
        self.single_detector_statistic = single_detector_statistic
        # Next, read information from the two trigger merge files
        logging.info("Reading single detector triggers from {0}; finding"
                     "minimum of {1}".format(
                         self.dirpath, self.single_detector_statistic))
        for fdict in [self.found, self.found_after_vetoes]:
            for ifo in self.inj_dets:
                for param in ['chisq', 'chisq_dof', 'snr', 'sg_chisq',
                              'sigmasq', 'coa_phase', 'end_time']:
                    if param in self.trigger_merge_dict[ifo]['data'].keys():
                        idx = fdict[ifo]['trigger_id']
                        vals = \
                            self.trigger_merge_dict[ifo]['data'][param][:][idx]
                        # A hack:
                        if param == 'chisq_dof':
                            vals = vals / 2 + 1
                        fdict[ifo].update({param: vals})
                # Initialize statclass with an empty file list
                curr_rank = ranking.get_sngls_ranking_from_trigs(
                    fdict[ifo],
                    single_detector_statistic
                )
                fdict[ifo].update({single_detector_statistic:
                                   curr_rank})
            statname = self.single_detector_statistic
            fdict.update({'minimum_single_detector_statistic': 
                          np.minimum(fdict[self.inj_dets[0]][statname],
                                     fdict[self.inj_dets[1]][statname])})

    def close(self):
        self.injfile.close()
        self.bankfile.close()
        for ifo in self.inj_dets:
            self.trigger_merge_dict[ifo]['fileptr'].close()

    def __str__(self):
        return "RunInjectionResults instance for injection set {0} from" \
            " directory {1}".format(self.injname, self.dirpath)

    def __eq__(self, other):
        if not isinstance(other, RunInjectionResults):
            raise ValueError("{0} cannot be compared to RunInjectionResults"
                             " instance; is not itself an instance".format(
                                 other))
        return ((set(self.inj_dets) == set(other.inj_dets)) and
                (self.injections == other.injections))

    def write_to_hdf(self, outpath):
        outfp = File(outpath, "w")
        outfp.attrs['injection_set'] = self.injname
        outfp.attrs['original_directory'] = self.dirpath
        outfp.attrs['single_detector_statistic'] = \
                                                self.single_detector_statistic
        outfp.attrs['detector_1'] = self.inj_dets[0]
        outfp.attrs['detector_2'] = self.inj_dets[1]
        self.injfile.copy('injections', outfp)
        missed_group = outfp.create_group('missed')
        for col in self.missed.keys():
            missed_group.create_dataset(col, data=self.missed[col])
        for fname in ['found', 'found_after_vetoes']:
            fgroup = outfp.create_group(fname)
            fdict = getattr(self, fname)
            keys_no_ifos = fdict.keys()
            for ifo in self.inj_dets:
                keys_no_ifos.remove(ifo)
            for col in keys_no_ifos:
                fgroup.create_dataset(col, data=fdict[col])
            for ifo in self.inj_dets:
                igroup = fgroup.create_group(ifo)
                for col in fdict[ifo].keys():
                    igroup.create_dataset(col, data=fdict[ifo][col])
        outfp.close()


def populate_injection_group(fgroup, fidx, fdict, ifos):
    keys = fdict.keys()
    for ifo in ifos:
        keys.remove(ifo)
    for key in keys:
        fgroup.create_dataset(key, data=fdict[key][fidx])
    for ifo in ifos:
        igroup = fgroup.create_group(ifo)
        for key in fdict[ifo].keys():
            igroup.create_dataset(key, data=fdict[ifo][key][fidx])

def compare_found_injs(reference_run, comparison_run, outfp):
    ifos = reference_run.inj_dets
    for found_class in ['found', 'found_after_vetoes']:
        logging.info("Comparing found injections for class"
                     " '{0}'".format(found_class))
        fgroup = outfp.create_group(found_class)
        ref_found = getattr(reference_run, found_class)
        com_found = getattr(comparison_run, found_class)
        ref_injs = ref_found['injection_index']
        com_injs = com_found['injection_index']
        # Note that intersect1d returns sorted, unique elements in both arrays
        # Because we set 'return_indices' to True, we also get the index into
        # each array of these common elements
        logging.info("Calculating indices of found/missed injections between"
                     " reference and comparison runs")
        both, ref_both_idx, com_both_idx = np.intersect1d(ref_injs,
                                                          com_injs,
                                                          return_indices=True)
        # Now we want to find the indices of those injections found in one
        # run but not the other
        ref_only_idx = np.isin(ref_injs, com_injs, invert=True)
        com_only_idx = np.isin(com_injs, ref_injs, invert=True)
        # We now have all of our indices, so create the hierarchy of data groups
        ref_both_group = fgroup.create_group('found_in_both/reference')
        com_both_group = fgroup.create_group('found_in_both/comparison')
        ref_only_group = fgroup.create_group('found_reference_only')
        com_only_group = fgroup.create_group('found_comparison_only')
        # Now fill these groups with the various data they should contain
        logging.info("Writing injections found in both, reference run")
        populate_injection_group(ref_both_group, ref_both_idx, ref_found, ifos)
        logging.info("Writing injections found only in reference run")
        populate_injection_group(ref_only_group, ref_only_idx, ref_found, ifos)
        logging.info("Writing injections found in both, comparison run")
        populate_injection_group(com_both_group, com_both_idx, com_found, ifos)
        logging.info("Writing injections found only in comparison run")
        populate_injection_group(com_only_group, com_only_idx, com_found, ifos)

    return

def compare_missed_injs(reference_run, comparison_run, outfp):
    logging.info("Comparing missed injections after vetoes")
    ifos = reference_run.inj_dets
    fgroup = outfp.create_group('missed_after_vetoes')
    ref_missed = reference_run.missed['after_vetoes']
    com_missed = comparison_run.missed['after_vetoes']
    ref_found_injs = reference_run.found_after_vetoes['injection_index']
    com_found_injs = comparison_run.found_after_vetoes['injection_index']
    if reference_run.ifar_threshold is not None:
        ifars = reference_run.found_after_vetoes['ifar']
        idx = ifars < reference_run.ifar_threshold
        ref_missed = np.append(ref_missed, ref_found_injs[idx])
        ref_above = reference_run.found_after_vetoes['injection_index'][~idx]
    else:
        ref_above = reference_run.found_after_vetoes['injection_index']
    if comparison_run.ifar_threshold is not None:
        ifars = comparison_run.found_after_vetoes['ifar']
        idx = ifars < comparison_run.ifar_threshold
        com_missed = np.append(com_missed, com_found_injs[idx])
        com_above = comparison_run.found_after_vetoes['injection_index'][~idx]
    else:
        com_above = comparison_run.found_after_vetoes['injection_index']
    # Now find the indices of the required loudest missed injections
    # Note that we are assuming that the injection table contains
    # optimal SNR columns, which may not always be true.
    logging.info("Finding {0} loudest missed injections in"
                 " reference run".format(reference_run.nmissed))
    o1 = reference_run.injgroup['optimal_snr_1'][:][ref_missed]
    o2 = reference_run.injgroup['optimal_snr_2'][:][ref_missed]
    ref_dec_snr = np.minimum(o1, o2)
    ref_sort = ref_dec_snr.argsort()
    ref_sort = ref_sort[::-1]
    ref_missed_n = ref_missed[ref_sort][0:reference_run.nmissed]
    logging.info("Finding {0} loudest missed injections in"
                 " comparison run".format(comparison_run.nmissed))
    o1 = comparison_run.injgroup['optimal_snr_1'][:][com_missed]
    o2 = comparison_run.injgroup['optimal_snr_2'][:][com_missed]
    com_dec_snr = np.minimum(o1, o2)
    com_sort = com_dec_snr.argsort()
    com_sort = com_sort[::-1]
    com_missed_n = com_missed[com_sort][0:comparison_run.nmissed]
    # Now see which missed injections in one run were found in the other
    logging.info("Finding loud injections missed only in reference run")
    ref_missed_only_idx = np.isin(ref_missed_n, com_above)
    ref_missed_only = ref_missed_n[ref_missed_only_idx]
    ref_missed_only_locs = np.where(ref_missed_only_idx)[0]
    com_ii = comparison_run.found_after_vetoes['injection_index']
    ref_missed_com_idx = np.isin(com_ii, ref_missed_only)
    logging.info("Finding loud injections missed only in comparison run")
    com_missed_only_idx = np.isin(com_missed_n, ref_above)
    com_missed_only = com_missed_n[com_missed_only_idx]
    com_missed_only_locs = np.where(com_missed_only_idx)[0]
    ref_ii = reference_run.found_after_vetoes['injection_index']
    com_missed_ref_idx = np.isin(ref_ii, com_missed_only)
    # We now have all of our indices, so create the hierarchy of data groups
    ref_only_group = outfp.create_group(
        "missed_after_vetoes/missed_only_reference")
    com_only_group = outfp.create_group(
        "missed_after_vetoes/missed_only_comparison")
    # Now fill these groups with the various data they should contain
    logging.info("Writing loud injections missed only in reference run")
    ref_only_group.create_dataset("injection_index", data=ref_missed_only)
    ref_only_group.create_dataset("loudest_rank", data=ref_missed_only_locs)
    ref_missed_com_group = ref_only_group.create_group("comparison")
    populate_injection_group(ref_missed_com_group, ref_missed_com_idx,
                             comparison_run.found_after_vetoes, ifos)
    logging.info("Writing loud injections missed only in comparison run")
    com_only_group.create_dataset("injection_index", data=com_missed_only)
    com_only_group.create_dataset("loudest_rank", data=com_missed_only_locs)
    com_missed_ref_group = com_only_group.create_group("reference")
    populate_injection_group(com_missed_ref_group, com_missed_ref_idx,
                             reference_run.found_after_vetoes, ifos)

    return


parser = argparse.ArgumentParser(usage="", description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--injection-label", type=str, required=True,
                    help="Label of injection set")
parser.add_argument("--reference-dir", type=str, required=True,
                    help="Directory containing reference run of this injection"
                    " set. This should contain a 'bank/' sub-directory, as well"
                    " as a sub-directory '<injection_label>_INJ_coinc'")
parser.add_argument("--comparison-dir", type=str, required=True,
                    help="Directory containing comparison run of this"
                    " injection set.  This should contain a 'bank/' sub-"
                    "directory, as well as a sub-directory"
                    " '<injection_label>_INJ_coinc'")
parser.add_argument("--output-file", type=str, required=True,
                    help="Name of HDF output file in which to store results")
parser.add_argument("--number-missed", type=int,
                    default=10, required=True,
                    help="Number of the loudest missed injections to compare"
                    " between runs")
parser.add_argument('--ifar-threshold', type=float, default=None,
                    help="If given, also followup injections with ifar smaller "
                         "than this threshold.")
parser.add_argument("--single-detector-statistic", type=str, default='newsnr',
                    choices=ranking.sngls_ranking_function_dict.keys(),
                    help="Which single-detector statistic to calculate for"
                    " found injections")

args = parser.parse_args()
pycbc.init_logging(args.verbose)

ref_found_injs = RunInjectionResults(args.injection_label, args.reference_dir,
                                     args.single_detector_statistic,
                                     args.number_missed, args.ifar_threshold)
com_found_injs = RunInjectionResults(args.injection_label, args.comparison_dir,
                                     args.single_detector_statistic,
                                     args.number_missed, args.ifar_threshold)

logging.info("Comparing injection parameters between reference and comparison"
             " runs")
same_inj = (ref_found_injs == com_found_injs)
if not same_inj:
    raise RuntimeError("Reference and comparison runs did not perform the same"
                       " injections")
else:
    outfp = File(args.output_file, "w")
    ref_found_injs.injfile.copy('injections', outfp)

outfp.attrs['injection_label'] = args.injection_label
outfp.attrs['reference_dir'] = args.reference_dir
outfp.attrs['comparison_dir'] = args.comparison_dir
outfp.attrs['single_detector_statistic'] = args.single_detector_statistic
outfp.attrs['detector_1'] = ref_found_injs.inj_dets[0]
outfp.attrs['detector_2'] = ref_found_injs.inj_dets[1]
outfp.attrs['number_missed'] = args.number_missed
outfp.attrs['ifar_threshold'] = args.ifar_threshold

compare_found_injs(ref_found_injs, com_found_injs, outfp)
compare_missed_injs(ref_found_injs, com_found_injs, outfp)

outfp.close()

logging.info("Finished")
