#!/usr/bin/python3.13 -s

"""
Splits injections created using pycbc_create_injections into smaller sets.
Split sets are organized to maximize time between injections.
"""

import argparse
import numpy as np

import pycbc
from pycbc.inject import InjectionSet
from pycbc.io.hdf import HFile


# Parse command line
parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("-f", "--output-files", nargs='*', required=True,
                    help="Names of output files")
parser.add_argument("-i", "--input-file", required=True,
                    help="Injection file to be split")

args = parser.parse_args()

pycbc.init_logging(args.verbose)

# Read in input file as both an hdf file and an InjectionSet object
inj_file = HFile(args.input_file, 'r')
inj_set = InjectionSet(args.input_file)

# Define table of injection info
inj_table = inj_set.table

# InjectionSet.write() requires static params as a dictionary,
# so get that from file object.
# Ignore the "static_params" copy.
static_params = {key : inj_file.attrs[key] for key in inj_file.attrs
                 if key != 'static_args'}

# Also get the names of variable params as write_args
write_args = [arg for arg in inj_table.fieldnames
              if arg not in static_params]

num_injs = len(inj_table)
num_splits = len(args.output_files)

# Ideally, the number of injections is divsible by number of splits
# with no remainder, but that is not always true
ideal_split = num_injs // num_splits
remainder = num_injs % num_splits
injs_per_split = np.zeros(num_splits, dtype=int) + ideal_split

# Handle the remainder if it exists
if remainder > 0:
    for i in range(remainder):
        injs_per_split[i] += 1

# Sanity check: did we get account for all injs?
assert sum(injs_per_split) == num_injs, "Not all injections were accounted for!"

# Sort injections by time
inj_table.sort(order='tc')

# Split injections into a list of smaller sets
for i in range(num_splits):
    # Number of injections in this split
    injs_in_split = injs_per_split[i]
    # Spread injections by time so they don't overlap
    injs_to_get = [i+(num_splits*j) for j in range(injs_in_split)]
    # Write to file
    InjectionSet.write(args.output_files[i], inj_table[injs_to_get],
                       write_args, static_params)
