#!/usr/bin/python3.13 -s

"""
Simple wrapper around bayestar-localize-coincs to get sky localization for a
specific compact binary merger event.

Uses an XML containing SNR timeseries and PSD and calls BAYESTAR to produce a
sky localization.

The wrapper needs to exist as the XML format doesnt allow us to indicate
the low frequency cutoff or waveform used. We also want to be able to control
the output filename which is not possible in bayestar-localize-coincs

Unrecognised options will be passed straight to the bayestar subprocess
"""

import argparse
import subprocess
import shutil
import logging
import os
import glob
import sys
import random
import tempfile

from ligo.lw import lsctables, utils as ligolw_utils

import pycbc
from pycbc import init_logging
from pycbc.waveform import bank as wavebank
from pycbc.io import WaveformArray
from pycbc.io.ligolw import LIGOLWContentHandler

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--bayestar-executable',
                    help="The bayestar-localize-coinc executable to be run. "
                         "If not given, will use whatever is available in "
                         "the current environment.")
parser.add_argument('--event-xml', required=True,
                    help="XML file containing event information, SNR "
                         "timeseries and PSD to pass to bayestar")
parser.add_argument('--low-frequency-cutoff', type=float, default=20,
                    help="Low-frequency cutoff used in the matched-filtering "
                         "to generate the SNR timeseries")
parser.add_argument('--output-file', required=True,
                    help="Filename to output the fits file to.")
wavebank.add_approximant_arg(parser)
args, unknown = parser.parse_known_args()

# Default logging level is info: --verbose adds to this
pycbc.init_logging(args.verbose, default_level=1)

logging.info("Starting")

bayestar_exe = args.bayestar_executable or 'bayestar-localize-coincs'

tmpdir = tempfile.mkdtemp()

# Work out which approximant is being used
# Load the file
xmldoc = ligolw_utils.load_filename(
    args.event_xml,
    contenthandler=LIGOLWContentHandler
)

# Grab the single inspiral table(s) which contain the template information
sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)

# turn this into a waveform array
row = WaveformArray.from_ligolw_table(
    sngl_inspiral_table,
)

# Work out which waveform is used based on the template information
# Note that as there are multiple single inspiral tables (one per ifo)
# with the same information, we use [0] here
waveform = wavebank.parse_approximant_arg(args.approximant, row)[0]

# BAYESTAR uses TaylorF2 instead of SPAtmplt
if waveform == 'SPAtmplt':
    waveform = 'TaylorF2'

# Set up the command to pass to bayestar.
# The XML file path being passed twice is a legacy requirement, not a mistake.
cmd = [bayestar_exe,
       args.event_xml,
       args.event_xml,
       '--waveform', waveform,
       '--f-low', str(args.low_frequency_cutoff),
       '-o', tmpdir]

# Pass any unknown options straight to the subprocess
cmd += unknown

logging.info("Running %s", ' '.join(cmd))
subprocess.check_output(cmd)

# Find the fits file in the temporary directory:
# It would be nice to do this better! - maybe use the input xml
# file to find the number which is going to be used?
fits_filenames = glob.glob(os.path.join(tmpdir, '*.fits*'))
if len(fits_filenames) != 1:
    raise ValueError(f'argh, got {len(fits_filenames)} FITS files after running BAYESTAR, I want one!')

logging.info("Moving output to %s", args.output_file)
shutil.move(fits_filenames[0], args.output_file)

logging.info("Done")
