#!/usr/bin/python3.13 -s
# -*- coding: utf-8 -*-
#
# Copyright (C) 2019 Duncan Macleod, Francesco Pannarale, Erin Vincent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Cluster triggers from a pyGRB run
"""

import argparse
import operator
import os
import re
from collections import defaultdict
from itertools import compress
from functools import reduce

import tqdm

import numpy

import h5py

from gwdatafind.utils import filename_metadata

from ligo.segments import segmentlist
from ligo.segments.utils import fromsegwizard

from pycbc import add_common_pycbc_options, init_logging
from pycbc.inject import InjectionSet
from pycbc.io.hdf import HFile
from pycbc.results.pygrb_postprocessing_utils import template_hash_to_id

__author__ = "Duncan Macleod <duncan.macleod@ligo.org>"

SPLIT_FILENAME = re.compile(r"_(SPLIT\d+)")

NETWORK_IFO_EVENT_ID_REGEX = re.compile(
    r"\Anetwork/(?P<ifo>[A-Z]1)_event_id\Z",
)
EVENT_ID_REGEX = re.compile(r"event_id\Z")

TQDM_BAR_FORMAT = ("{desc}: |{bar}| "
                   "{n_fmt}/{total_fmt} {unit} ({percentage:3.0f}%) "
                   "[{elapsed} | ETA {remaining}]{postfix}")
TQDM_KW = {
    "ascii": " -=#",
    "bar_format": TQDM_BAR_FORMAT,
    "smoothing": 0.05,
}

# -- utilities ----------------------------------

def find_split_files(filelist, split):
    for fn in filelist:
        if (split is None and 'SPLIT' not in fn) or split in fn:
            yield fn


def read_segment_files(segfiles):
    def _read(name):
        with open(name, "r") as f:
            return fromsegwizard(f)

    return segmentlist(reduce(
        operator.or_,
        map(_read, segfiles),
        segmentlist()))


def read_hdf5_triggers(inputfiles, verbose=False):
    """Merge several HDF5 files into a single file

    Parameters
    ----------
    inputfiles : `list` of `str`
        the paths of the input HDF5 files to merge
    """
    datasets = {}

    def _scan_dataset(name, obj):
        if not isinstance(obj, h5py.Dataset):
            return
        elif name.startswith("network") and \
            NETWORK_IFO_EVENT_ID_REGEX.match(name):
                return
        elif "template_id" in name:
            return
        elif "search" in name:
            return
        elif "gating" in name:
            return
        else:
            shape = obj.shape
            dtype = obj.dtype
            try:
                shape = numpy.sum(datasets[name][0] + shape, keepdims=True)
            except KeyError:
                pass
            else:
                assert dtype == datasets[name][1], (
                    "Cannot merge {0}/{1}, does not match dtype".format(
                        obj.file.filename, name,
                    ))
            datasets[name] = (shape, dtype)

    # get list of datasets
    for filename in inputfiles:
        with HFile(filename, 'r') as h5f:
            h5f.visititems(_scan_dataset)

    position = defaultdict(int)

    out_dict = dict.fromkeys(datasets.keys(), {})

    # create datasets
    for dset, (shape, dtype) in datasets.items():
        out_dict[dset] = numpy.empty(shape, dtype=dtype)

    # copy dataset contents
    for filename in inputfiles:
        with HFile(filename, 'r') as h5in:
            # Skipping empty files
            if len(h5in['network']['end_time_gc'][:]):
                for dset in datasets:
                    data = h5in[dset][:]
                    size = data.shape[0]
                    pos = position[dset]
                    if EVENT_ID_REGEX.search(dset):
                        out_dict[dset][pos:pos+size] = data + pos
                    else:
                        out_dict[dset][pos:pos+size] = data
                    position[dset] += size

    return out_dict

# -- parse command line -------------------------

parser = argparse.ArgumentParser(
    description=__doc__,
)

add_common_pycbc_options(parser)

# input/output
parser.add_argument(
    "-f",
    "--input-files",
    nargs="+",
    help="path(s) to input trigger file(s)",
)
parser.add_argument(
    "-j",
    "--inj-files",
    nargs="+",
    help="path(s) to input injection file(s)",
)
parser.add_argument(
    "--bank-file",
    required=True,
    help="full template bank file for mapping hashes to ids",
)
parser.add_argument(
    "-o",
    "--output-dir",
    default=os.getcwd(),
    help="output directory (default: %(default)s)",
)
parser.add_argument(
    "-i",
    "--ifo-tag",
    help="the ifo tag, H1L1 or H1L1V1 for instance",
)

# parameters
parser.add_argument(
    "-W",
    "--time-window",
    type=float,
    default=0.,
    help="the found time window (default: %(default)s)",
)
parser.add_argument(
    "-c",
    "--rank-column",
    default="coherent_snr",
    help="column over which to rank events (default: %(default)s)",
)
parser.add_argument(
    "-b",
    "--exclude-segments",
    action="append",
    default=[],
    help="ignore injections in segments found within "
         "these files, e.g. buffer segments (may be given "
         "more than once",
)

args = parser.parse_args()

init_logging(args.verbose)

vprint = print if args.verbose else str
win = args.time_window

# -- find injections ----------------------------

exclude = read_segment_files(args.exclude_segments)

nexcluded = 0
missed = []
found = []

# dictionary that will contain ifo names and 'network' as keys and information
# on triggers to be written to file 
out_trigs = {}
# open first injection file
firstfile = HFile(args.input_files[0], "r")
for key in firstfile.keys():
    # associate ifo names and 'network' with [] so we can append later 
    out_trigs[key] = defaultdict(list)

allinjections = None

with tqdm.tqdm(args.inj_files, desc="Finding injections",
               disable=not args.verbose, unit="files",
               postfix=dict(found=0, missed=0, excluded=nexcluded),
               **TQDM_KW) as bar:
    for injf in bar:
        # get list of trigger files
        try:
            split, = SPLIT_FILENAME.search(injf).groups()
        except AttributeError:  # no regex match
            split = None
        trigfiles = list(find_split_files(args.input_files, split))

        # read injections and filter by exclude segments
        tmpinj = InjectionSet(injf).table
        injtime = numpy.array([inj.tc for inj in tmpinj])
        keep = numpy.asarray([t not in exclude for t in injtime])
        nexcluded += (~keep).sum()
        injections = tmpinj[keep]
        injtime = injtime[keep]

        # read triggers
        triggers = read_hdf5_triggers(trigfiles)
        time = triggers['network/end_time_gc']
        time_sorting = time.argsort()
        snr = triggers['network/'+args.rank_column][time_sorting]
        event_id = triggers['network/event_id'][time_sorting]

        # determine found or missed
        _left = numpy.searchsorted(
            time[time_sorting],
            injtime - args.time_window,
            side='left',
        )
        _right = numpy.searchsorted(
            time[time_sorting],
            injtime + args.time_window,
            side='right',
        )

        # get indices of found/missed injections and finding trigger
        for i, (l, r) in enumerate(
                zip(_left, _right),
                start=len(allinjections if allinjections is not None else []),
        ):
            if not r - l:
                missed.append(i)
            else:
                found.append(i)
                # the needed event_id is unique within its split_inj file
                eid = event_id[l] if r - l == 1 \
                    else event_id[l + snr[l:r].argmax()]
                for x in triggers:
                    ij = x.split('/')
                    # assign a new, overall unique event_id to a found
                    # injection by simply counting found injections
                    if ij[1] == 'event_id':
                        out_trigs[ij[0]][ij[1]].append(len(found)-1)
                    else:
                        out_trigs[ij[0]][ij[1]].append(triggers[x][eid])

        # record all injections
        if allinjections is None:
            allinjections = injections
        else:
            allinjections = allinjections.append(injections)

        # update progress bar
        bar.set_postfix(
            found=len(found),
            missed=len(missed),
            excluded=nexcluded,
        )

# prepare output hdf5 file
ifotag, desc, segment = filename_metadata(args.inj_files[0])
desc = SPLIT_FILENAME.split(desc)[0]
outfilename = os.path.join(
    args.output_dir,
    "{}-{}_FOUNDMISSED-{}-{}.h5".format(
        args.ifo_tag or ifotag,
        desc,
        segment[0],
        abs(segment),
    ),
)

with HFile(outfilename, "w") as h5out:
    # write injections from recarray
    for name, injlist in (("found", found), ("missed", missed)):
        grp = h5out.create_group(name)
        injs = allinjections[injlist]
        for field in injs.fieldnames:
            # h5py does not support unicode strings
            if 'U' not in str(injs[field].dtype):
                grp[field] = injs[field]
            else:
                grp[field] = [str(s) for s in injs[field]]
                
    # write triggers
    for x in out_trigs:
        xgroup = h5out.create_group(x)
        for col in out_trigs[x]:
            if "template_id" == col:
                continue
            xgroup.create_dataset(
                col,
                data=out_trigs[x][col],
                compression="gzip",
                compression_opts=9,
            )
    # map and write template ids
    ids = template_hash_to_id(h5out, args.bank_file)
    h5out.create_dataset(
        "network/template_id",
        data=ids,
        compression="gzip",
        compression_opts=9,
    )
if args.verbose:
    print("Found/missed written to {}".format(outfilename))
