#!/usr/bin/python3.13 -s

"""Followup utility to optimize the SNR of a PyCBC Live trigger."""

import os
import argparse
import logging
import numpy

# we will make plots on a likely headless machine, so make sure matplotlib's
# backend is set appropriately
from matplotlib import use as mpl_use_backend
mpl_use_backend('agg')

import pycbc
from pycbc import (
    fft, scheme
)
from pycbc.types import MultiDetOptionAction, load_frequencyseries
import pycbc.conversions as cv
from pycbc.io.gracedb import CandidateForGraceDB
from pycbc.io.hdf import load_hdf5_to_dict, HFile
from pycbc.detector import Detector
from pycbc.psd import interpolate
from pycbc.live import snr_optimizer


parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--params-file', required=True,
                    help='Location of the attributes file created by PyCBC '
                         'Live')
parser.add_argument('--data-files', type=str, nargs='+',
                    action=MultiDetOptionAction,
                    metavar='IFO:DATA_FILE',
                    help='Locations of the overwhitened data files produced '
                         'by PyCBC Live.')
parser.add_argument('--psd-files', type=str, nargs='+',
                    action=MultiDetOptionAction,
                    metavar='IFO:PSD_FILE',
                    help='Locations of the PSD files produced '
                         'by PyCBC Live.')
parser.add_argument('--approximant', required=True,
                    help='Waveform approximant string.')
parser.add_argument('--snr-threshold', type=float, default=4.0,
                    help='If the SNR in ifo X is below this threshold do not '
                         'consider it part of the coincidence. Not implemented')
parser.add_argument('--chirp-time-f-lower', type=float, default=20.,
                    help='Starting frequency for chirp time window (Hz).')
parser.add_argument('--chirp-time-window', type=float, default=2.,
                    help='Chirp time window (s).')
parser.add_argument('--gracedb-server', metavar='URL',
                    help='URL of GraceDB server API for uploading events. '
                         'If not provided, the default URL is used.')
parser.add_argument('--gracedb-search', type=str, default='AllSky',
                    help='String going into the "search" field of the GraceDB '
                         'events')
parser.add_argument('--gracedb-labels', metavar='LABEL', nargs='+',
                    help='Apply the given list of labels to events uploaded '
                         'to GraceDB.')
parser.add_argument('--production', action='store_true',
                    help='Upload a production event rather than a test event')
parser.add_argument('--enable-gracedb-upload', action='store_true', default=False,
                    help='Upload triggers to GraceDB')
parser.add_argument('--output-path', required=True,
                    help='Path to a directory to store results in')
parser.add_argument('--cores', type=int,
                    help='Restrict calculation to given number of CPU cores')
snr_optimizer.insert_snr_optimizer_options(parser)
scheme.insert_processing_option_group(parser)
fft.insert_fft_option_group(parser)

args = parser.parse_args()

pycbc.init_logging(args.verbose)

if args.snr_opt_seed is not None and args.snr_opt_seed != 'random':
    logging.info('Setting snr optimizer random seed.')
    numpy.random.seed(int(args.snr_opt_seed))

# Input checking
if args.snr_threshold != 4:
    parser.error("Sorry, the SNR threshold option doesn't work yet")
snr_optimizer.check_snr_optimizer_options(args, parser)

scheme.verify_processing_options(args, parser)
fft.verify_fft_options(args, parser)

scheme_context = scheme.from_cli(args)
fft.from_cli(args)

# Set a minimum mass for points tried in optimization allowing for
# minimal slop relative to the lightest template
MIN_CPT_MASS = snr_optimizer.MIN_CPT_MASS

# Set a notional maximum and minimum possible eta
MAX_ETA = 0.24999999999
MIN_ETA = 0.01

logging.info('Starting optimize SNR')

data = {}
ifos = list(args.data_files.keys())
for ifo in ifos:
    data[ifo] = load_frequencyseries(args.data_files[ifo])
    data[ifo].psd = load_frequencyseries(args.psd_files[ifo])

