#!/usr/bin/python3.13 -s

# Copyright (C) 2014 Alex Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Find multi-detector gravitational wave triggers and calculate the
coherent SNRs and related statistics.

To see an example on how to analyze the GW170817 event using this executable,
take a look at:
https://github.com/gwastro/pycbc/blob/master/examples/multi_inspiral/run.sh
"""

import logging
import time
import argparse
import numpy as np

from pycbc import (
    detector,
    fft,
    init_logging,
    inject,
    opt,
    psd,
    scheme,
    strain,
    vetoes,
    waveform,
    DYN_RANGE_FAC,
    add_common_pycbc_options,
)
from pycbc.events import ranking, coherent as coh, EventManagerCoherent
from pycbc.io.hdf import HFile
from pycbc.filter import MatchedFilterControl
from pycbc.types import zeros, float32, complex64
from pycbc.vetoes import sgchisq


def sky_grid_from_cli(parser, args):
    """Read the sky grid or the single sky position given the CLI arguments."""
    if args.sky_grid is not None and (
        args.ra is not None or args.dec is not None
    ):
        parser.error(
            'Please provide either a sky grid via --sky-grid or a '
            'single sky position via --ra and --dec, not both'
        )
    if args.sky_grid is not None:
        with HFile(args.sky_grid, 'r') as sky_grid_file:
            ra = sky_grid_file['ra'][:]
            dec = sky_grid_file['dec'][:]
    elif args.ra is not None and args.dec is not None:
        ra = np.array([args.ra])
        dec = np.array([args.dec])
    else:
        parser.error(
            'Please specify a sky grid via --sky-grid '
            'or a single position via --ra and --dec'
        )
    return np.array([ra, dec])


def calculate_antenna_pattern(args, sky_pos_indices):
    """Calculate the antenna pattern functions for all detectors and sky
    positions.
    """
    antenna_pattern = {}
    for ifo in args.instruments:
        curr_det = detector.Detector(ifo)
        antenna_pattern[ifo] = [None] * len(sky_pos_indices)
        for index in sky_pos_indices:
            antenna_pattern[ifo][index] = curr_det.antenna_pattern(
                sky_positions[0][index],
                sky_positions[1][index],
                polarization=0,
                t_gps=t_gps,
            )
    return antenna_pattern

def slide_limiter(args):
    '''
    This function computes the number of shortslides used by the coherent
    matched filter statistic to obtain as most background triggers as 
    possible.

    It bounds the number of slides to avoid counting triggers more than once.
    If the data is not time slid, there is a single slide for the zero-lag.
    '''
    low, upp = 1, args.segment_length[args.instruments[0]]
    if args.do_shortslides:
        n_ifos = len(args.instruments)
        stride_dur = args.segment_length[args.instruments[0]]/2
        num_slides = np.int32(1 +
            np.floor(stride_dur / (args.slide_shift*(n_ifos-1))))
        assert  np.logical_and(num_slides>=low,num_slides<=upp),\
                "the combination (slideshift, segment_dur)"\
                f" = ({args.slide_shift:.2f},{stride_dur*2:.2f})"\
                f" goes over the allowed upper bound {upp}"
    else:
        num_slides = 1
    return num_slides

# The following block of lines sets up the command-line interface (CLI) for the
# pycbc_multi_inspiral executable.
time_init = time.time()
parser = argparse.ArgumentParser(description=__doc__)
add_common_pycbc_options(parser)
parser.add_argument("--output", type=str)
parser.add_argument(
    "--instruments",
    nargs="+",
    type=str,
    required=True,
    help="List of instruments to analyze.",
)
parser.add_argument("--bank-file", type=str)
parser.add_argument(
    "--low-frequency-cutoff",
    type=float,
    help="The low frequency cutoff to use for filtering (Hz).",
)
# Add approximant arg
waveform.bank.add_approximant_arg(parser)
parser.add_argument(
    "--order",
    type=str,
    help="The integer half-PN order at which to generate the approximant.",
)
parser.add_argument(
    "--taper-template",
    choices=["start", "end", "startend"],
    help="For time-domain approximants, taper the start "
    "and/or end of the waveform before FFTing.",
)
parser.add_argument(
    "--cluster-method",
    choices=["template", "window"],
    default="window",
    help="Method to use when clustering triggers. 'window' - "
    "cluster within a fixed time window defined by the "
    "cluster-window option (default); or 'template' - "
    "cluster within windows defined by each template's "
    "chirp length.",
)
parser.add_argument(
    "--cluster-window",
    type=float,
    default=0,
    help="Length of clustering window in seconds.",
)
parser.add_argument(
    "--bank-veto-bank-file",
    type=str,
    help="Path to the "
    "bank file used to compute the the bank chi-square veto.",
)
parser.add_argument("--chisq-bins", default=0)
# Commenting out unused options: remove if they remain unused
# parser.add_argument("--chisq-threshold", type=float, default=0)
# parser.add_argument("--chisq-delta", type=float, default=0)
parser.add_argument("--autochi-number-points", type=int, default=0,
                    help="The number of points to use, in both directions if"
                         "doing a two-sided auto-chisq, to calculate the"
                         "auto-chisq statistic.")
parser.add_argument("--autochi-stride", type=int, default=0,
                    help="The gap, in sample points, between the points at"
                         "which to calculate auto-chisq.")
parser.add_argument("--autochi-two-phase", action="store_true",
                    default=False,
                    help="If given auto-chisq will be calculated by testing "
                         "against both phases of the SNR time-series. "
                         "If not given, only the phase matching the trigger "
                         "will be used.")
parser.add_argument("--autochi-onesided", action='store',
                    choices=['left','right'],
                    help="Decide whether to calculate auto-chisq using"
                         "points on both sides of the trigger or only on one"
                         "side. If not given points on both sides will be"
                         "used. If given, with either 'left' or 'right',"
                         "only points on that side (right = forward in time,"
                         "left = back in time) will be used.")
parser.add_argument("--autochi-reverse-template", action="store_true",
                    default=False,
                    help="If given, time-reverse the template before"
                         "calculating the auto-chisq statistic. This will"
                         "come at additional computational cost as the SNR"
                         "time-series will need recomputing for the time-"
                         "reversed template.")
parser.add_argument("--autochi-max-valued", action="store_true",
                    default=False,
                    help="If given, store only the maximum value of the auto-"
                         "chisq over all points tested. A disadvantage of this "
                         "is that the mean value will not be known "
                         "analytically.")
parser.add_argument("--autochi-max-valued-dof", action="store", metavar="INT",
                    type=int,
                    help="If using --autochi-max-valued this value denotes "
                         "the pre-calculated mean value that will be stored "
                         "as the auto-chisq degrees-of-freedom value.")
parser.add_argument(
    "--downsample-factor",
    type=int,
    default=1,
    help="Factor that determines the interval between the "
    "initial SNR sampling. If not set (or 1) no sparse "
    "sample is created, and the standard full SNR is "
    "calculated.",
)
parser.add_argument(
    "--upsample-threshold",
    type=float,
    help="The fraction of the SNR threshold to check the "
    "sparse SNR sample.",
)
parser.add_argument(
    "--upsample-method",
    choices=["pruned_fft"],
    default='pruned_fft',
    help="The method to find the SNR points between the sparse SNR sample.",
)
parser.add_argument(
    "--user-tag",
    type=str,
    metavar="TAG",
    help="This is used to identify FULL_DATA jobs for "
    "compatibility with pipedown post-processing. Option "
    "will be removed when no longer needed.",
)
parser.add_argument(
    "--ra",
    type=float,
    help="Right ascension of sky point to search (radians),",
)
parser.add_argument(
    "--dec", type=float, help="Declination of sky point to search (radians),"
)
parser.add_argument(
    "--sky-grid",
    type=str,
    help="Sky-grid to search, an hdf file containing two datasets: "
    "ra and dec, both in radians",
)
parser.add_argument(
    "--coinc-threshold",
    type=float,
    default=0.0,
    help="Triggers with coincident/coherent SNR below this "
    "value will be discarded.",
)
parser.add_argument(
    "--sngl-snr-threshold",
    action="store",
    type=float,
    default=4.0,
    metavar='THRESHOLD',
    help="Single detector SNR threshold for trigger generation required in "
    "at least two detectors (default: 4).",
)
parser.add_argument(
    "--chisq-index",
    action="store",
    type=float,
    default=6.0,
    help="chisq-index (q) for the reweighting of coherent SNR by chi-square "
    "(default: 6).",
)
parser.add_argument(
    "--chisq-nhigh",
    action="store",
    type=float,
    default=2.0,
    help="chisq-nhigh (n) for the reweighting of coherent SNR by chi-square "
    "(default: 2).",
)
parser.add_argument(
    "--do-null-cut",
    action='store_true',
    help="Apply a cut based on null SNR: retained triggers have null SNR "
    "smaller than null-min and coherent SNR smaller than null-step, or null "
    "SNR smaller than (null-grad * coherent SNR + null_min) and coherent SNR "
    "greater than null-step.",
)
parser.add_argument(
    "--null-min",
    type=float,
    default=5.25,
    help="In addition to its usage with the flag --do-null-cut, null-min is "
    "used in SNR reweighting: reweighting happens for triggers with null SNR "
    "greater than (null-min - 1) and coherent SNR smaller than null-step, or "
    "null SNR greater than (null-grad * coherent SNR + null_min - 1) and "
    "coherent SNR greater than null-step (default: 5.25).",
)
parser.add_argument(
    "--null-grad",
    type=float,
    default=0.2,
    help="The gradient of the null SNR cut and/or of the SNR reweighting "
    "theshold when coherent SNR > null_step (default: 0.2).",
)
parser.add_argument(
    "--null-step",
    type=float,
    default=20.0,
    help="The threshold for a second condition set to cut in null SNR "
    "and/or reweight SNR (default: 20).",
)
parser.add_argument(
    "--trigger-time",
    type=int,
    help="Time of the GRB, used to set the antenna patterns.",
)
parser.add_argument(
    "--projection",
    default="standard",
    choices=["standard", "left", "right", "left+right"],
    help="Choice of projection matrix. 'left' and 'right' "
    "correspond to face-away and face-on.",
)
parser.add_argument(
    "--slide-shift",
    type=float,
    default=1.0,
    help="Size of each time slide shift.",
)
parser.add_argument(
        "--do-shortslides",
        action="store_true"
)
# Add options groups
strain.insert_strain_option_group_multi_ifo(parser)
strain.StrainSegments.insert_segment_option_group_multi_ifo(parser)
psd.insert_psd_option_group_multi_ifo(parser)
scheme.insert_processing_option_group(parser)
fft.insert_fft_option_group(parser)
opt.insert_optimization_option_group(parser)
inject.insert_injfilterrejector_option_group_multi_ifo(parser)
sgchisq.SingleDetSGChisq.insert_option_group(parser)
args = parser.parse_args()
init_logging(args.verbose)
# Setting the number of Shortslides
num_slides = slide_limiter(args)
# Arrange detectors alphabetically so they are always called in the same order
args.instruments.sort()
# Use class verification methods to check whether input CLI options provided
# by parser to pycbc.strain, pycbc.strain.StrainSegments, pycbc.psd,
# pycbc.scheme, and pycbc.fft modules are sane.
strain.verify_strain_options_multi_ifo(args, parser, args.instruments)
strain.StrainSegments.verify_segment_options_multi_ifo(
    args, parser, args.instruments
)
psd.verify_psd_options_multi_ifo(args, parser, args.instruments)
scheme.verify_processing_options(args, parser)
fft.verify_fft_options(args, parser)
# InjFilterRejector instance: avoids investing computing power on processing
# injections with templates that differ significantly in chirp mass
inj_filter_rejector = inject.InjFilterRejector.from_cli_multi_ifos(
    args, args.instruments
)
# Create a dictionary to store timeseries objects corresponding to the
# individual detector strains.
# strain data is taken from args.gps_start_time to args.gps_end_time
# with the sampling rate equal to args.sample_rate
strain_dict = strain.from_cli_multi_ifos(
    args, args.instruments, inj_filter_rejector, dyn_range_fac=DYN_RANGE_FAC
)
# Create a dictionary of Python slice objects that indicate where the segments
# start and end for each detector timeseries.
strain_segments_dict = strain.StrainSegments.from_cli_multi_ifos(
    args, strain_dict, args.instruments
)
# Context manager to handle the various possible processing schemes
ctx = scheme.from_cli(args)
with ctx:
    fft.from_cli(args)
    # Set some convenience variables: number of IFOs, lower frequency,
    # GRB time, sky positions to search (either a grid or single sky point)
    nifo = len(args.instruments[:])
    flow = args.low_frequency_cutoff
    t_gps = args.trigger_time
    sky_positions = sky_grid_from_cli(parser, args)
    sky_pos_indices = np.arange(sky_positions.shape[1])
    # The following for loop is used to check whether
    # the sampling rate, flen and tlen agree for all detectors
    # taking the zeroth detector in the list as a reference.
    for count, ifo in enumerate(args.instruments):
        if count == 0:
            sample_rate = strain_dict[ifo].sample_rate
            flen = strain_segments_dict[ifo].freq_len
            tlen = strain_segments_dict[ifo].time_len
            delta_f = strain_segments_dict[ifo].delta_f
        else:
            vname = "Sample rate"
            err_msg = " must be consistent across all ifos."
            assert sample_rate == strain_dict[ifo].sample_rate, vname + err_msg
            vname = "Frequency length"
            assert flen == strain_segments_dict[ifo].freq_len, vname + err_msg
            vname = "Time length"
            assert tlen == strain_segments_dict[ifo].time_len, vname + err_msg
            vname = "delta_f (=segment length inverse)"
            assert delta_f == strain_segments_dict[ifo].delta_f, (
                vname + err_msg
            )
    # segments is a dictionary of frequency domain objects, each one of which
    # is the Fourier transform of the segments in strain_segments_dict
    logging.info("Making frequency-domain data segments")
    segments = {
        ifo: strain_segments_dict[ifo].fourier_segments()
        for ifo in args.instruments
    }
    # Memory cleaning
    del strain_segments_dict
    # Associate PSDs to segments for all IFOs when using the multi-detector CLI
    logging.info("Associating PSDs to them")
    psd.associate_psds_to_multi_ifo_segments(
        args,
        segments,
        strain_dict,
        flen,
        delta_f,
        flow,
        args.instruments,
        dyn_range_factor=DYN_RANGE_FAC,
        precision='single',
    )

    logging.info("Determining time slide shifts and time delays")
    # Create a dictionary of time slide shifts; IFO 0 is unshifted
    slide_ids = np.arange(num_slides)
    time_slides = {
        ifo: args.slide_shift * slide_ids * ifo_idx
        for ifo_idx, ifo in enumerate(args.instruments)
    }
    # Given the time delays wrt to IFO 0 in time_slides, create a dictionary
    # for time delay indices evaluated wrt the geocenter, in units of samples,
    # i.e. (time delay from geocenter + time slide)*sampling_rate
    time_delay_idx_zerolag = {
        position_index: {
            ifo: detector.Detector(ifo).time_delay_from_earth_center(
                sky_positions[0][position_index],
                sky_positions[1][position_index],
                t_gps,
            )
            for ifo in args.instruments
        }
        for position_index in sky_pos_indices
    }
    time_delay_idx = {
        slide: {
            position_index: {
                ifo: int(
                    round(
                        (
                            time_delay_idx_zerolag[position_index][ifo]
                            + time_slides[ifo][slide]
                        )
                        * sample_rate
                    )
                )
                for ifo in args.instruments
            }
            for position_index in sky_pos_indices
        }
        for slide in slide_ids
    }
    del time_delay_idx_zerolag

    logging.info("Setting up MatchedFilterControl at each IFO")
    # Prototype container for the output of MatchedFilterControl and
    # waveform.FilterBank (see below).
    # Use tlen of the first IFO as it is the same across IFOs.
    template_mem = zeros(tlen, dtype=complex64)

    # All MatchedFilterControl instances are initialized in the same way.
    # This allows to track where the single detector SNR timeseries are
    # greater than args.sngl_snr_threshold. Later, coh.get_coinc_indexes
    # will enforce the requirement that at least two single detector SNR
    # are above args.sngl_snr_threshold, rescuing, where necessary, SNR
    # timeseries points for detectors below that threshold.
    # NOTE: Do not cluster here for a coherent search (use_cluster=False).
    #       Clustering happens at the end of the template loop.
    matched_filter = {
        ifo: MatchedFilterControl(
            args.low_frequency_cutoff,
            None,
            args.sngl_snr_threshold,
            tlen,
            delta_f,
            complex64,
            segments[ifo],
            template_mem,
            use_cluster=False,
            downsample_factor=args.downsample_factor,
            upsample_threshold=args.upsample_threshold,
            upsample_method=args.upsample_method,
        )
        for ifo in args.instruments
    }

    # Chi-squares
    logging.info("Initializing signal-based vetoes: power, bank, and auto")
    # Directly use existing SingleDetPowerChisq, SingleDetBankVeto, and
    # SingleDetAutoChisq to calculate single detector chi-squares for
    # multiple IFOs
    power_chisq = vetoes.SingleDetPowerChisq(args.chisq_bins)
    bank_chisq = vetoes.SingleDetBankVeto(
        args.bank_veto_bank_file,
        flen,
        delta_f,
        flow,
        complex64,
        phase_order=args.order,
        approximant=args.approximant,
    )
    autochisq = vetoes.SingleDetAutoChisq(
        args.autochi_stride,
        args.autochi_number_points,
        onesided=args.autochi_onesided,
        twophase=args.autochi_two_phase,
        reverse_template=args.autochi_reverse_template,
        take_maximum_value=args.autochi_max_valued,
        maximal_value_dof=args.autochi_max_valued_dof
    )

    # Overwhiten all frequency-domain segments by dividing by the PSD estimate
    logging.info("Overwhitening frequency-domain data segments")
    for ifo in args.instruments:
        for seg in segments[ifo]:
            seg /= seg.psd

    logging.info("Setting up event manager")
    # But first build dictionaries to initialize and feed the event manager
    ifo_out_types = {
        'time_index': int,
        'ifo': int,  # IFO is stored as an int internally!
        'snr': complex64,
        'chisq': float32,
        'chisq_dof': int,
        'bank_chisq': float32,
        'bank_chisq_dof': int,
        'auto_chisq': float32,
        'auto_chisq_dof': int,
        'slide_id': int,
    }
    ifo_out_vals = {
        'time_index': None,
        'ifo': None,
        'snr': None,
        'chisq': None,
        'chisq_dof': None,
        'bank_chisq': None,
        'bank_chisq_dof': None,
        'auto_chisq': None,
        'auto_chisq_dof': int,
        'slide_id': None,
    }
    ifo_names = sorted(ifo_out_vals.keys())
    network_out_types = {
        'dec': float32,
        'ra': float32,
        'time_index': int,
        'coherent_snr': float32,
        'null_snr': float32,
        'nifo': int,
        'my_network_chisq': float32,
        'reweighted_snr': float32,
        'slide_id': int,
    }
    network_out_vals = {
        'dec': None,
        'ra': None,
        'time_index': None,
        'coherent_snr': None,
        'null_snr': None,
        'nifo': None,
        'my_network_chisq': None,
        'reweighted_snr': None,
        'slide_id': None,
    }
    network_names = sorted(network_out_vals.keys())
    event_mgr = EventManagerCoherent(
        args,
        args.instruments,
        ifo_names,
        [ifo_out_types[n] for n in ifo_names],
        network_names,
        [network_out_types[n] for n in network_names],
        segments=segments,
        time_slides=time_slides,
        gating_info={det: strain_dict[det].gating_info for det in strain_dict},
    )

    # Template bank: filtering and thinning
    logging.info("Read in template bank")
    bank = waveform.FilterBank(
        args.bank_file,
        flen,
        delta_f,
        complex64,
        low_frequency_cutoff=flow,
        phase_order=args.order,
        taper=args.taper_template,
        approximant=args.approximant,
        out=template_mem,
    )

    # Use inj_filter_rejector to reduce the bank to only those templates that
    # might actually find something
    n_bank = len(bank)
    logging.info("Full template bank size: %d", n_bank)
    for ifo in args.instruments:
        bank.template_thinning(inj_filter_rejector[ifo])
    if not len(bank) == n_bank:
        n_bank = len(bank)
        logging.info("Template bank size after thinning: %d", n_bank)

    logging.info("Calculating antenna pattern functions at every sky position")
    antenna_pattern = calculate_antenna_pattern(args, sky_pos_indices)

    logging.info("Starting the filtering...")
    # Loop over templates
    for t_num, template in enumerate(bank):
        # Loop over segments
        for s_num, stilde in enumerate(segments[args.instruments[0]]):
            stilde = {ifo: segments[ifo][s_num] for ifo in args.instruments}
            # Checks the 'inj_filter_rejector' options to determine whether
            # to filter this template/segment if injections are present
            analyse_segment = True
            for ifo in args.instruments:
                if not inj_filter_rejector[ifo].template_segment_checker(
                    bank, t_num, stilde[ifo]
                ):
                    logging.info(
                        "Skipping segment %d/%d with template %d/%d as no "
                        "detectable injection is present",
                        s_num + 1,
                        len(segments[ifo]),
                        t_num + 1,
                        n_bank,
                    )
                    analyse_segment = False
            # Find how loud the template is in each detector, i.e., its
            # unnormalized matched-filter with itself. This quantity is
            # used to normalize matched-filters with the data.
            sigmasq = {
                ifo: template.sigmasq(segments[ifo][s_num].psd)
                for ifo in args.instruments
            }
            sigma = {ifo: np.sqrt(sigmasq[ifo]) for ifo in args.instruments}
            # Every time s_num is zero, run new_template to increment the
            # template index
            if s_num == 0:
                event_mgr.new_template(tmplt=template.params, sigmasq=sigmasq)
            if not analyse_segment:
                continue
            logging.info(
                "Analyzing segment %d/%d", s_num + 1, len(segments[ifo])
            )
            # The following dicts with IFOs as keys are created to store
            # copies of the matched filtering results computed below.
            # - Complex SNR time series
            snr_dict = dict.fromkeys(args.instruments)
            # - Its normalization
            norm_dict = dict.fromkeys(args.instruments)
            # - The correlation vector frequency series
            #   It is the FFT of the SNR (so inverse FFT it to get the SNR)
            corr_dict = dict.fromkeys(args.instruments)
            # - The trigger indices list (idx_dict will be created out of this)
            idx = dict.fromkeys(args.instruments)
            # - The list of normalized SNR values at the trigger locations
            snr = dict.fromkeys(args.instruments)
            for ifo in args.instruments:
                logging.info(
                    "Filtering template %d/%d, ifo %s", t_num + 1, n_bank, ifo
                )
                # The following lines unpack and store copies of the matched
                # filtering results for the current template, segment, and IFO.
                # No clustering happens in the coherent search until the end.
                snr_ts, norm, corr, ind, snrv = matched_filter[
                    ifo
                ].matched_filter_and_cluster(
                    s_num, template.sigmasq(stilde[ifo].psd), window=0
                )
                snr_dict[ifo] = (
                    snr_ts[matched_filter[ifo].segments[s_num].analyze] * norm
                )
                assert len(snr_dict[ifo]) > 0, f'SNR time series for {ifo} is empty'
                norm_dict[ifo] = norm
                corr_dict[ifo] = corr.copy()
                idx[ifo] = ind.copy()
                snr[ifo] = snrv * norm

            # Move onto next segment if there are no triggers.
            n_trigs = [len(snr[ifo]) for ifo in args.instruments]
            if not any(n_trigs):
                continue

            # Loop over (short) time-slides, staring with the zero-lag
            for slide in range(num_slides):
                logging.info("Analyzing slide %d/%d", slide, num_slides)
                # Loop over sky positions
                for position_index in sky_pos_indices:
                    logging.info(
                        "Analyzing sky position %d/%d",
                        position_index + 1,
                        len(sky_pos_indices),
                    )
                    # Adjust the indices of triggers (if there are any)
                    # and store trigger indices list in a dictionary;
                    # when there are no triggers, the dictionary is empty.
                    # Indices are kept only if they do not get time shifted
                    # out of the time we are looking at, i.e., require
                    # idx[ifo] - time_delay_idx[slide][position_index][ifo]
                    # to be in (0, len(snr_dict[ifo]))
                    idx_dict = {
                        ifo: idx[ifo][
                            np.logical_and(
                                idx[ifo]
                                > time_delay_idx[slide][position_index][ifo],
                                idx[ifo]
                                - time_delay_idx[slide][position_index][ifo]
                                < len(snr_dict[ifo]),
                            )
                        ]
                        for ifo in args.instruments
                    }
                    # Find triggers that are coincident (in geocent time) in
                    # multiple IFOs. If a single IFO analysis then just use the
                    # indices from that IFO, i.e., IFO 0; otherwise, this
                    # method finds coincidences and applies the single IFO cut,
                    # namely, triggers must have at least 2 IFO SNRs above
                    # args.sngl_snr_threshold.
                    if nifo > 1:
                        coinc_idx = coh.get_coinc_indexes(
                            idx_dict, time_delay_idx[slide][position_index]
                        )
                    else:
                        coinc_idx = (
                            idx_dict[args.instruments[0]]
                            - time_delay_idx[slide][position_index][
                                args.instruments[0]
                            ]
                        )
                    logging.info(
                        "Found %d coincident triggers", len(coinc_idx)
                    )
                    # Time delay is applied to indices to have them at the IFOs
                    coinc_idx_det_frame = {
                        ifo: (
                            coinc_idx
                            + time_delay_idx[slide][position_index][ifo]
                        )
                        % len(snr_dict[ifo])
                        for ifo in args.instruments
                    }
                    # Calculate the coincident and coherent SNR.
                    # First check there is enough data to compute the SNRs.
                    if len(coinc_idx) != 0 and nifo > 1:
                        # Find coinc SNR at trigger times and apply coinc SNR
                        # threshold (which depopulates coinc_idx accordingly)
                        (
                            rho_coinc,
                            coinc_idx,
                            coinc_triggers,
                        ) = coh.coincident_snr(
                            snr_dict,
                            coinc_idx,
                            args.coinc_threshold,
                            time_delay_idx[slide][position_index],
                        )
                        logging.info(
                            "%d triggers above coincident SNR threshold",
                            len(coinc_idx),
                        )
                        if len(coinc_idx) != 0:
                            logging.info(
                                "With max coincident SNR = %.2f",
                                max(rho_coinc),
                            )
                    # If there is only one IFO, just take its triggers
                    # and their SNRs
                    elif len(coinc_idx) != 0 and nifo == 1:
                        coinc_triggers = {
                            args.instruments[0]: snr[args.instruments[0]][
                                coinc_idx_det_frame[args.instruments[0]]
                            ]
                        }
                    else:
                        coinc_triggers = {}
                        logging.info("No coincident triggers were found")
                    # If there are triggers above coinc threshold and more
                    # than 2 IFOs, then calculate the coherent statistics for
                    # them and apply the cut on coherent SNR (with threshold
                    # equal to the coinc SNR one)
                    if len(coinc_idx) != 0 and nifo > 2:
                        logging.info("Calculating their coherent statistics")
                        # Plus and cross antenna pattern dictionaries
                        fp = {
                            ifo: antenna_pattern[ifo][position_index][0]
                            for ifo in args.instruments
                        }
                        fc = {
                            ifo: antenna_pattern[ifo][position_index][1]
                            for ifo in args.instruments
                        }
                        # The cut is applied after the left vs right
                        # comparison to ensure the arrays have equal lengths
                        if args.projection == 'left+right':
                            # Left polarized coherent SNR
                            project_l = coh.get_projection_matrix(
                                fp, fc, sigma, projection='left'
                            )
                            (
                                rho_coh_l,
                                coinc_idx_l,
                                coinc_triggers_l,
                                rho_coinc_l,
                            ) = coh.coherent_snr(
                                coinc_triggers,
                                coinc_idx,
                                0.0,
                                project_l,
                                rho_coinc,
                            )
                            # Right polarized coherent SNR
                            project_r = coh.get_projection_matrix(
                                fp, fc, sigma, projection='right'
                            )
                            (
                                rho_coh_r,
                                coinc_idx_r,
                                coinc_triggers_r,
                                rho_coinc_r,
                            ) = coh.coherent_snr(
                                coinc_triggers,
                                coinc_idx,
                                0.0,
                                project_r,
                                rho_coinc,
                            )
                            # Apply cut to remove points with left or right
                            # polarized coherent SNR below threshold
                            l_above = rho_coh_l > args.coinc_threshold
                            r_above = rho_coh_r > args.coinc_threshold
                            lr_above = l_above & r_above
                            rho_coh_l = rho_coh_l[lr_above]
                            rho_coh_r = rho_coh_r[lr_above]
                            coinc_idx_l = coinc_idx_l[lr_above]
                            coinc_idx_r = coinc_idx_r[lr_above]
                            for ifo in coinc_triggers_l.keys():
                                coinc_triggers_l[ifo] = coinc_triggers_l[ifo][
                                    lr_above
                                ]
                            for ifo in coinc_triggers_r.keys():
                                coinc_triggers_r[ifo] = coinc_triggers_r[ifo][
                                    lr_above
                                ]
                            rho_coinc_l = rho_coinc_l[lr_above]
                            rho_coinc_r = rho_coinc_r[lr_above]
                            # Point by point, track the larger of the two
                            # and store its information
                            max_idx = np.argmax([rho_coh_l, rho_coh_r], axis=0)
                            rho_coh = np.where(
                                max_idx == 0, rho_coh_l, rho_coh_r
                            )
                            coinc_idx = np.where(
                                max_idx == 0, coinc_idx_l, coinc_idx_r
                            )
                            coinc_triggers = {
                                ifo: np.where(
                                    max_idx == 0,
                                    coinc_triggers_l[ifo],
                                    coinc_triggers_r[ifo],
                                )
                                for ifo in coinc_triggers_l
                            }
                            rho_coinc = np.where(
                                max_idx == 0, rho_coinc_l, rho_coinc_r
                            )
                        else:
                            project = coh.get_projection_matrix(
                                fp, fc, sigma, projection=args.projection
                            )
                            (
                                rho_coh,
                                coinc_idx,
                                coinc_triggers,
                                rho_coinc,
                            ) = coh.coherent_snr(
                                coinc_triggers,
                                coinc_idx,
                                args.coinc_threshold,
                                project,
                                rho_coinc,
                            )
                        logging.info(
                            "%d triggers above coherent SNR threshold",
                            len(rho_coh),
                        )
                        if len(coinc_idx) != 0:
                            logging.info(
                                "With max coherent SNR = %.2f", max(rho_coh)
                            )
                            # Calculate the null SNR and apply the null SNR cut
                            (
                                null,
                                rho_coh,
                                rho_coinc,
                                coinc_idx,
                                coinc_triggers,
                            ) = coh.null_snr(
                                rho_coh,
                                rho_coinc,
                                apply_cut=args.do_null_cut,
                                null_min=args.null_min,
                                null_grad=args.null_grad,
                                null_step=args.null_step,
                                snrv=coinc_triggers,
                                index=coinc_idx,
                            )
                            logging.info(
                                "%d triggers above null threshold", len(null)
                            )
                            if len(coinc_idx) != 0:
                                logging.info(
                                    "With max null SNR = %.2f", max(null)
                                )
                    # Now calculate the individual detector chi2 values
                    # and the SNR reweighted by chi2 and by null SNR
                    # (no cut on reweighted SNR is applied).
                    # To do this it is useful to find the indices of the
                    # (surviving) triggers in the detector frame.
                    if len(coinc_idx) != 0:
                        # Updated coinc_idx_det_frame to account for the
                        # effect of the cuts applied to far
                        coinc_idx_det_frame = {
                            ifo: (
                                coinc_idx
                                + time_delay_idx[slide][position_index][ifo]
                            )
                            % len(snr_dict[ifo])
                            for ifo in args.instruments
                        }
                        # Build dictionary with per-IFO complex SNR time series
                        # of the most recent set of triggers
                        coherent_ifo_trigs = {
                            ifo: snr_dict[ifo][coinc_idx_det_frame[ifo]]
                            for ifo in args.instruments
                        }
                        # Calculate the powerchi2 values of remaining triggers
                        # (this uses the SNR timeseries before the time delay,
                        # so we undo it; the same holds for normalisation)
                        chisq = {}
                        chisq_dof = {}
                        for ifo in args.instruments:
                            chisq[ifo], chisq_dof[ifo] = power_chisq.values(
                                corr_dict[ifo],
                                coherent_ifo_trigs[ifo] / norm_dict[ifo],
                                norm_dict[ifo],
                                stilde[ifo].psd,
                                coinc_idx_det_frame[ifo]
                                + stilde[ifo].analyze.start,
                                template,
                            )
                        # Calculate network chisq value
                        network_chisq_dict = coh.network_chisq(
                            chisq, chisq_dof, coherent_ifo_trigs
                        )
                        # Calculate chisq reweighted SNR
                        if nifo > 2:
                            reweighted_snr = ranking.newsnr(
                                rho_coh,
                                network_chisq_dict,
                                q=args.chisq_index,
                                n=args.chisq_nhigh,
                            )
                            # Calculate null reweighted SNR
                            reweighted_snr = coh.reweight_snr_by_null(
                                reweighted_snr,
                                null,
                                rho_coh,
                                null_min=args.null_min,
                                null_grad=args.null_grad,
                                null_step=args.null_step,
                            )
                        elif nifo == 2:
                            reweighted_snr = ranking.newsnr(
                                rho_coinc,
                                network_chisq_dict,
                                q=args.chisq_index,
                                n=args.chisq_nhigh,
                            )
                        else:
                            rho_sngl = abs(
                                snr[args.instruments[0]][
                                    coinc_idx_det_frame[args.instruments[0]]
                                ]
                            )
                            reweighted_snr = ranking.newsnr(
                                rho_sngl,
                                network_chisq_dict,
                                q=args.chisq_index,
                                n=args.chisq_nhigh,
                            )
                        # All out vals must be the same length, so single
                        # value entries are repeated once per event
                        num_events = len(reweighted_snr)
                        # Calculate the bankchi2 and autochi2 values of
                        # remaining triggers
                        for ifo in args.instruments:
                            (
                                ifo_out_vals['bank_chisq'],
                                ifo_out_vals['bank_chisq_dof'],
                            ) = bank_chisq.values(
                                template,
                                stilde[ifo].psd,
                                stilde[ifo],
                                coherent_ifo_trigs[ifo] / norm_dict[ifo],
                                norm_dict[ifo],
                                coinc_idx_det_frame[ifo]
                                + stilde[ifo].analyze.start,
                            )
                            (
                                ifo_out_vals['auto_chisq'],
                                ifo_out_vals['auto_chisq_dof'],
                            ) = autochisq.values(
                                snr_dict[ifo] / norm_dict[ifo],
                                coinc_idx_det_frame[ifo],
                                template,
                                stilde[ifo].psd,
                                norm_dict[ifo],
                                stilde=stilde[ifo],
                                low_frequency_cutoff=flow,
                            )
                            ifo_out_vals['chisq'] = chisq[ifo]
                            ifo_out_vals['chisq_dof'] = chisq_dof[ifo]
                            ifo_out_vals['time_index'] = (
                                coinc_idx_det_frame[ifo]
                                + stilde[ifo].cumulative_index
                            )
                            ifo_out_vals['snr'] = coherent_ifo_trigs[ifo]
                            # IFO is stored as an int
                            ifo_out_vals['ifo'] = [
                                event_mgr.ifo_dict[ifo]
                            ] * num_events
                            # Time slide ID
                            ifo_out_vals['slide_id'] = [slide] * num_events
                            event_mgr.add_template_events_to_ifo(
                                ifo,
                                ifo_names,
                                [ifo_out_vals[n] for n in ifo_names],
                            )
                        if nifo > 2:
                            network_out_vals['coherent_snr'] = rho_coh
                            network_out_vals['null_snr'] = null
                        elif nifo == 2:
                            network_out_vals['coherent_snr'] = rho_coinc
                        else:
                            network_out_vals['coherent_snr'] = abs(
                                snr[args.instruments[0]][
                                    coinc_idx_det_frame[args.instruments[0]]
                                ]
                            )
                        network_out_vals['reweighted_snr'] = reweighted_snr
                        network_out_vals['my_network_chisq'] = np.real(
                            network_chisq_dict
                        )
                        network_out_vals['time_index'] = (
                            coinc_idx + stilde[ifo].cumulative_index
                        )
                        network_out_vals['nifo'] = [nifo] * num_events
                        network_out_vals['dec'] = [
                            sky_positions[1][position_index]
                        ] * num_events
                        network_out_vals['ra'] = [
                            sky_positions[0][position_index]
                        ] * num_events
                        network_out_vals['slide_id'] = [slide] * num_events
                        event_mgr.add_template_events_to_network(
                            network_names,
                            [network_out_vals[n] for n in network_names],
                        )
            # Left loops over sky positions and time-slides,
            # but not loops over segments and templates.
            # The triggers can be clustered
            if args.cluster_method == "window":
                cluster_window = int(args.cluster_window * sample_rate)
            elif args.cluster_method == "template":
                cluster_window = int(template.chirp_length * sample_rate)
            # Cluster template events by slide
            for slide in range(num_slides):
                logging.info("Clustering slide %d", slide)
                event_mgr.cluster_template_network_events(
                    'time_index', 'reweighted_snr', cluster_window, slide=slide
                )
        # Left loop over segments
        event_mgr.finalize_template_events()
    # Left loop over templates
    logging.info("Filtering completed")

logging.info("Writing output")
event_mgr.write_events(args.output)

logging.info("Finished")
logging.info("Time to complete analysis: %d", int(time.time() - time_init))
