#!/usr/bin/python3.13 -s

# Copyright (C) 2016 Ian W. Harry, Alex Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Program for concatenating output files from pycbc_banksim_combine_banks with
a set of injection files. The *order* of the injection files *must* match the
bank files, and the number of injections in each must correspond one-to-one.
"""

import argparse
import numpy as np

from ligo.lw import utils, table

import pycbc
from pycbc import pnutils
from pycbc.waveform import TemplateBank
from pycbc.io.ligolw import LIGOLWContentHandler
from pycbc.io.hdf import HFile
from pycbc import load_source

__author__  = "Ian Harry <ian.harry@astro.cf.ac.uk>"
__version__ = pycbc.version.git_verbose_msg
__date__    = pycbc.version.date
__program__ = "pycbc_banksim_match_combine"


# Read command line options
parser = argparse.ArgumentParser(description=__doc__)

pycbc.add_common_pycbc_options(parser)
parser.add_argument("--match-files", nargs='+',
                    help="Explicit list of match files.")
#parser.add_argument("--inj-files", nargs='+',
#                    help="Explicit list of injection files. These must be in "
#                         "the same order, and match one-to-one with the "
#                         "match-files.")
parser.add_argument("-o", "--output-file", required=True,
                    help="Output file name")
parser.add_argument("--filter-func-file", default=None,
                    help="This can be provided to give a function to define "
                         "which points are covered by the template bank "
                         "bounds, and which are not. The file should contain "
                         "a function called filter_injections, which should "
                         "take as call profile, mass1, mass2, spin1z, spin2z, "
                         "as numpy arrays.")
options = parser.parse_args()

pycbc.init_logging(options.verbose)

dtypem = np.dtype([('match', np.float64), ('bank', np.unicode_, 256),
                      ('bank_i', np.int32), ('sim', np.unicode_, 256),
                      ('sim_i', np.int32), ('sigmasq', np.float64)])

# Collect the results
res = None
for fil in options.match_files:
    if res is not None:
        res = np.append(res, np.loadtxt(fil, dtype=dtypem))
    else:
        res = np.loadtxt(fil, dtype=dtypem)

btables = {}
itables = {}     

f = HFile(options.output_file, "w")

# If we more these over to HDF we can decide column names dynamically. This is
# why I'm using a dictionary now.

bank_params = {}
bank_par_list = ['mass1', 'mass2', 'spin1x', 'spin1y', 'spin1z', 'spin2x',
                 'spin2y', 'spin2z']
for val in bank_par_list:
    bank_params[val] = np.zeros(len(res), dtype=np.float64)

inj_params = {}
inj_par_list = ['mass1', 'mass2', 'spin1x', 'spin1y', 'spin1z', 'spin2x',
                'spin2y', 'spin2z', 'coa_phase', 'inclination', 'latitude',
                'longitude', 'polarization']
for val in inj_par_list:
    inj_params[val] = np.zeros(len(res), dtype=np.float64)

trig_params = {}
trig_par_list = ['match', 'sigmasq']
for val in trig_par_list:
    trig_params[val] = np.zeros(len(res), dtype=np.float64)

for idx, row in enumerate(res): 
    outstr = ""
    if row['bank'] not in btables:
        bank_path = row['bank']
        temp_bank = TemplateBank(bank_path)
        btables[row['bank']] = temp_bank.table

    if row['sim'] not in itables:
        indoc = utils.load_filename(row['sim'], False,
                                    contenthandler=LIGOLWContentHandler)
        itables[row['sim']] = table.Table.get_table(indoc, "sim_inspiral") 
    
    bt = btables[row['bank']][row['bank_i']]     
    it = itables[row['sim']][row['sim_i']]
 
    for val in trig_par_list:
        trig_params[val][idx] = row[val]
    for val in bank_par_list:
        try:
            bank_params[val][idx] = getattr(bt, val)
        except AttributeError:
            # If not present set to 0.
            # For example spin1x is not always stored in aligned-spin banks
            bank_params[val][idx] = 0.
    for val in inj_par_list:
        inj_params[val][idx] = getattr(it, val)

for val in bank_par_list:
    f['bank_params/{}'.format(val)] = bank_params[val] 
for val in inj_par_list:
    f['inj_params/{}'.format(val)] = inj_params[val]
for val in trig_par_list:
    f['trig_params/{}'.format(val)] = trig_params[val]

if options.filter_func_file:
    modl = load_source('filter_func', options.filter_func_file)
    func = modl.filter_injections
    bool_arr = func(inj_params['mass1'], inj_params['mass2'],
                    inj_params['spin1z'], inj_params['spin2z'])
    bool_arr = np.array(bool_arr)

# Also consider values over the whole set
# Signal recovery fraction
sigma = trig_params['sigmasq']**0.5
srfn = np.sum((trig_params['match'] * sigma)**3.)
srfd = np.sum(sigma**3.)

f['sig_rec_fac'] = srfn / srfd
f['eff_fitting_factor'] = (srfn / srfd)**(1./3.)
mchirp, _ = pnutils.mass1_mass2_to_mchirp_eta(inj_params['mass1'],
                                              inj_params['mass2'])
srfn_mcweighted = np.sum((trig_params['match'] * mchirp**(-5./6.) *\
                             sigma)**3.)
srfd_mcweighted = np.sum((mchirp**(-5./6.) * sigma)**3.)
f['sig_rec_fac_chirp_mass_weighted'] = srfn_mcweighted / srfd_mcweighted
f['eff_fitting_factor_chirp_mass_weighted'] = \
    (srfn_mcweighted / srfd_mcweighted)**(1./3.)

if options.filter_func_file:
    num_filtered = len(inj_params['mass1'][bool_arr])
    if num_filtered == 0:
        f['frac_points_within_bank'] = 0
        f['filtered_sig_rec_fac'] = -1
        f['filtered_eff_fitting_factor'] = -1
    else:
        f['frac_points_within_bank'] = \
            num_filtered / float(len(inj_params['mass1']))
        filt_match = trig_params['match'][bool_arr]
        filt_sigma = trig_params['sigmasq'][bool_arr]**0.5
        srfn_filt = np.sum((filt_match * filt_sigma)**3.)
        srfd_filt = np.sum(filt_sigma**3)
        f['filtered_sig_rec_fac'] = srfn_filt / srfd_filt
        f['filtered_eff_fitting_factor'] = (srfn_filt / srfd_filt)**(1./3.)
        mchirp = mchirp[bool_arr]
        srfn_mcweighted = np.sum((filt_match * mchirp**(-5./6.) *\
                                     filt_sigma)**3.)
        srfd_mcweighted = np.sum((mchirp**(-5./6.) * filt_sigma)**3.)
        f['filtered_sig_rec_fac_chirp_mass_weighted'] = \
            srfn_mcweighted / srfd_mcweighted
        f['filtered_eff_fitting_factor_chirp_mass_weighted'] = \
            (srfn_mcweighted / srfd_mcweighted)**(1./3.)

    f['filtered_points'] = bool_arr

f.close()