fp = HFile(args.params_file, 'r')
original_gid = None
if 'gid' in fp:
    original_gid = fp['gid'].asstr()[()]
if args.enable_gracedb_upload and not original_gid:
    raise RuntimeError('Params must include original gracedb ID in order '
                       'to upload followup!')
if original_gid:
    logging.info('Following up GID %s', original_gid)

coinc_times = {}
for ifo in ifos:
    try:
        coinc_times[ifo] = fp['coinc_times'][ifo][()]
    except KeyError:
        pass
coinc_ifos = list(coinc_times.keys())

flen = fp['flen'][()]
approximant = args.approximant
flow = float(fp['flow'][()])
f_end = float(fp['f_end'][()])
delta_f = fp['delta_f'][()]
sample_rate = fp['sample_rate'][()]

extra_args = [data, coinc_times, coinc_ifos, flen,
              approximant, flow, f_end, delta_f, sample_rate]

# Determine chirp mass bounds from constant chirp time window
tau0flow = args.chirp_time_f_lower
tau0 = cv.tau0_from_mass1_mass2(fp['mass1'][()], fp['mass2'][()], tau0flow)
# Arbitrarily set max mchirp to 200 Msun if tau0 - window is too small
mintau0 = max(tau0 - args.chirp_time_window,
              cv.tau0_from_mchirp(200., tau0flow))
maxtau0 = tau0 + args.chirp_time_window
minchirp = cv.mchirp_from_tau0(maxtau0, tau0flow)
maxchirp = cv.mchirp_from_tau0(mintau0, tau0flow)
# Check basic sanity
assert minchirp > 0.5
assert minchirp < maxchirp
assert maxchirp < 250.

# Establish minimum eta: find the most asymmetric mass point
# nb function name is 'mass2' but it doesn't enforce m2 < m1!
maxm1 = cv._mass2_from_mchirp_mass1(maxchirp, MIN_CPT_MASS)
mineta = max(cv.eta_from_mass1_mass2(maxm1, MIN_CPT_MASS), MIN_ETA)

# Use max chirp and max eta to find the upper bound on mass2
maxm2 = cv.mass2_from_mchirp_eta(maxchirp, MAX_ETA)

# Astrophysical spin bounds dependent on system masses - https://inspirehep.net/literature/1124003
# If masses can be > 3 then allow them BH spin range
minspin1z = -0.4 if maxm1 < 3 else -0.9
maxspin1z = 0.4 if maxm1 < 3 else 0.9
minspin2z = -0.4 if maxm2 < 3 else -0.9
maxspin2z = 0.4 if maxm2 < 3 else 0.9

# Boundary of the optimization space
bounds = {
    'mchirp': (minchirp, maxchirp),
    'eta': (mineta, MAX_ETA),
    'spin1z': (minspin1z, maxspin1z),
    'spin2z': (minspin2z, maxspin2z)
}

if args.snr_opt_include_candidate:
    # Initial point from found candidate
    mchirp_init = cv.mchirp_from_mass1_mass2(fp['mass1'][()], fp['mass2'][()])
    eta_init = cv.eta_from_mass1_mass2(fp['mass1'][()], fp['mass2'][()])
    spin1z_init = fp['spin1z'][()]
    spin2z_init = fp['spin2z'][()]

    initial_point = numpy.array([
        mchirp_init,
        eta_init,
        spin1z_init,
        spin2z_init,
    ])[numpy.newaxis]
else:
    initial_point = None

with scheme_context:
    logging.info('Starting optimization')

    optimize_func = snr_optimizer.optimize_funcs[args.snr_opt_method]
    opt_params = optimize_func(bounds, args, extra_args, initial_point)

    logging.info('Optimization complete')

    fup_ifos = set(ifos) - set(coinc_ifos)
    for ifo in fup_ifos:
        coinc_times[ifo] = coinc_times[coinc_ifos[0]]

    extra_args[2] = ifos

    _, snr_series_dict = snr_optimizer.compute_network_snr_core(opt_params, *extra_args)

