#!/usr/bin/python3.11

# Copyright (C) 2014 Alex Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Find multi-detector gravitational wave triggers and calculate the
coherent SNRs and related statistics.
"""

import logging
import time
from collections import defaultdict
import argparse
import numpy as np
import h5py
from pycbc import (
    detector, fft, init_logging, inject, opt, psd, scheme, strain, vetoes,
    waveform, DYN_RANGE_FAC
)
from pycbc.events import ranking, coherent as coh, EventManagerCoherent
from pycbc.filter import MatchedFilterControl
from pycbc.types import TimeSeries, zeros, float32, complex64
from pycbc.types import MultiDetOptionAction
from pycbc.vetoes import sgchisq
time_init = time.time()
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="print extra debugging information",
                    default=False)
parser.add_argument("--output", type=str)
parser.add_argument("--instruments", nargs="+", type=str, required=True,
                    help="List of instruments to analyze.")
parser.add_argument("--bank-file", type=str)
parser.add_argument("--snr-threshold", type=float,
                    help="SNR threshold for trigger generation")
parser.add_argument("--newsnr-threshold", type=float, metavar='THRESHOLD',
                    help="Cut triggers with NewSNR less than THRESHOLD")
parser.add_argument("--low-frequency-cutoff", type=float,
                    help="The low frequency cutoff to use for filtering "
                         "(Hz)")
# add approximant arg
waveform.bank.add_approximant_arg(parser)
parser.add_argument("--order", type=str,
                    help="The integer half-PN order at which to generate the "
                         "approximant.")
parser.add_argument("--taper-template",
                    choices=["start", "end", "startend"],
                    help="For time-domain approximants, taper the start "
                         "and/or end of the waveform before FFTing.")
parser.add_argument("--cluster-method", choices=["template", "window"],
                    default="window",
                    help="Method to use when clustering triggers. 'window' - "
                         "cluster within a fixed time window defined by the "
                         "cluster-window option (default); or 'template' - "
                         "cluster within windows defined by each template's "
                         "chirp length.")
parser.add_argument("--cluster-window", type=float, default=0,
                    help="Length of clustering window in seconds.")
parser.add_argument("--bank-veto-bank-file", type=str, help="Path to the "
                    "bank file used to compute the the bank chi-square veto.")
parser.add_argument("--chisq-bins", default=0)
# Commenting out options which are not yet implemented
# parser.add_argument("--chisq-threshold", type=float, default=0) 
# parser.add_argument("--chisq-delta", type=float, default=0)
parser.add_argument("--autochi-number-points", type=int, default=0)
parser.add_argument("--autochi-stride", type=int, default=0)
parser.add_argument("--autochi-onesided", action='store_true',
                    default=False)
parser.add_argument("--downsample-factor", type=int, default=1,
                    help="Factor that determines the interval between the "
                         "initial SNR sampling. If not set (or 1) no sparse "
                         "sample is created, and the standard full SNR is "
                         "calculated.")
parser.add_argument("--upsample-threshold", type=float,
                    help="The fraction of the SNR threshold to check the "
                         "sparse SNR sample.")
parser.add_argument("--upsample-method", choices=["pruned_fft"],
                    default='pruned_fft',
                    help="The method to find the SNR points between the "
                         "sparse SNR sample.")
parser.add_argument("--user-tag", type=str, metavar="TAG",
                    help="This is used to identify FULL_DATA jobs for "
                         "compatibility with pipedown post-processing. Option "
                         "will be removed when no longer needed.")
# Arguments added for the coherent stuff
parser.add_argument("--ra", type=float, help="Right ascension, in radians")
parser.add_argument("--dec", type=float, help="Declination, in radians")
parser.add_argument("--sky-grid", type=str,
                    help="Sky-grid (hdf file) containing two datasets : "
                    "ra and dec, both in radians")
parser.add_argument("--coinc-threshold", type=float, default=0.0,
                    help="Triggers with coincident/coherent snr below this "
                         "value will be discarded.")
parser.add_argument("--do-null-cut", action='store_true',
                    help="Apply a cut based on null SNR.")
parser.add_argument("--null-min", type=float, default=5.25,
                    help="Triggers with null_snr above this value will be"
                         " discarded.")
parser.add_argument("--null-grad", type=float, default=0.2,
                    help="The gradient of the line defining the null cut "
                         "after the null step.")
parser.add_argument("--null-step", type=float, default=20.,
                    help="Triggers with coherent snr above null_step will "
                         "be cut according to the null_grad and null_min.")
parser.add_argument("--trigger-time", type=int,
                    help="Time of the GRB, used to set the antenna patterns.")
parser.add_argument("--projection", default="standard",
                    choices=["standard", "left", "right", "left+right"],
                    help="Choice of projection matrix. 'left' and 'right' " 
                         "correspond to face-away and face-on")
parser.add_argument("--num-slides", type=int, default=0,
                    help="Number of time slides to perform.")
parser.add_argument("--slide-shift", type=float, default=1.,
                    help="Size of each time slide shift.")
# Add options groups
strain.insert_strain_option_group_multi_ifo(parser)
strain.StrainSegments.insert_segment_option_group_multi_ifo(parser)
psd.insert_psd_option_group_multi_ifo(parser)
scheme.insert_processing_option_group(parser)
fft.insert_fft_option_group(parser)
opt.insert_optimization_option_group(parser)
inject.insert_injfilterrejector_option_group_multi_ifo(parser)
sgchisq.SingleDetSGChisq.insert_option_group(parser)
args = parser.parse_args()
init_logging(args.verbose)
# Set GRB time variable for convenience
t_gps = args.trigger_time
# Put the ifos in alphabetical order so they are always called in
# the same order.
args.instruments.sort()
strain.verify_strain_options_multi_ifo(args, parser, args.instruments)
strain.StrainSegments.verify_segment_options_multi_ifo(
        args, parser, args.instruments)
psd.verify_psd_options_multi_ifo(args, parser, args.instruments)
scheme.verify_processing_options(args, parser)
fft.verify_fft_options(args, parser)
inj_filter_rejector = inject.InjFilterRejector.from_cli_multi_ifos(
    args, args.instruments)
strain_dict = strain.from_cli_multi_ifos(
    args, args.instruments, inj_filter_rejector, dyn_range_fac=DYN_RANGE_FAC)
strain_segments_dict = strain.StrainSegments.from_cli_multi_ifos(
    args, strain_dict, args.instruments)
ctx = scheme.from_cli(args)
with ctx:
    fft.from_cli(args)
    # Set some often used variables for easy access
    flow = args.low_frequency_cutoff
    flow_dict = defaultdict(lambda : flow)
    for count, ifo in enumerate(args.instruments):
        if count == 0:
            sample_rate = strain_dict[ifo].sample_rate
            sample_rate_dict = defaultdict(lambda : sample_rate)
            flen = strain_segments_dict[ifo].freq_len
            flen_dict = defaultdict(lambda : flen)
            tlen = strain_segments_dict[ifo].time_len
            tlen_dict = defaultdict(lambda : tlen)
            delta_f = strain_segments_dict[ifo].delta_f
            delta_f_dict = defaultdict(lambda : delta_f)
        else:
            try:
                assert(sample_rate == strain_dict[ifo].sample_rate)
                assert(flen == strain_segments_dict[ifo].freq_len)
                assert(tlen == strain_segments_dict[ifo].time_len)
                assert(delta_f == strain_segments_dict[ifo].delta_f)
            except:
                err_msg = "Sample rate, frequency length and time length "
                err_msg += "must all be consistent across ifos."
                raise ValueError(err_msg)
    logging.info("Making frequency-domain data segments")
    segments = {
        ifo: strain_segments_dict[ifo].fourier_segments()
        for ifo in args.instruments
        }
    del strain_segments_dict
    psd.associate_psds_to_multi_ifo_segments(
        args, segments, strain_dict, flen, delta_f, flow, args.instruments,
        dyn_range_factor=DYN_RANGE_FAC, precision='single')
    # Currently we are using the same matched-filter parameters for all
    # ifos. Therefore only one MatchedFilterControl needed. Maybe this can
    # change if needed. Segments is only used to get tlen etc. which is
    # same for all ifos, so just send the first ifo
    template_mem = zeros(tlen, dtype=complex64)
    
    #Read the sky grid or the single sky position
    if args.sky_grid is not None and args.ra is not None and args.dec is not None:
        parser.error('Give either a sky grid or a sky position, not both')
    
    if args.sky_grid is not None:
        sky_grid = h5py.File(args.sky_grid, 'r')
        ra = np.array(sky_grid['ra'])
        dec = np.array(sky_grid['dec'])
    if args.ra is not None and args.dec is not None:
        ra = np.array([args.ra])
        dec = np.array([args.dec])
    
    sky_positions = np.array([ra, dec])
    num_sky_positions = sky_positions.shape[1]
    positions_array = np.arange(num_sky_positions)

    # Calculate time delays to each detector for each sky position and apply time slide shifts
    slide_ids = np.arange(1 + args.num_slides)
    time_slides = {
        ifo: args.slide_shift * slide_ids * n_ifo
        for n_ifo, ifo in enumerate(args.instruments)}
    time_delay_idx = {
            slide: {
                position_index: {
                    ifo: int(round(
                        (detector.Detector(ifo).time_delay_from_earth_center(
                            sky_positions[0][position_index], sky_positions[1][position_index], t_gps)
                            + time_slides[ifo][slide])
                        * sample_rate
                        ))
                        for ifo in args.instruments
                } for position_index in positions_array
            } for slide in slide_ids
        }
    
    # Matched filter each ifo. Don't cluster here for a coherent search.
    # Clustering happens at the end of the template loop.
    # FIXME: The single detector SNR threshold should not necessarily be
    #        applied to every IFO (usually only 2 most sensitive in
    #        network)
    matched_filter = {
        ifo: MatchedFilterControl(
            args.low_frequency_cutoff, None, args.snr_threshold, tlen, delta_f,
            complex64, segments[ifo], template_mem, use_cluster=False,
            downsample_factor=args.downsample_factor,
            upsample_threshold=args.upsample_threshold,
            upsample_method=args.upsample_method, cluster_function='symmetric')
        for ifo in args.instruments}
    logging.info("Initializing signal-based vetoes.")
    # The existing SingleDetPowerChisq can calculate the single detector
    # chisq for multiple ifos, so just use that directly.
    power_chisq = vetoes.SingleDetPowerChisq(args.chisq_bins)
    # The existing SingleDetBankVeto can calculate the single detector
    # bank veto for multiple ifos, so we just use it directly.
    bank_chisq = vetoes.SingleDetBankVeto(
        args.bank_veto_bank_file, flen, delta_f, flow, complex64,
        phase_order=args.order, approximant=args.approximant)
    # Same here
    autochisq = vetoes.SingleDetAutoChisq(
        args.autochi_stride, args.autochi_number_points,
        onesided=args.autochi_onesided)
    logging.info("Overwhitening frequency-domain data segments")
    for ifo in args.instruments:
        for seg in segments[ifo]:
            seg /= seg.psd
    ifo_out_types = {
        'time_index': int,
        'ifo': int, # IFO is stored as an int internally!
        'snr': complex64,
        'chisq': float32,
        'chisq_dof': int,
        'bank_chisq': float32,
        'bank_chisq_dof': int,
        'cont_chisq': float32,
        'slide_id': int
        }
    ifo_out_vals = {
        'time_index': None,
        'ifo': None,
        'snr': None,
        'chisq': None,
        'chisq_dof': None,
        'bank_chisq': None,
        'bank_chisq_dof': None,
        'cont_chisq': None,
        'slide_id': int
        }
    ifo_names = sorted(ifo_out_vals.keys())
    network_out_types = {
        'dec': float32,
        'ra': float32,
        'time_index': int,
        'coherent_snr': float32,
        'null_snr': float32,
        'nifo': int,
        'my_network_chisq': float32,
        'reweighted_snr': float32,
        'slide_id': int
        }
    network_out_vals = {
        'dec': None,
        'ra': None,
        'time_index': None,
        'coherent_snr': None,
        'null_snr': None,
        'nifo': None,
        'my_network_chisq': None, 
        'reweighted_snr': None,
        'slide_id': int
        }
    network_names = sorted(network_out_vals.keys())
    event_mgr = EventManagerCoherent(
        args, args.instruments, ifo_names,
        [ifo_out_types[n] for n in ifo_names], network_names,
        [network_out_types[n] for n in network_names])
    logging.info("Read in template bank")
    bank = waveform.FilterBank(
        args.bank_file, flen, delta_f, complex64, low_frequency_cutoff=flow,
        phase_order=args.order, taper=args.taper_template,
        approximant=args.approximant, out=template_mem)
    # Use injfilterrejector to reduce the bank to only those templates that
    # might actually find something
    n_bank = len(bank)
    nfilters = 0
    logging.info("Full template bank size: %d", n_bank)
    for ifo in args.instruments:
        bank.template_thinning(inj_filter_rejector[ifo])
    if not len(bank) == n_bank:
        n_bank = len(bank)
        logging.info("Template bank size after thinning: %d", n_bank)
   
    # Antenna patterns
    antenna_patterns = [[[0 for i in range(2)] for position_index in positions_array] for i in range(len(args.instruments))]
    for i, ifo in enumerate(args.instruments):
        for position_index in positions_array:
            antenna_patterns[i][position_index] = detector.Detector(ifo).antenna_pattern(sky_positions[0][position_index], sky_positions[1][position_index], polarization=0, t_gps=t_gps)

    ap = {}
    for i, ifo in enumerate(args.instruments):
        ap[ifo] = antenna_patterns[i]

    # Loop over templates
    for t_num, template in enumerate(bank):
        # Loop over segments
        for s_num,stilde in enumerate(segments[args.instruments[0]]):
            stilde = {ifo : segments[ifo][s_num] for ifo in args.instruments}
            # Filter check checks the 'inj_filter_rejector' options to
            # determine whether to filter this template/segment 
            # if injections are present.
            analyse_segment = True
            for ifo in args.instruments:
                if not inj_filter_rejector[ifo].template_segment_checker(
                        bank, t_num, stilde[ifo]):
                    logging.info(
                        "Skipping segment %d/%d with template %d/%d as no "
                        "detectable injection is present",
                        s_num + 1, len(segments[ifo]), t_num + 1, n_bank)
                    analyse_segment = False
            # Find detector sensitivities (sigma) and make array of
            # normalised
            sigmasq = {
                ifo : template.sigmasq(segments[ifo][s_num].psd)
                for ifo in args.instruments}
            sigma = {ifo : np.sqrt(sigmasq[ifo]) for ifo in args.instruments}
            # Every time s_num is zero or we skip the segment, we run new 
            # template to increment the template index
            if s_num==0:
                event_mgr.new_template(tmplt=template.params, sigmasq=sigmasq)
            if not analyse_segment: continue
            logging.info(
                "Analyzing segment %d/%d", s_num + 1, len(segments[ifo]))
            snr_dict = dict.fromkeys(args.instruments)
            norm_dict = dict.fromkeys(args.instruments)
            corr_dict = dict.fromkeys(args.instruments)
            idx = dict.fromkeys(args.instruments)
            snrv_dict = dict.fromkeys(args.instruments)
            snr = dict.fromkeys(args.instruments)
            # FIXME: 2 lines that can be moved outside the loops
            #        We do not really use ifo_list, and we reassign
            #        nifo identically a few lines below
            ifo_list = args.instruments[:]
            nifo = len(ifo_list)
            for ifo in args.instruments:
                logging.info(
                    "Filtering template %d/%d, ifo %s", t_num + 1, n_bank, ifo)
                # No clustering in the coherent search until the end.
                # The correlation vector is the FFT of the snr (so inverse
                # FFT it to get the snr). 
                snr_ts, norm, corr, ind, snrv = \
                     matched_filter[ifo].matched_filter_and_cluster(
                         s_num, template.sigmasq(stilde[ifo].psd), window=0)
                snr_dict[ifo] = (
                        snr_ts[matched_filter[ifo].segments[s_num].analyze]
                        * norm)
                norm_dict[ifo] = norm
                corr_dict[ifo] = corr.copy()
                idx[ifo] = ind.copy()
                snrv_dict[ifo] = snrv.copy()
                snr[ifo] = snrv * norm
            
            # FIXME: wrong comment?
            # Move onto next segment if there are no triggers.
            if len(ifo_list)==0: continue
            # Loop through slides, staring with the zero-lag
            for slide in range(args.num_slides + 1):
                logging.info(
                        "Analyzing slide %d/%d", slide, args.num_slides)
                for position_index in positions_array:
                    logging.info(
                        "Analyzing sky position %d/%d", position_index+1, len(positions_array))
                    # Save the indexes of triggers (if we have any)
                    # Even if we have none, need to keep an empty dictionary.
                    # Only do this if idx doesn't get time shifted out of the
                    # time we are looking at, i.e., require
                    # idx[ifo] - time_delay_idx[slide][position_index][ifo] to be in
                    # (0, len(snr_dict[ifo]))
                    idx_dict = {
                    ifo: idx[ifo][
                        np.logical_and(
                            idx[ifo] > time_delay_idx[slide][position_index][ifo],
                            idx[ifo] - time_delay_idx[slide][position_index][ifo]
                                < len(snr_dict[ifo]))
                        ]
                    for ifo in args.instruments
                    }
                    
                    # Find triggers that are coincident (in geocent time) in
                    # multiple ifos. If a single ifo analysis then just use the
                    # indexes from that ifo.
                    if nifo > 1:
                        coinc_idx = coh.get_coinc_indexes(
                            idx_dict, time_delay_idx[slide][position_index])
                    else:
                        coinc_idx = (
                            idx_dict[args.instruments[0]]
                            - time_delay_idx[slide][position_index][args.instruments[0]]
                            )
                    logging.info("Found %d coincident triggers", len(coinc_idx))
                    for ifo in args.instruments:
                        # Raise errror if this segment has no data
                        # FIXME: raise this sooner?
                        if len(snr_dict[ifo])==0:
                            raise RuntimeError(
                                'The SNR triggers dictionary is empty. This '
                                'should not be possible.')
                    # Time delay is applied to indices
                    coinc_idx_det_frame = {
                        ifo: (coinc_idx + time_delay_idx[slide][position_index][ifo]) % len(snr_dict[ifo])
                        for ifo in args.instruments}
                    # Calculate the coincident and coherent snr. Check we have
                    # data before we try to compute the coherent snr
                    if len(coinc_idx) != 0 and nifo > 1:
                        # Find coinc snr at trigger times and apply coinc snr
                        # threshold
                        rho_coinc, coinc_idx, coinc_triggers = \
                            coh.coincident_snr(
                                snr_dict, coinc_idx, args.coinc_threshold,
                                time_delay_idx[slide][position_index])
                        logging.info(
                            "%d coincident tiggers above coincident SNR threshold",
                            len(coinc_idx))
                        if len(coinc_idx) != 0:
                            logging.info(
                                "Max coincident SNR = %.2f", max(rho_coinc))
                    # If there is only one ifo, then coinc_triggers is just the
                    # triggers from the ifo
                    elif len(coinc_idx) != 0 and nifo == 1:
                        coinc_triggers = {
                            args.instruments[0]: snr[args.instruments[0]][
                                coinc_idx_det_frame[args.instruments[0]]
                                ]
                            }
                    else:
                        coinc_triggers = {}
                        logging.info("No triggers above coincident SNR threshold")
                    # If we have triggers above coinc threshold and more than 2
                    # ifos, then calculate the coherent statistics
                    if len(coinc_idx) != 0 and nifo > 2:
                        if args.projection=='left+right':
                            #Plus and cross polarization
                            fp = {ifo: ap[ifo][position_index][0] for ifo in args.instruments}
                            fc = {ifo: ap[ifo][position_index][1] for ifo in args.instruments}
                            # Left polarized coherent SNR
                            project_l = coh.get_projection_matrix(
                                fp, fc, sigma, projection='left')
                            (rho_coh_l, coinc_idx_l, coinc_triggers_l,
                                    rho_coinc_l) = \
                                coh.coherent_snr(
                                    coinc_triggers, coinc_idx,
                                    args.coinc_threshold, project_l, rho_coinc)
                            # Right polarized coherent SNR
                            project_r = coh.get_projection_matrix(
                                fp, fc, sigma, projection='right')
                            (rho_coh_r, coinc_idx_r, coinc_triggers_r,
                                    rho_coinc_r) = \
                                coh.coherent_snr(
                                    coinc_triggers, coinc_idx,
                                    args.coinc_threshold, project_r, rho_coinc)
                            # Point by point, track the larger of the two and store it
                            max_idx = np.argmax([rho_coh_l, rho_coh_r], axis=0)
                            rho_coh = np.where(
                                max_idx==0, rho_coh_l, rho_coh_r)
                            coinc_idx = np.where(
                                max_idx==0, coinc_idx_l, coinc_idx_r)
                            coinc_triggers = {
                                ifo: np.where(
                                    max_idx==0, coinc_triggers_l[ifo],
                                    coinc_triggers_r[ifo])
                                for ifo in coinc_triggers_l}
                            rho_coinc = np.where(
                                max_idx==0, rho_coinc_l, rho_coinc_r)
                        else:
                            project = coh.get_projection_matrix(
                                fp, fc, sigma, projection=args.projection)
                            rho_coh, coinc_idx, coinc_triggers, rho_coinc = \
                                coh.coherent_snr(
                                    coinc_triggers, coinc_idx,
                                    args.coinc_threshold, project, rho_coinc)
                        logging.info(
                            "%d triggers above coherent threshold", len(rho_coh))
                        if len(coinc_idx) != 0:
                            logging.info("Max coherent SNR = %.2f", max(rho_coh))
                            #Find the null snr
                            (null, rho_coh, rho_coinc, coinc_idx,
                                    coinc_triggers) =\
                                coh.null_snr(
                                    rho_coh, rho_coinc, snrv=coinc_triggers,
                                    index=coinc_idx)
                            if len(coinc_idx) != 0:
                                logging.info("Max null SNR = %.2f", max(null))
                            logging.info(
                                "%d triggers above null threshold", len(null))
                    # We are now going to find the individual detector chi2
                    # values. To do this it is useful to find the indexes of
                    # the triggers in the detector frame.
                    if len(coinc_idx) != 0:
                        # coinc_idx_det_frame is redefined to account for the
                        # cuts to coinc_idx above
                        coinc_idx_det_frame = {
                            ifo: (coinc_idx + time_delay_idx[slide][position_index][ifo]) % len(snr_dict[ifo])
                            for ifo in args.instruments}
                        coherent_ifo_trigs = {
                            ifo: snr_dict[ifo][coinc_idx_det_frame[ifo]]
                            for ifo in args.instruments}
                            # Calculate the power and autochi2 values for the coinc
                        # indexes (this uses the snr timeseries before the time
                        # delay, so we need to undo it. Same for normalisation)
                        chisq = {}
                        chisq_dof = {}
                        for ifo in args.instruments:
                            chisq[ifo], chisq_dof[ifo] = power_chisq.values(
                                corr_dict[ifo],
                                coherent_ifo_trigs[ifo] / norm_dict[ifo],
                                norm_dict[ifo], stilde[ifo].psd,
                                coinc_idx_det_frame[ifo]
                                + stilde[ifo].analyze.start,
                                template)
                        # Calculate network chisq value
                        network_chisq_dict = coh.network_chisq(
                            chisq, chisq_dof, coherent_ifo_trigs)
                        # Calculate chisq reweighted SNR
                        if nifo > 2:
                            reweighted_snr = ranking.newsnr(
                                rho_coh, network_chisq_dict)
                            # Calculate null reweighted SNR
                            reweighted_snr = coh.reweight_snr_by_null(
                                reweighted_snr, null, rho_coh)
                        elif nifo == 2:
                            reweighted_snr = ranking.newsnr(
                                rho_coinc, network_chisq_dict)
                        else:
                            rho_sngl = abs(
                                snr[args.instruments[0]][
                                    coinc_idx_det_frame[args.instruments[0]]
                                    ]
                                )
                            reweighted_snr = ranking.newsnr(
                                rho_sngl, network_chisq_dict)
                        # Need all out vals to be the same length. This means
                        # the entries that are single values need to be
                        # repeated once per event.
                        num_events = len(reweighted_snr)
                        # the output will only be possible if
                        # len(networkchi2) == num_events
                        for ifo in args.instruments:
                            (ifo_out_vals['bank_chisq'],
                                    ifo_out_vals['bank_chisq_dof']) =\
                                bank_chisq.values(
                                    template, stilde[ifo].psd, stilde[ifo],
                                    coherent_ifo_trigs[ifo] / norm_dict[ifo],
                                    norm_dict[ifo],
                                    coinc_idx_det_frame[ifo]
                                    + stilde[ifo].analyze.start)
                            ifo_out_vals['cont_chisq'] = autochisq.values(
                                snr_dict[ifo] / norm_dict[ifo],
                                coinc_idx_det_frame[ifo], template,
                                stilde[ifo].psd, norm_dict[ifo],
                                stilde=stilde[ifo], low_frequency_cutoff=flow)
                            ifo_out_vals['chisq'] = chisq[ifo]
                            ifo_out_vals['chisq_dof'] = chisq_dof[ifo]
                            ifo_out_vals['time_index'] = (
                                coinc_idx_det_frame[ifo]
                                + stilde[ifo].cumulative_index
                                )
                            ifo_out_vals['snr'] = coherent_ifo_trigs[ifo]
                            # IFO is stored as an int
                            ifo_out_vals['ifo'] = (
                                [event_mgr.ifo_dict[ifo]] * num_events
                                )
                            # Time slide ID
                            ifo_out_vals['slide_id'] = [slide] * num_events
                            event_mgr.add_template_events_to_ifo(
                                ifo, ifo_names,
                                [ifo_out_vals[n] for n in ifo_names])
                        if nifo>2:
                            network_out_vals['coherent_snr'] = rho_coh
                            network_out_vals['null_snr'] = null
                        elif nifo==2:
                            network_out_vals['coherent_snr'] = rho_coinc
                        else:
                            network_out_vals['coherent_snr'] = (
                                abs(snr[args.instruments[0]][
                                    coinc_idx_det_frame[args.instruments[0]]
                                    ])
                                )
                        network_out_vals['reweighted_snr'] = reweighted_snr
                        network_out_vals['my_network_chisq'] = (
                            np.real(network_chisq_dict))
                        network_out_vals['time_index'] = (
                            coinc_idx + stilde[ifo].cumulative_index)
                        network_out_vals['nifo'] = [nifo] * num_events
                        network_out_vals['dec'] = [sky_positions[1][position_index]] * num_events
                        network_out_vals['ra'] = [sky_positions[0][position_index]] * num_events
                        network_out_vals['slide_id'] = [slide] * num_events
                        event_mgr.add_template_events_to_network(
                            network_names,
                            [network_out_vals[n] for n in network_names])
            if args.cluster_method == "window":
                cluster_window = int(args.cluster_window * sample_rate)
            elif args.cluster_method == "template":
                cluster_window = int(template.chirp_length * sample_rate)
            # Cluster template events by slide
            for slide in range(args.num_slides + 1):
                logging.info("Clustering slide %d", slide)
                event_mgr.cluster_template_network_events(
                    'time_index', 'reweighted_snr', cluster_window, slide=slide)
        event_mgr.finalize_template_events()
event_mgr.write_events(args.output)
logging.info("Finished")
logging.info("Time to complete analysis: %d", int(time.time() - time_init))
