#!/usr/bin/python3.13 -s

# Copyright (C) 2021 Alex Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

""" This program converts standard injection file formats (e.g. XML format
or the LIGO HDF format) into a PyCBC hdf injection format
"""

import argparse
import numpy
import shutil

import pycbc
from pycbc.inject import InjectionSet, CBCHDFInjectionSet
from pycbc.types import float64, float32
from pycbc.io.record import FieldArray
from pycbc.io.hdf import HFile

from pycbc.io.ligolw import LIGOLWContentHandler
from ligo.lw import utils as ligolw_utils, lsctables


# Handlers for alternative injection formats

def legacy_approximant_name(apx):
    """Convert the old style xml approximant name to a name
    and phase_order.
    """
    import lalsimulation as lalsim
    apx = str(apx)
    try:
        order = lalsim.GetOrderFromString(apx)
    except:
        print("Warning: Could not read phase order from string, using default")
        order = -1
    name = lalsim.GetStringFromApproximant(lalsim.GetApproximantFromString(apx))
    return name, order


class XMLInjectionSet(object):

    """Reads injections from LIGOLW XML files

    Parameters
    ----------
    sim_file : string
        Path to a LIGOLW XML file containing a SimInspiralTable
        with injection definitions.

    Attributes
    ----------
    indoc
    table
    """

    def __init__(self, sim_file):  
        self.indoc = ligolw_utils.load_filename(
            sim_file, False, contenthandler=LIGOLWContentHandler)
        self.table = lsctables.SimInspiralTable.get_table(self.indoc)

    def end_times(self):
        """Return the end times of all injections"""
        return [inj.time_geocent for inj in self.table]

    def pack_data_into_pycbc_format_input(self):
        data = {}
        for key in self.table[0].__slots__:
            # Some XML files can have empty columns which are read as None in
            # python. For these cases we ignore the columns so they do not
            # appear in the output HDF file. (Such XML files cannot currently
            # be read by LALSuite C code.) 
            if getattr(self.table[0], key) is not None:
                data[str(key)] = numpy.array(
                    [getattr(t, key) for t in self.table]
                )

        for k in ['simulation_id', 'process_id']:
            a = data.pop(k)

        data['approximant'], data['phase_order'] = numpy.array(
                [legacy_approximant_name(wf) for wf in data['waveform']]
        ).T
        data['tc'] = (
            data['geocent_end_time'] + 1e-9 * data['geocent_end_time_ns']
        )
        data['dec'] = data['latitude']
        data['ra'] = data['longitude']
        return data


class LVKNewStyleInjectionSet(object):
    """Reads injections from new-style LVK injection files
    """

    translation_dict = {
        'mass1_det': 'mass1',
        'mass2_det': 'mass2',
        'f22_ref_spin': 'f_ref',
        't_co_gps': 'tc',
        'd_lum': 'distance',
        'f22_start': 'f_lower',
        'cbc_model': 'approximant',
        'longAscNodes': 'long_asc_nodes',
        'eccentricity': 'eccentricity',
        'meanPerAno': 'mean_per_ano',
        't_co_gps_add': None,
        'ModeArray': 'FIXME', # FIXME (Will need special handling. Need an example)
        'ModeArrayJframe': 'FIXME', # FIXME (See ModeArray above)
    }

    subdir = 'cbc_waveform_params'
    
    def __init__(self, sim_file):
        self.inj_file = HFile(sim_file, 'r')

    def columns_in_pycbc_format(self):
        # Loop over columns in input file
                         
        for field_name in self.inj_file[self.subdir].keys():
            if field_name in self.translation_dict:
                yield (field_name, self.translation_dict[field_name])
            else:
                yield (field_name, field_name)

    def get_coalescence_time(self):
        tcs = self.inj_file[f'{self.subdir}/t_co_gps'][:]
        if 't_co_gps_add' in self.inj_file[self.subdir]:
            tcs += self.inj_file[f'{self.subdir}/t_co_gps_add'][:] 
        return tcs

    def pack_data_into_pycbc_format_input(self):
        data = {}
        for lvk_name, pycbc_name in self.columns_in_pycbc_format():
            if lvk_name == 't_co_gps':
                # Special case
                data[pycbc_name] = self.get_coalescence_time()
            elif lvk_name == 't_co_gps_add':
                continue
            else:
                lvk_file_dset = self.inj_file[f'{self.subdir}/{lvk_name}']
                if lvk_file_dset.dtype.char in ['U', 'O']:
                    data[pycbc_name] = lvk_file_dset[:].astype('S')
                else:
                    data[pycbc_name] = lvk_file_dset[:]
        return data


parser = argparse.ArgumentParser()
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--injection-file', required=True,
                    help="The injection file to load. Must end in '.xml[.gz]' "
                         "and must contain a SimInspiral table")
parser.add_argument('--output-file', required=True,
                    help="The ouput file name. Must end in '.hdf'.")
args = parser.parse_args()

pycbc.init_logging(args.verbose)

injclass = None
if args.injection_file.endswith(('.xml', '.xml.gz', '.xmlgz')):
    injclass = XMLInjectionSet(args.injection_file)
else:
    # Assume a HDF file - check if new LVK format first
    inj_file = HFile(args.injection_file, 'r')
    if 'file_format' in inj_file.attrs or b'file_format' in inj_file.attrs:
        # Assume LVK as PyCBC doesn't have this attribute
        injclass = LVKNewStyleInjectionSet(args.injection_file)
    inj_file.close()

if injclass is not None:
    # The injection file is of known type
    data = injclass.pack_data_into_pycbc_format_input()
    samples = FieldArray.from_kwargs(**data)
    CBCHDFInjectionSet.write(args.output_file, samples)
else:
    # Assume the injection is a HDF file (PyCBC format), only make a copy
    shutil.copy(args.injection_file, args.output_file)