mtotal = cv.mtotal_from_mchirp_eta(opt_params[0], opt_params[1])
mass1 = cv.mass1_from_mtotal_eta(mtotal, opt_params[1])
mass2 = cv.mass2_from_mtotal_eta(mtotal, opt_params[1])
spin1z = opt_params[2]
spin2z = opt_params[3]

# Prepare for GraceDB upload
coinc_results = {}
skyloc_data = {}

loudest_ifo = None
loudest_snr_allifos = 0
loudest_snr_time_allifos = None
loudest_snrs = {}
loudest_snr_times = {}
loudest_snr_idxs = {}

# Determine which ifo has the loudest SNR peak ... We'll use this to determine
# the time window for other ifos (e.g. if one ifo has a loud SNR peak, and the
# other has SNR < 4, we would need to determine the loudest SNR for the quieter
# instrument within light-travel time of the loud SNR peak)
for ifo in ifos:
    duration = 0.095
    half_dur_samples = int(sample_rate * duration / 2)
    onsource_idx = float(coinc_times[ifo] - snr_series_dict[ifo].start_time) \
        * snr_series_dict[ifo].sample_rate
    onsource_idx = int(round(onsource_idx))
    onsource_slice = slice(onsource_idx - half_dur_samples,
                           onsource_idx + half_dur_samples + 1)
    max_snr_idx = numpy.argmax(abs(snr_series_dict[ifo][onsource_slice]))
    max_snr_idx = max_snr_idx + onsource_idx - half_dur_samples
    max_snr = snr_series_dict[ifo][max_snr_idx]
    max_snr_time = snr_series_dict[ifo].start_time \
        + max_snr_idx*(1./sample_rate)
    loudest_snr_idxs[ifo] = max_snr_idx
    loudest_snr_times[ifo] = float(max_snr_time)
    loudest_snrs[ifo] = max_snr
    if abs(max_snr) > loudest_snr_allifos:
        loudest_snr_allifos = abs(max_snr)
        loudest_ifo = ifo
        loudest_snr_time_allifos = loudest_snr_times[ifo]

# Now create the snr_series_dict
for ifo in ifos:
    # Find loudest SNR within coincidence window of the largest peak
    snr_peak = abs(loudest_snrs[ifo])
    # And check light travel time
    lttbd = Detector(ifo).light_travel_time_to_detector(Detector(loudest_ifo))
    # Add small buffer
    lttbd += 0.005

    time_window = [loudest_snr_time_allifos-lttbd,
                   loudest_snr_time_allifos+lttbd]

    snr_slice = snr_series_dict[ifo].time_slice(time_window[0], time_window[1])
    max_snr_idx = numpy.argmax(abs(snr_slice))
    idx_offset = snr_slice.start_time - snr_series_dict[ifo].start_time
    idx_offset = int(float(idx_offset) * sample_rate + 0.5)
    max_snr_idx = max_snr_idx + idx_offset
    new_snr_slice = slice(max_snr_idx - int(sample_rate/10),
                          max_snr_idx + int(sample_rate/10)+1)
    snr_series_dict[ifo] = snr_series_dict[ifo][new_snr_slice]

