#!/usr/bin/python3.13 -s
# Copyright (C) 2012-17  Alex Nitz, Ian Harry
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Calculate the fitting factors of simulated signals with a template bank."""


import sys
import logging
from numpy import complex64, float32, sqrt, argmax, real, array
from argparse import ArgumentParser
from math import ceil, log
from tqdm import tqdm

from ligo.lw import utils as ligolw_utils
from ligo.lw import lsctables

from pycbc.pnutils import mass1_mass2_to_mchirp_eta, f_SchwarzISCO
from pycbc.pnutils import nearest_larger_binary_number
from pycbc.pnutils import mass1_mass2_to_tau0_tau3
from pycbc.waveform import get_td_waveform, get_fd_waveform, td_approximants, fd_approximants
from pycbc.waveform.utils import taper_timeseries
from pycbc import DYN_RANGE_FAC
from pycbc.types import FrequencySeries, TimeSeries, zeros, real_same_precision_as, complex_same_precision_as
from pycbc.filter import match, sigmasq, resample_to_delta_t
from pycbc.filter import overlap_cplx, matched_filter
from pycbc.filter import compute_max_snr_over_sky_loc_stat
from pycbc.filter import compute_max_snr_over_sky_loc_stat_no_phase
from pycbc.io.ligolw import LIGOLWContentHandler
import pycbc.psd, pycbc.scheme, pycbc.fft, pycbc.strain
from pycbc.detector import overhead_antenna_pattern as generate_fplus_fcross
from pycbc.waveform import TemplateBank

## Remove the need for these functions ########################################
    
def generate_detector_strain(template_params, h_plus, h_cross):
    latitude = 0 
    longitude = 0 
    polarization = 0 

    if hasattr(template_params, 'latitude'):
        latitude = template_params.latitude
    if hasattr(template_params, 'longitude'):
        longitude = template_params.longitude
    if hasattr(template_params, 'polarization'):
        polarization = template_params.polarization

    f_plus, f_cross = generate_fplus_fcross(longitude, latitude, polarization)

    return h_plus * f_plus + h_cross * f_cross

def make_padded_frequency_series(vec, filter_N=None, delta_f=None):
    """Convert vec (TimeSeries or FrequencySeries) to a FrequencySeries. If
    filter_N and/or delta_f are given, the output will take those values. If
    not told otherwise the code will attempt to pad a timeseries first such that
    the waveform will not wraparound. However, if delta_f is specified to be
    shorter than the waveform length then wraparound *will* be allowed.
    """
    if filter_N is None:
        power = ceil(log(len(vec), 2)) + 1
        N = 2 ** power
    else:
        N = filter_N
    n = N // 2 + 1

    if isinstance(vec, FrequencySeries):
        vectilde = FrequencySeries(zeros(n, dtype=complex_same_precision_as(vec)),
                                   delta_f=1.0, copy=False)
        if len(vectilde) < len(vec):
            cplen = len(vectilde)
        else:
            cplen = len(vec)
        vectilde[0:cplen] = vec[0:cplen]
        delta_f = vec.delta_f

    elif isinstance(vec, TimeSeries):
        # First determine if the timeseries is too short for the specified df
        # and increase if necessary
        curr_length = len(vec)
        new_length = int(nearest_larger_binary_number(curr_length))
        while new_length * vec.delta_t < 1./delta_f:
            new_length = new_length * 2
        vec.resize(new_length)
        # Then convert to frequencyseries
        v_tilde = vec.to_frequencyseries()
        # Then convert frequencyseries to required length and spacing by keeping
        # only every nth sample if delta_f needs increasing, and cutting at
        # Nyquist if the max frequency is too high.
        # NOTE: This assumes that the input and output data is using binary
        #       lengths.
        i_delta_f = v_tilde.get_delta_f()
        v_tilde = v_tilde.numpy()
        df_ratio = int(delta_f / i_delta_f)
        n_freq_len = int((n-1) * df_ratio +1)
        assert(n <= len(v_tilde))
        df_ratio = int(delta_f / i_delta_f)
        v_tilde = v_tilde[:n_freq_len:df_ratio]
        vectilde = FrequencySeries(v_tilde, delta_f=delta_f, dtype=complex64)

    return FrequencySeries(vectilde * DYN_RANGE_FAC, delta_f=delta_f,
                           dtype=complex64)

def get_waveform(approximant, phase_order, amplitude_order, spin_order,
                 wf_params, start_frequency, sample_rate, length,
                 filter_rate, sky_max_template=False):

    if type(approximant) is not str:
        approximant = approximant.decode('utf-8')

    delta_f = filter_rate / length
    if approximant in fd_approximants():
        hp, hc = get_fd_waveform(wf_params, approximant=approximant,
                                 phase_order=phase_order, delta_f=delta_f,
                                 spin_order=spin_order,
                                 f_lower=start_frequency,
                                 amplitude_order=amplitude_order) 

    elif approximant in td_approximants():
        hp, hc = get_td_waveform(wf_params,
                                 approximant=approximant,
                                 phase_order=phase_order,
                                 spin_order=spin_order,
                                 delta_t=1./sample_rate,
                                 f_lower=start_frequency,
                                 amplitude_order=amplitude_order)
        if hasattr(wf_params, 'taper'):
            hp = taper_timeseries(hp, wf_params.taper)
            hc = taper_timeseries(hc, wf_params.taper)

    if not sky_max_template:
        hvec = generate_detector_strain(wf_params, hp, hc)
        return make_padded_frequency_series(hvec, filter_N, delta_f=delta_f)
    else:
        return make_padded_frequency_series(hp, filter_N, delta_f=delta_f), \
            make_padded_frequency_series(hc, filter_N, delta_f=delta_f)

aprs = sorted(list(set(td_approximants() + fd_approximants())))

#File output Settings
parser = ArgumentParser(description=__doc__)
parser.add_argument("--match-file", dest="out_file", metavar="FILE",
                    required=True, help="File to output match results")
pycbc.add_common_pycbc_options(parser)

#Template Settings
parser.add_argument("--template-file", dest="bank_file", metavar="FILE",
                    required=True, help="SimInspiral or SnglInspiral XML file "
                                        "containing the template parameters")
parser.add_argument("--total-mass-divide", type=float,
                    help="Total mass to switch from --template-approximant to "
                         "--highmass-approximant.")
parser.add_argument("--highmass-approximant", choices=aprs,
                    help="Waveform approximant for highmass templates.")
parser.add_argument("--template-approximant", choices=aprs, required=True,
                    help="Waveform approximant for templates")
parser.add_argument("--template-phase-order", default=-1, type=int,
                    help="PN order to use for the template phase")
parser.add_argument("--template-amplitude-order", default=-1, type=int,
                    help="PN order to use for the template amplitude")
parser.add_argument("--template-spin-order", default=-1, type=int,
                    help="PN order to use for the template spin terms")
parser.add_argument("--template-start-frequency", type=float,
                    help="Starting frequency for templates [Hz]")
parser.add_argument("--template-sample-rate", type=float,
                    help="Sample rate for templates [Hz]")

#Signal Settings
parser.add_argument("--signal-file", dest="sim_file", metavar="FILE",
                    required=True, help="SimInspiral or SnglInspiral XML file "
                                        "containing the signal parameters")
parser.add_argument("--signal-approximant", choices=aprs, required=True,
                    help="Waveform approximant for signals")
parser.add_argument("--signal-phase-order", default=-1, type=int,
                    help="PN order to use for the signal phase")
parser.add_argument("--signal-spin-order", default=-1, type=int,
                    help="PN order to use for the signal spin terms")
parser.add_argument("--signal-amplitude-order", default=-1, type=int,
                    help="PN order to use for the signal amplitude")
parser.add_argument("--signal-start-frequency", type=float,
                    help="Starting frequency for signals [Hz]")
parser.add_argument("--signal-sample-rate", type=float,
                    help="Sample rate for signals [Hz]")
parser.add_argument("--use-sky-location", action='store_true',
                    help="Inject into a theoretical detector at the celestial "
                         "North pole of a non-rotating Earth rather than overhead")

#Filtering Settings
parser.add_argument('--filter-low-frequency-cutoff', metavar='FREQ', type=float,
                    required=True, help='low frequency cutoff of matched filter')
parser.add_argument('--filter-low-freq-cutoff-column', metavar='NAME', type=str,
                    help='If given, use a per-template low-frequency cutoff '
                         'from column NAME of the template table instead of '
                         'the value given via --filter-low-frequency-cutoff. '
                         'Signals are still normalized using '
                         '--filter-low-frequency-cutoff and the value in '
                         'column NAME must be larger than or equal to it.')
parser.add_argument("--filter-sample-rate", type=float, required=True,
                    help="Filter sample rate [Hz]")
parser.add_argument("--filter-signal-length", type=int, required=True,
                    help="Length of signal for filtering, shoud be longer "
                         "than all waveforms and include some padding")
parser.add_argument("--sky-maximization-method", required=True,
                    choices=["precessing", "hom"])


# add PSD options
pycbc.psd.insert_psd_option_group(parser, output=False)

# Insert the data reading options
pycbc.strain.insert_strain_option_group(parser)

#hardware support
pycbc.scheme.insert_processing_option_group(parser)
pycbc.fft.insert_fft_option_group(parser)

#Restricted maximization
parser.add_argument("--mchirp-window", type=str, metavar="FRACTION",
                    help="Ignore templates whose chirp mass deviates from "
                         "signal's one more than given fraction. Provide two "
                         "comma separated numbers to have different bounds "
                         "above and below the signal's, with below bound "
                         "listed first.")
parser.add_argument("--tau0-window", type=float, metavar="TIME", default=None,
                    help="Ignore templates whose Newtonian order chirp time "
                         "(tau0) varies from the signals by more than the "
                         "supplied amount. If option is not provided no "
                         "window on tau0 is used. The "
                         "filter-low-frequency-cutoff is used to calculate "
                         "the value of tau0 for all cases.")

options = parser.parse_args()

pycbc.init_logging(options.verbose)

pycbc.psd.verify_psd_options(options, parser)

if options.psd_estimation:
    pycbc.strain.verify_strain_options(options, parser)

if options.total_mass_divide and options.highmass_approximant is None:
    parser.error("You must provide a highmass-approximant if you want total-mass-divide.")

if options.mchirp_window is None:
    def outside_mchirp_window(template_mchirp, signal_mchirp):
        return False
elif ',' in options.mchirp_window:
    # asymmetric chirp mass window
    mchirp_window_lower = float(options.mchirp_window.split(",")[0])
    mchirp_window_upper = float(options.mchirp_window.split(",")[1])
    def outside_mchirp_window(template_mchirp, signal_mchirp):
        delta = (template_mchirp - signal_mchirp) / signal_mchirp
        return delta > mchirp_window_upper or -delta > mchirp_window_lower
else:
    # symmetric chirp mass window
    mchirp_window = float(options.mchirp_window)
    def outside_mchirp_window(template_mchirp, signal_mchirp):
        return abs(signal_mchirp - template_mchirp) > \
                (mchirp_window * signal_mchirp)

if options.tau0_window is None:
    def outside_tau0_window(template_tau0, signal_tau0, window):
        return False
else:
    def outside_tau0_window(template_tau0, signal_tau0, window):
        return abs(signal_tau0 - template_tau0) > window


# If we are going to use h(t) to estimate a PSD we need h(t)
if options.psd_estimation:
    logging.info("Obtaining h(t) for PSD generation")
    strain = pycbc.strain.from_cli(options, pycbc.DYN_RANGE_FAC)
else:
    strain = None

if options.template_sample_rate is not None:
    template_sample_rate = options.template_sample_rate
else:
    template_sample_rate = options.filter_sample_rate
if options.signal_sample_rate is not None:
    signal_sample_rate = options.signal_sample_rate
else:
    signal_sample_rate = options.filter_sample_rate

ctx = pycbc.scheme.from_cli(options)

logging.info('Reading template bank')
temp_bank = TemplateBank(options.bank_file)
template_table = temp_bank.table
    
logging.info("  %d templates", len(template_table))

logging.info('Reading simulation list')
indoc = ligolw_utils.load_filename(options.sim_file, False,
                                   contenthandler=LIGOLWContentHandler)
try:
    signal_table = lsctables.SimInspiralTable.get_table(indoc)
except ValueError:
    signal_table = lsctables.SnglInspiralTable.get_table(indoc)
logging.info("  %d signal waveforms", len(signal_table))

logging.info("Matches will be written to %s", options.out_file)

filter_N = int(options.filter_signal_length * options.filter_sample_rate)
filter_n = filter_N // 2 + 1
filter_delta_f = 1.0 / float(options.filter_signal_length)

logging.info("Reading and Interpolating PSD")
psd = pycbc.psd.from_cli(options, filter_n, filter_delta_f,
                         options.filter_low_frequency_cutoff, strain=strain,
                         dyn_range_factor=pycbc.DYN_RANGE_FAC,
                         precision='single')
  
with ctx: 
    pycbc.fft.from_cli(options)

    logging.info("Pregenerating Signals")

    signals = []
    # Used for getting mchirp/tau0 later
    sig_m1 = []
    sig_m2 = []
    prog = tqdm(total=len(signal_table), disable=(not options.verbose))    
    for index, signal_params in enumerate(signal_table):
        prog.update(1)
        if not options.use_sky_location:
            signal_params.latitude = 0.
            signal_params.longitude = 0.
        stilde = get_waveform(options.signal_approximant,
                              options.signal_phase_order,
                              options.signal_amplitude_order,
                              options.signal_spin_order,
                              signal_params,
                              options.signal_start_frequency,
                              signal_sample_rate,
                              filter_N, options.filter_sample_rate)
        s_norm = sigmasq(stilde, psd=psd,
                         low_frequency_cutoff=options.filter_low_frequency_cutoff)
        stilde /= sqrt(float(s_norm))
        stilde /= psd
        signals.append((stilde, s_norm, [], signal_params))
        sig_m1.append(signal_params.mass1)
        sig_m2.append(signal_params.mass2)
    prog.close()
    sig_m1 = array(sig_m1)
    sig_m2 = array(sig_m2)
    sig_tau0, _ = mass1_mass2_to_tau0_tau3(sig_m1, sig_m2,
                                           options.filter_low_frequency_cutoff)
    sig_mchirp, _ = mass1_mass2_to_mchirp_eta(sig_m1, sig_m2)

    logging.info("Calculating Mchirp and Tau0")
    template_m1 = []
    template_m2 = []
    for template_params in template_table:
        template_m1.append(template_params.mass1)
        template_m2.append(template_params.mass2)
    template_m1 = array(template_m1)
    template_m2 = array(template_m2)
    template_tau0, _ = mass1_mass2_to_tau0_tau3(template_m1, template_m2,
                                                options.filter_low_frequency_cutoff)
    template_mchirp, _ = mass1_mass2_to_mchirp_eta(template_m1, template_m2)

    logging.info("Calculating Overlaps")

    flow_warned = False
    prog = tqdm(total=len(template_table), disable=(not options.verbose))    
    for index, template_params in enumerate(template_table):
        prog.update(1)
        f_lower = template_params.f_lower
        # If not set fall back on filter low-freq cutoff
        if f_lower < 0.000001:
            f_lower = options.filter_low_frequency_cutoff
        if f_lower < options.filter_low_frequency_cutoff:
            # Not entirely clear what to do here?
            if not flow_warned:
                logging.warning("Template's flower is smaller than "
                                "--filter-low-frequency-cutoff. Raising "
                                "flower of template to match.")
                flow_warned=True
            f_lower = options.filter_low_frequency_cutoff


        h_norm = htilde = None
        for sidx, (stilde, s_norm, matches, signal_params) in enumerate(signals):
            # Check if we need to look at this
            check_logic = stilde is None
            check_logic |= outside_tau0_window(template_tau0[index],
                                               sig_tau0[sidx],
                                               options.tau0_window)
            check_logic |= outside_mchirp_window(template_mchirp[index],
                                                 sig_mchirp[sidx])
            if check_logic:
                matches.append(0)
                continue

            # Generate htilde if we haven't already done so
            if htilde is None:
                # FIXME: I would like to remove the approximant options and
                #        have this entirely controlled by the template bank.
                #        However, while we are still using the high-mass divide
                #        in XML banks, this must be retained.
                try:
                    this_approximant = template_params['approximant']
                except:
                    this_approximant = options.template_approximant
                    if options.total_mass_divide is not None and (template_params.mass1+template_params.mass2) >= options.total_mass_divide:
                        this_approximant = options.highmass_approximant
                hplus, hcross = get_waveform(this_approximant,
                                      options.template_phase_order,
                                      options.template_amplitude_order,
                                      options.template_spin_order,
                                      template_params,
                                      options.template_start_frequency,
                                      template_sample_rate,
                                      filter_N, options.filter_sample_rate,
                                      sky_max_template=True)

                hp_norm = sigmasq(hplus, psd=psd, low_frequency_cutoff=f_lower)
                hc_norm = sigmasq(hcross, psd=psd, low_frequency_cutoff=f_lower)
                hplus /= sqrt(float(hp_norm))
                hcross /= sqrt(float(hc_norm))
                hpc_corr = overlap_cplx(hplus, hcross, psd=psd,
                      low_frequency_cutoff=options.filter_low_frequency_cutoff,
                      normalized=False)
                hpc_corr_R = real(hpc_corr)
                htilde=1

            I_plus = matched_filter(hplus, stilde,
                      low_frequency_cutoff=options.filter_low_frequency_cutoff,
                      sigmasq=1.)

            I_cross = matched_filter(hcross, stilde,
                      low_frequency_cutoff=options.filter_low_frequency_cutoff,
                      sigmasq=1.)

            if options.sky_maximization_method == 'precessing':
                det_stat = compute_max_snr_over_sky_loc_stat\
                    (I_plus, I_cross, hpc_corr_R, hpnorm=1., hcnorm=1.,
                     thresh=0.1, analyse_slice=slice(0,len(I_plus.data)))
            elif options.sky_maximization_method == 'hom':
                det_stat = compute_max_snr_over_sky_loc_stat_no_phase\
                    (I_plus, I_cross, hpc_corr_R, hpnorm=1., hcnorm=1.,
                     thresh=0.1, analyse_slice=slice(0,len(I_plus.data)))
            else:
                err_msg = "I really shouldn't be here! Who gone broked me?"
                raise ValueError(err_msg)

            i = argmax(det_stat.data)
            o = det_stat[i]
            matches.append(o)
    prog.close()
logging.info("Determining maximum overlaps and outputting results")

# Find the maximum overlap in the bank and output to a file
with open(options.out_file, "w") as fout:
    for i, (stilde, s_norm, matches, sim_template) in enumerate(signals):
        match_str = "%5.5f " % max(matches)
        match_str += " " + options.bank_file
        match_str += " " + str(matches.index(max(matches)))
        match_str += " " + options.sim_file
        match_str += " %d" % i
        match_str += " %5.5f\n" % s_norm
        fout.write(match_str)
