#!/usr/bin/python3.13 -s
# -*- coding: utf-8 -*-
#
# Copyright (C) 2019 Duncan Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Cluster triggers from a pyGRB run
"""

import argparse
import os
import shutil
import sys
import logging

import tqdm

import numpy

import h5py

from gwdatafind.utils import filename_metadata

from pycbc import init_logging, add_common_pycbc_options
from pycbc.io.hdf import HFile

__author__ = "Duncan Macleod <duncan.macleod@ligo.org>"

TQDM_BAR_FORMAT = ("{desc}: |{bar}| "
                   "{n_fmt}/{total_fmt} {unit} ({percentage:3.0f}%) "
                   "[{elapsed} | ETA {remaining}]{postfix}")
TQDM_KW = {
    "ascii": " -=#",
    "bar_format": TQDM_BAR_FORMAT,
    "smoothing": 0.05,
}


# -- utilities ----------------------------------

def slice_hdf5(inputfile, outfile, include, verbose=False):
    """Create a new HDF5 file containing a slice of the network events

    Here ``include`` should be an index array
    """
    include = numpy.array(include, copy=False)
    if include.dtype == numpy.bool_:
        nevents = include.sum()
    else:
        nevents = include.size

    with HFile(inputfile, "r") as h5in:
        ifos = [k for k in h5in.keys() if k != "network"]

        # find which single-ifo events to keep
        ifo_index = {
            ifo: numpy.unique(
                h5in["network/{}_event_id".format(ifo)][:][include],
            ) for ifo in ifos
        }

        nsets = sum(isinstance(item, h5py.Dataset)
                    or isinstance(item, h5py.Group) for
                    group in h5in.values() for
                    item in group.values())
        msg = "Slicing {} network events into new file".format(nevents)
        bar = tqdm.tqdm(total=nsets, desc=msg, disable=not verbose,
                        unit="datasets", **TQDM_KW)
        with HFile(outfile, "w") as h5out:
            for old in h5in["network"].values():
                if isinstance(old, h5py.Dataset):
                    h5out.create_dataset(
                        old.name,
                        data=old[:][include],
                        compression=old.compression,
                        compression_opts=old.compression_opts,
                    )
                    bar.update()
                elif isinstance(old, h5py.Group):
                    if "search" in old.name:
                        h5in.copy(h5in[old.name], h5out, old.name)
                        bar.update()
            for ifo in ifos:
                idx = numpy.in1d(h5in[ifo]["event_id"][()], ifo_index[ifo])
                for old in h5in[ifo].values():
                    if isinstance(old, h5py.Dataset):
                        h5out.create_dataset(
                            old.name,
                            data=old[:][idx],
                            compression=old.compression,
                            compression_opts=old.compression_opts,
                        )
                        bar.update()
                    elif isinstance(old, h5py.Group):
                        if "search" in old.name:
                            h5in.copy(h5in[old.name], h5out, old.name)
                            bar.update()
        bar.close()
        if verbose:
            print("Slice written to {}".format(outfile))


# -- parse command line -------------------------

parser = argparse.ArgumentParser(
    description=__doc__,
)

add_common_pycbc_options(parser)

# clustering
parser.add_argument(
    "-W",
    "--time-window",
    type=float,
    default=0.1,
    help="the cluster time window (default %(default)s)",
)
parser.add_argument(
    "-c",
    "--rank-column",
    default="coherent_snr",
    help="column over which to rank events (default: %(default)s)",
)

# input/output
parser.add_argument(
    "-t",
    "--trig-file",
    required=True,
    help="path to input trigger file",
)
parser.add_argument(
    "-o",
    "--output-dir",
    default=os.getcwd(),
    help="output directory (default: %(default)s)",
)

args = parser.parse_args()

init_logging(args.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

win = args.time_window

ifotag, filetag, segment = filename_metadata(args.trig_file)
start, end = segment
outfile = os.path.join(
    args.output_dir,
    "{}-{}_CLUSTERED-{}-{}.h5".format(
        ifotag,
        filetag,
        start,
        end - start,
    ),
)

# this list contains the indexing of clusters from all slides
all_clusters = []

# load necessary information from all triggers

with HFile(args.trig_file, "r") as h5f:
    all_times = h5f["network/end_time_gc"][()]
    all_snrs = h5f[f"network/{args.rank_column}"][()]
    slide_ids = h5f["network/slide_id"][()]

# empty file (no triggers), so just copy the file
if not all_times.size:
    shutil.copyfile(args.trig_file, outfile)
    msg = "trigger file is empty\n"
    msg += "copied input file to {}".format(outfile)
    logging.info(msg)
    sys.exit(0)

# -- cluster ------------------------------------

unique_slide_ids = numpy.unique(slide_ids)
max_slide_id = max(unique_slide_ids)
msg = 'Clustering '+str(len(slide_ids))+' triggers from '
msg += str(len(unique_slide_ids))+' slides'
logging.info(msg)

for slide_id in unique_slide_ids:
    # indices to slice current slide
    slide_id_pos = numpy.where(slide_ids == slide_id)[0]
    # all time and snr values for the current slide
    time = all_times[slide_id_pos]
    snr = all_snrs[slide_id_pos]

    # generate clustering bins
    nbins = int((end - start) // win + 1)
    bins = [[] for i in range(nbins)]
    loudsnr = numpy.zeros(nbins)
    loudtime = numpy.zeros(nbins)
    # list to index clusters for current slide
    clusters = []

    # find loudest trigger in each bin, for the current slide
    for i in tqdm.tqdm(range(time.size),
                       desc="Initialising bins",
                       disable=not args.verbose,
                       total=time.size,
                       unit='triggers',
                       **TQDM_KW):
        t, s = time[i], snr[i]
        idx = int(float(t - start) // win)
        bins[idx].append(i)
        if s > loudsnr[idx]:
            loudsnr[idx] = s
            loudtime[idx] = t

    prev = -1
    nxt_ = 1
    first = True
    last = False
    add_cluster = clusters.append
    nclusters = 0

    # cluster
    bar = tqdm.tqdm(bins,
                    desc="Clustering bins",
                    disable=not args.verbose,
                    total=nbins,
                    unit='bins',
                    postfix=dict(nclusters=0),
                    **TQDM_KW)
    for i, bin_ in enumerate(bar):
        if not bin_:  # empty
            continue

        for idx in bin_:
            t, s = time[idx], snr[idx]

            if s < loudsnr[i]:  # not loudest in own bin
                continue

            # check loudest event in previous bin
            if not first:
                prevt = loudtime[prev]
                if prevt and abs(prevt - t) < win and s < loudsnr[prev]:
                    continue

            # check loudest event in next bin
            if not last:
                nextt = loudtime[nxt_]
                if nextt and abs(nextt - t) < win and s < loudsnr[nxt_]:
                    continue

            loudest = True

            # check all events in previous bin
            if not first and prevt and abs(prevt - t) < win:
                for id2 in bins[prev]:
                    if abs(time[id2] - t) < win and s < snr[id2]:
                        loudest = False
                        break

            # check all events in next bin
            if loudest and not last and nextt and abs(nextt - t) < win:
                for id2 in bins[nxt_]:
                    if abs(time[id2] - t) < win and s < snr[id2]:
                        loudest = False
                        break

            # this is loudest in its vicinity, keep it
            if loudest:
                add_cluster(idx)
                nclusters += 1
                bar.set_postfix(nclusters=nclusters)

        # update things for next time
        first = False
        last = i == nbins - 1
        prev += 1
        nxt_ += 1

        bar.update()

    # clusters is the indexing array for a specific slide_id
    # all_clusters is the (absolute) indexing of all clustered triggers
    # so look up the indices [clusters] within the absolute indexing array
    # slide_id_pos which is built at each slide_id
    all_clusters += list(slide_id_pos[clusters])
    msg = 'Slide '+str(slide_id)+'/'+str(max_slide_id)
    msg += ' has '+str(len(slide_id_pos))
    msg += ' trigers that were clustered to '+str(len(clusters))
    logging.info(msg)

logging.info('Total clustered triggers: '+str(len(all_clusters)))

# -- write output --------------------------------

slice_hdf5(
    args.trig_file,
    outfile,
    numpy.asarray(all_clusters),
    verbose=args.verbose,
)