netsnr = 0
for idx, ifo in enumerate(ifos):
    coinc_results['foreground/'+ifo+'/mass1'] = mass1
    coinc_results['foreground/'+ifo+'/mass2'] = mass2
    coinc_results['foreground/'+ifo+'/spin1z'] = spin1z
    coinc_results['foreground/'+ifo+'/spin2z'] = spin2z
    coinc_results['foreground/'+ifo+'/f_lower'] = flow
    coinc_results['foreground/'+ifo+'/window'] = 0.1
    coinc_results['foreground/'+ifo+'/sample_rate'] = int(sample_rate)
    coinc_results['foreground/'+ifo+'/template_id'] = 0
    # Apparently GraceDB gets upset if chi-squared is not set
    coinc_results['foreground/'+ifo+'/chisq'] = 1.
    coinc_results['foreground/'+ifo+'/chisq_dof'] = 1
    coinc_results['foreground/'+ifo+'/template_duration'] = \
        fp['template_duration'][()]
    skyloc_data[ifo] = {}
    skyloc_data[ifo]['psd'] = interpolate(data[ifo].psd, 0.25)
    skyloc_data[ifo]['snr_series'] = snr_series_dict[ifo]

    # Find loudest SNR within coincidence window of the largest peak
    snr_peak = abs(loudest_snrs[ifo])
    # And check light travel time
    lttbd = Detector(ifo).light_travel_time_to_detector(Detector(loudest_ifo))
    # Add small buffer
    lttbd += 0.005

    time_window = [loudest_snr_time_allifos - lttbd,
                   loudest_snr_time_allifos + lttbd]

    snr_slice = snr_series_dict[ifo].time_slice(time_window[0], time_window[1])
    max_snr_idx = numpy.argmax(abs(snr_slice))
    loudest_snr_time = snr_slice.sample_times[max_snr_idx]
    loudest_snr = snr_slice[max_snr_idx]

    coinc_results['foreground/'+ifo+'/end_time'] = loudest_snr_time
    coinc_results['foreground/'+ifo+'/snr_series'] = snr_series_dict[ifo]
    coinc_results['foreground/'+ifo+'/psd_series'] = data[ifo].psd
    coinc_results['foreground/'+ifo+'/delta_f'] = delta_f
    coinc_results['foreground/'+ifo+'/snr'] = abs(loudest_snr)
    netsnr += abs(loudest_snr)**2
    coinc_results['foreground/'+ifo+'/sigmasq'] \
        = snr_series_dict['sigmasq_' + ifo]
    coinc_results['foreground/'+ifo+'/coa_phase'] \
        = numpy.angle(loudest_snr)

coinc_results['foreground/stat'] = numpy.sqrt(netsnr)
coinc_results['foreground/ifar'] = fp['ifar'][()]

channel_names = {}
for ifo in fp['channel_names'].keys():
    channel_names[ifo] = fp['channel_names'][ifo].asstr()[()]

mc_area_args = load_hdf5_to_dict(fp, 'mc_area_args/')

kwargs = {'psds': {ifo: skyloc_data[ifo]['psd'] for ifo in ifos},
          'low_frequency_cutoff': flow,
          'skyloc_data': skyloc_data,
          'channel_names': channel_names,
          'mc_area_args': mc_area_args}

# Do not recalculate p_terr/p_astro, reuse previous value if available
# (source probabilities may be recalculated)
if 'p_terr' in fp:
    kwargs['p_terr'] = fp['p_terr'][()]

# Treat all ifos as having triggers
doc = CandidateForGraceDB(
    ifos,
    ifos,
    coinc_results,
    upload_snr_series=True,
    **kwargs
)

xml_path = os.path.join(
    args.output_path,
    f'coinc-{loudest_snr_time_allifos:.3f}.xml.gz'
)

if args.enable_gracedb_upload:
    logging.info('Uploading optimized candidate to GraceDB')

    comment = ('Automatic PyCBC followup of trigger '
               '<a href="/events/{0}/view">{0}</a> to find the template '
               'parameters that maximize the SNR. The FAR of this trigger is '
               'copied from {0} and does not reflect this trigger\'s template '
               'parameters.')
    comment = comment.format(original_gid)

    gid = doc.upload(xml_path, gracedb_server=args.gracedb_server,
                     testing=(not args.production), extra_strings=[comment],
                     search=args.gracedb_search, labels=args.gracedb_labels)
    if gid is not None:
        logging.info('Event uploaded as %s', gid)

        # add a note to the original G event pointing to the optimized one
        from ligo.gracedb.rest import GraceDb

        gracedb = GraceDb(args.gracedb_server) \
                if args.gracedb_server is not None else GraceDb()
        comment = ('Result of SNR maximization uploaded as '
                   '<a href="/events/{0}/view">{0}</a>').format(gid)
        gracedb.write_log(
            original_gid,
            comment,
            tag_name=['analyst_comments']
        )
else:
    logging.info('Saving optimized candidate')
    doc.save(xml_path)

logging.info('Done')
