#!/usr/bin/python3.13 -s

"""Finds the maximum frequency of waveforms by generating them. Will also report
the duration of time domain waveforms.
"""

__prog__ = 'pycbc_get_ffinal'
__author__ = 'Collin Capano <collin.capano@ligo.org>'

import os, sys
import numpy
import argparse

import pycbc
from pycbc import waveform

from ligo.lw import lsctables
from ligo.lw import utils as ligolw_utils
from ligo.lw import table
from ligo.lw.utils import process

appx_options = waveform.td_approximants() + waveform.fd_approximants()

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("-i", "--input",
                    help="Input file. If specified, any single waveform "
                         "parameters given by the below options will be "
                         "ignored.")
parser.add_argument("-o", "--output",
                    help="Output file. Required if specifying an input "
                         "file.")
parser.add_argument("-a", "--approximant", required=True,
                    help=("What approximant to use to generate the "
                          "waveform(s). Options are: TD approximants: %s."
                          % ', '.join(waveform.td_approximants()) +
                          ". FD approximants: %s." %
                          ', '.join(waveform.fd_approximants())),
                    choices=appx_options)
parser.add_argument('-f', '--f-min', type=float,
                    help="Frequency at which to start the waveform "
                         "generation (in Hz).")
parser.add_argument('-r', '--sample-rate', type=int,
                    help="Sample rate to use (in Hz). Required for TD "
                         "approximants.")
parser.add_argument('-s', '--segment-length', type=int,
                    help = "The inverse of deltaF (in s). Required for FD "
                           "approximants.")
parser.add_argument('-m', '--max-sample-rate', type=int,
                    help="Optional. Maximum sample rate to use (in Hz). "
                         "If the Nyquist frequency of the given sample rate "
                         "is lower than the ringdown frequency, an error "
                         "will occur for some waveform approximants. If "
                         "max-sample-rate is specified, the code will try "
                         "increasing the sample-rate by a factor of 2 until "
                         "it finds a frequency that works or until it "
                         "exceeds the specified maximum rate.")
waveform_opts = parser.add_argument_group("Waveform Options",
    "Optional arguments for specifying a single waveform. These will be "
    "ignored if an input file is given.")
waveform_opts.add_argument('--mass1', type=float,
    help="Specify mass1 of a single waveform.")
waveform_opts.add_argument('--mass2', type=float,
    help="Specify mass2 of a single waveform.")
waveform_opts.add_argument('--spin1x', type=float, default=0,
    help='Specify spin1x of a single waveform.')
waveform_opts.add_argument('--spin1y', type=float, default=0,
    help='Specify spin1y of a single waveform.')
waveform_opts.add_argument('--spin1z', type=float, default=0,
    help='Specify spin1z of a single waveform.')
waveform_opts.add_argument('--spin2x', type=float, default=0,
    help='Specify spin2x of a single waveform.')
waveform_opts.add_argument('--spin2y', type=float, default=0,
    help='Specify spin2y of a single waveform.')
waveform_opts.add_argument('--spin2z', type=float, default=0,
    help='Specify spin2z of a single waveform.')
waveform_opts.add_argument('--lambda1', type=float, default=0,
    help='Specify lambda1 of a single waveform.')
waveform_opts.add_argument('--lambda2', type=float, default=0,
    help='Specify lambda2 of a single waveform.')
waveform_opts.add_argument('--phase-order', type=int, default=-1,
    help='Specify phase-order of a single waveform.')
waveform_opts.add_argument('--spin-order', type=int, default=-1,
    help='Specify spin-order of a single waveform.')
waveform_opts.add_argument('--tidal-order', type=int, default=-1,
    help='Specify tidal-order of a single waveform.')
waveform_opts.add_argument('--amplitude-order', type=int, default=-1,
    help='Specify amplitude-order of a single waveform.')

opts = parser.parse_args()

pycbc.init_logging(opts.verbose)

# check options
if opts.input is None and (opts.mass1 is None or opts.mass2 is None):
    parser.error("Must specify input file or at least mass1,mass2 "
                 "of a single waveform")

infile = opts.input
if opts.input is not None and opts.output is None:
    parser.error("Must specify output if giving input file")
outfile = opts.output
if opts.approximant is None:
    parser.error("Must specify an approximant")

