#!/usr/bin/python3.13 -s

# Copyright (C) 2015-2023 Ian Harry, Gareth Cabourn Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Take a coinc xml file containing multiple events and prepare one of them
for upload to gracedb.
"""

import argparse
import logging
import numpy as np
import matplotlib
matplotlib.use('agg')

import lal
import lal.series
from ligo.lw import lsctables
from ligo.lw import utils as ligolw_utils
from ligo.segments import segment, segmentlist

import pycbc
from pycbc.io import HFile
from pycbc.io.ligolw import (
    LIGOLWContentHandler,
    snr_series_to_xml,
)
from pycbc.psd import interpolate
from pycbc.types import FrequencySeries, TimeSeries, load_timeseries
from pycbc.types import MultiDetOptionAction
from pycbc.results import generate_asd_plot, generate_snr_plot

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--psd-files", nargs='+', required=True,
                    help='HDF file(s) containing the PSDs to upload')
parser.add_argument("--snr-timeseries", nargs='+', required=True,
                    help='HDF file(s) containing the SNR timeseries to upload')
parser.add_argument('--input-file', required=True, type=str,
                    help='Input LIGOLW XML file of coincidences.')
parser.add_argument('--event-id', type=int, required=True,
                    help='ID for the event being prepared.')
parser.add_argument('--output-file',
                    help="Output filename for the locally stored XML. ")
parser.add_argument('--snr-timeseries-plot',
                    help="Output filename for the plot of SNR timeseries. "
                         "Must be a plot filetype as used by matplotlib.")
parser.add_argument('--psd-plot',
                    help="Output filename for the plot of the PSD. "
                         "Must be a plot filetype as used by matplotlib.")
parser.add_argument('--channel-name', action=MultiDetOptionAction,
                    required=True,
                    help="Channel name for adding to the uploaded single "
                         "inspiral table.")
parser.add_argument('--delta-f', type=float, default=0.25,
                    help="Frequency spacing of the PSD to be uploaded, Hz. "
                         "Default 0.25")

args = parser.parse_args()

# Default logging level is info: --verbose adds to this
pycbc.init_logging(args.verbose, default_level=1)

xmldoc = ligolw_utils.load_filename(args.input_file,
                             contenthandler=LIGOLWContentHandler)

class psd_segment(segment):
    def __new__(cls, psd, *args):
        return segment.__new__(cls, *args)
    def __init__(self, psd, *args):
        self.psd = psd

def read_psds(psd_files):
    logging.info("Reading PSDs")
    psds = {}

    for psd_file in psd_files:
        (ifo, group), = HFile(psd_file, "r").items()
        psd = [group["psds"][str(i)] for i in range(len(group["psds"].keys()))]
        psds[ifo] = segmentlist(psd_segment(*segargs) for segargs in zip(
            psd, group["start_time"], group["end_time"]))
    return psds

psds = read_psds(args.psd_files)

def read_snr_timeseries(snr_timeseries_files):
    """
    Get information from the single_template snr timeseries files
    """
    logging.info("Reading SNR timeseries")
    snr_timeseries = {}
    for snr_filename in snr_timeseries_files:
        with HFile(snr_filename, 'r') as snr_f:
            ifo = snr_f.attrs['ifo']
            time = snr_f.attrs['event_time']
        snr_timeseries[(ifo, time)] = load_timeseries(snr_filename, 'snr')
    return snr_timeseries

snr_timeseries = read_snr_timeseries(args.snr_timeseries)

coinc_table = lsctables.CoincTable.get_table(xmldoc)
coinc_inspiral_table = lsctables.CoincInspiralTable.get_table(xmldoc)
coinc_event_map_table = lsctables.CoincMapTable.get_table(xmldoc)
sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)

xmldoc.childNodes[-1].removeChild(sngl_inspiral_table)
xmldoc.childNodes[-1].removeChild(coinc_event_map_table)
xmldoc.childNodes[-1].removeChild(coinc_inspiral_table)
xmldoc.childNodes[-1].removeChild(coinc_table)

# Find the coinc_table entry corresponding to the desired event ID
for event in coinc_table:
    coinc_event_id = event.coinc_event_id
    if args.event_id == coinc_event_id:
        # This is the event we want
        break
else:
    raise ValueError(f"event {args.event_id} not found in coinc table")

coinc_event_table_curr = lsctables.New(lsctables.CoincTable)
coinc_event_table_curr.append(event)
coinc_inspiral_table_curr = lsctables.New(lsctables.CoincInspiralTable)
coinc_event_map_table_curr = lsctables.New(lsctables.CoincMapTable)
sngl_inspiral_table_curr = lsctables.New(lsctables.SnglInspiralTable)

for coinc_insp in coinc_inspiral_table:
    if coinc_insp.coinc_event_id == event.coinc_event_id:
        coinc_inspiral_table_curr.append(coinc_insp)

time = coinc_inspiral_table_curr[0].end_time \
        + coinc_inspiral_table_curr[0].end_time_ns * 1e-9

sngl_ids = []
for coinc_map in coinc_event_map_table:
    if coinc_map.coinc_event_id == event.coinc_event_id:
        coinc_event_map_table_curr.append(coinc_map)
        sngl_ids.append(coinc_map.event_id)

# Get the SNR timeseries and PSDs at the time of this event
# these may not have contributed a trigger

# Check that the SNR timeseries correspond to the event
# - allow 0.5s either side of the time
snr_ts = {}
for ifo, t in snr_timeseries.keys():
    if not abs(t - time) < 0.5:
        raise ValueError("SNR timeseries for IFO %s does not look like it "
            "corresponds to this event, event time %.3f, SNR timeseries is "
            "around time %.3f" % (ifo, time, t))
    # Convert the SNR dict to be keyed on IFO-only:
    snr_ts[ifo] = snr_timeseries[(ifo, t)]

# IFOs from SNR timeseries:
psds_event = {}
psddict = {}
for ifo in snr_ts.keys():
    psd = psds[ifo]

    psd = psd[psd.find(time)].psd
    # resample the psd to new spacing
    psd_fs = FrequencySeries(psd, delta_f=psd.attrs["delta_f"],
                             dtype=np.float64)
    psd_fs = interpolate(psd_fs, args.delta_f)

    psds_event[ifo] = psd
    psddict[ifo] = psd_fs

lal_psddict = {}
sample_freqs = {}

for sngl in sngl_inspiral_table:
    if sngl.event_id not in sngl_ids:
        continue
    sngl_inspiral_table_curr.append(sngl)
    sngl.channel = args.channel_name[sngl.ifo]

    # Two versions of the PSD dictionary have useful info here
    psd = psds_event[sngl.ifo]
    psd_fs = psddict[sngl.ifo]

    flow = psd.file.attrs['low_frequency_cutoff']
    kmin = int(flow / args.delta_f)

    fseries = lal.CreateREAL8FrequencySeries(
        "psd",
        lal.LIGOTimeGPS(int(psd.attrs["epoch"])),
        kmin * args.delta_f,
        args.delta_f,
        lal.StrainUnit**2 / lal.HertzUnit,
        len(psd_fs) - kmin)
    fseries.data.data = psd_fs[kmin:] / np.square(pycbc.DYN_RANGE_FAC)
    lal_psddict[sngl.ifo] = fseries

    snr_series_to_xml(snr_ts[sngl.ifo], xmldoc, sngl.event_id)

xmldoc.childNodes[-1].appendChild(coinc_event_table_curr)
xmldoc.childNodes[-1].appendChild(coinc_inspiral_table_curr)
xmldoc.childNodes[-1].appendChild(coinc_event_map_table_curr)
xmldoc.childNodes[-1].appendChild(sngl_inspiral_table_curr)
lal.series.make_psd_xmldoc(lal_psddict, xmldoc.childNodes[-1])

ligolw_utils.write_filename(xmldoc, args.output_file)
logging.info("Saved XML file %s", args.output_file)

if args.psd_plot:
    logging.info("Saving ASD plot %s", args.psd_plot)
    generate_asd_plot(psddict, args.psd_plot)

if args.snr_timeseries_plot:
    triggers = {sngl.ifo: (sngl.end_time + sngl.end_time_ns * 1e-9, sngl.snr)
                for sngl in sngl_inspiral_table_curr}
    base_time = int(np.floor(time))
    logging.info("Saving SNR plot %s", args.snr_timeseries_plot)
    generate_snr_plot(
        snr_ts,
        args.snr_timeseries_plot,
        triggers,
        base_time
    )

logging.info('Done!')