fd_approx = opts.approximant in waveform.fd_approximants()
if not fd_approx and opts.approximant not in waveform.td_approximants():
    raise ValueError("Unrecognized approximant {}".format(opts.approximant))
if not fd_approx and opts.sample_rate is None:
    raise ValueError("TD approximants require a sample-rate")
min_sample_rate = opts.sample_rate
if opts.max_sample_rate is not None:
    max_sample_rate = opts.max_sample_rate
else:
    max_sample_rate = min_sample_rate
if fd_approx and opts.segment_length is None:
    raise ValueError("FD approximants require a segment-length")
seg_length = opts.segment_length
wFmin = opts.f_min
if not fd_approx and wFmin >= min_sample_rate / 2.:
    raise ValueError("f-min must be < half the sample-rate")
phi0 = 0.

# load the xml file
if infile is not None:
    xmldoc = ligolw_utils.load_filename(
        infile,
        compress='auto',
        verbose=opts.verbose
    )
    this_process = process.register_to_xmldoc(xmldoc, __prog__, opts.__dict__)
    sngl_insp_table = table.Table.get_table(xmldoc, 'sngl_inspiral')
else:
    # FIXME: try to get this from the waveform_opts group
    tmplt_args = ['mass1', 'mass2', 'spin1x', 'spin1y', 'spin1z', 'spin2x',
        'spin2y', 'spin2z', 'lambda1', 'lambda2', 'phase_order', 'spin_order',
        'tidal_order', 'amplitude_order']
    tmplt = dict([ [arg, getattr(opts, arg)] for arg in tmplt_args])
    sngl_insp_table = [tmplt]

if opts.verbose and infile is not None:
    print("Cycling over templates...", file=sys.stdout)

for cache_id, tmplt in enumerate(sngl_insp_table):
    if opts.verbose and infile is not None:
        print("Template %i/%i\r" %(cache_id+1,
            len(sngl_insp_table)), end=' ', file=sys.stdout)
        sys.stdout.flush()


    if fd_approx:
        if infile is None:
            hplus, hcross = waveform.get_fd_waveform(
                approximant=opts.approximant, delta_f=1./seg_length,
                f_lower=wFmin, **tmplt)
        else:
            hplus, hcross = waveform.get_fd_waveform(
                template=tmplt, approximant=opts.approximant,
                delta_f=1./seg_length, f_lower=wFmin)
        
        freq = hplus.delta_f * numpy.nonzero(abs(hplus.data))[0][-1]

    else:
        # cycle over the sample rates finding the lowest one that works
        sample_rate = min_sample_rate
        success = False
        while not success:
            try:
                if infile is None:
                    hplus, hcross = waveform.get_td_waveform(
                        approximant=opts.approximant, delta_t=1./sample_rate,
                        f_lower=wFmin, **tmplt)
                else:
                    hplus, hcross = waveform.get_td_waveform(
                        template=tmplt, approximant=opts.approximant,
                        delta_t=1./sample_rate, f_lower=wFmin)
                # success break
                success = True
            except:
                sample_rate *= 2
                if sample_rate > max_sample_rate:
                    err_msg = "Unable to generate template %i;" %(cache_id)
                    err_msg += " try increasing max sample rate"
                    raise ValueError(err_msg)
                
        # Remove possible zero padding at the end of the waveforms
        amp = hplus**2 + hcross**2 
        end_pnt = numpy.nonzero(amp.data)[0][-1] + 1
        hplus = hplus[:end_pnt]
        hcross = hcross[:end_pnt]
        
        # find f_final
        freq = waveform.frequency_from_polarizations(hplus, hcross).max()

    if infile is None:
        print("f Final: %f Hz" %(freq), file=sys.stdout)
        if not fd_approx:
            print("Duration: %f s" %(hplus.duration), file=sys.stdout)
    else:
        sngl_insp_table[cache_id].f_final = freq

        if not fd_approx:
            sngl_insp_table[cache_id].template_duration = hplus.duration

if infile is not None:
    if opts.verbose:
        print("\nWriting output file...", file=sys.stdout)
        sys.stdout.flush()

    # write the xml file
    ligolw_utils.write_filename(xmldoc, outfile, gz = outfile.endswith('.gz'))

    if opts.verbose:
        print("Finished!", file=sys.stdout)
sys.exit(0)
