#!/usr/bin/python3.13 -s

# Copyright (C) 2015 Andrew R. Williamson
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Make workflow for the archival, targeted, coherent inspiral pipeline.
"""

import sys
import os
import argparse
import logging
import matplotlib
matplotlib.use('agg')

from ligo.segments import segment, segmentlistdict

from pycbc import init_logging, add_common_pycbc_options
import pycbc.workflow as _workflow
from pycbc.workflow.core import configparser_value_to_file
from pycbc.results.pygrb_plotting_utils import make_grb_segments_plot

workflow_name = "pygrb_offline"

# Parse command line options and instantiate pycbc workflow object
parser = argparse.ArgumentParser()
add_common_pycbc_options(parser)
_workflow.add_workflow_command_line_group(parser)
_workflow.add_workflow_settings_cli(parser)
args = parser.parse_args()

init_logging(args.verbose, default_level=1)

wflow = _workflow.Workflow(args, workflow_name)
all_files = _workflow.FileList([])
tags = []
init_dir = os.getcwd()

logging.info("Generating %s workflow" % workflow_name)

# Setup run directory
if wflow.cp.has_option("workflow", "output-directory"):
    base_dir = wflow.cp.get("workflow", "output-directory")
else:
    base_dir = os.getcwd()
triggername = str(wflow.cp.get("workflow", "trigger-name"))
run_dir = os.path.join(base_dir, "GRB%s" % triggername)
logging.info("Workflow will be generated in %s" % run_dir)
if not os.path.exists(run_dir):
    os.makedirs(run_dir)
os.chdir(run_dir)

# SEGMENTS
triggertime = int(wflow.cp.get("workflow", "trigger-time"))
start = triggertime - int(wflow.cp.get("workflow-exttrig_segments",
                                       "max-duration"))
end = triggertime + int(wflow.cp.get("workflow-exttrig_segments",
                                     "max-duration"))
wflow.cp = _workflow.set_grb_start_end(wflow.cp, start, end)

# Retrieve science segments
curr_dir = os.getcwd()
seg_dir = os.path.join(curr_dir, "segments")
sciSegsFile = _workflow.get_segments_file(wflow,
                                          'science',
                                          'segments-science',
                                          seg_dir)

sciSegs = {}
for ifo in wflow.ifos:
    sciSegs[ifo] = sciSegsFile.segment_dict[ifo+':science']

# This block of code and PyCBCMultiInspiralExecutable.get_valid_times()
# must be consistent
if wflow.cp.has_option("inspiral", "segment-start-pad"):
    pad_data = int(wflow.cp.get("inspiral", "pad-data"))
    start_pad = int(wflow.cp.get("inspiral", "segment-start-pad"))
    end_pad = int(wflow.cp.get("inspiral", "segment-end-pad"))
    wflow.cp.set("workflow-exttrig_segments", "min-before",
                 str(start_pad+pad_data))
    wflow.cp.set("workflow-exttrig_segments", "min-after",
                 str(end_pad+pad_data))
elif wflow.cp.has_option("inspiral", "analyse-segment-end"):
    safety = 1
    deadtime = int(wflow.cp.get("inspiral", "segment-length")) / 2
    spec_len = int(wflow.cp.get("inspiral", "inverse-spec-length")) / 2
    wflow.cp.set("workflow-exttrig_segments", "min-before",
                 str(deadtime - spec_len - safety))
    wflow.cp.set("workflow-exttrig_segments", "min-after",
                 str(spec_len + safety))
else:
    deadtime = int(wflow.cp.get("inspiral", "segment-length")) / 4
    wflow.cp.set("workflow-exttrig_segments", "min-before", str(deadtime))
    wflow.cp.set("workflow-exttrig_segments", "min-after", str(deadtime))

# Do checks for no/single IFO case
single_ifo = wflow.cp.has_option("workflow", "allow-single-ifo-search")
if len(sciSegs.keys()) == 0:
    plot_met = make_grb_segments_plot(wflow, segmentlistdict(), triggertime,
                                      triggername, seg_dir)
    logging.warning("No science segments available.")
    sys.exit()
elif len(sciSegs.keys()) < 2 and not single_ifo:
    plot_met = make_grb_segments_plot(wflow, segmentlistdict(sciSegs),
                                      triggertime, triggername, seg_dir)
    msg = "Science segments exist only for %s. " % tuple(sciSegs.keys())[0]
    msg += "If you wish to enable single IFO running add the option "
    msg += "'allow-single-ifo-search' to the [workflow] section of your "
    msg += "configuration file."
    logging.warning(msg)
    sys.exit()
else:
    onSrc, offSrc, bufferSeg = _workflow.generate_triggered_segment(wflow,
                                                                    seg_dir,
                                                                    sciSegs)

sciSegs = segmentlistdict(sciSegs)
if onSrc is None:
    plot_met = make_grb_segments_plot(wflow, sciSegs, triggertime, triggername,
                                      seg_dir, fail_criterion=offSrc)
    logging.info("Making segment plot and exiting.")
    sys.exit()
else:
    plot_met = make_grb_segments_plot(
            wflow, sciSegs, triggertime, triggername, seg_dir,
            coherent_seg=offSrc[tuple(offSrc.keys())[0]][0])
    segs_plot = _workflow.File(plot_met[0], plot_met[1], plot_met[2],
                               file_url=plot_met[3])
    segs_plot.add_pfn(segs_plot.cache_entry.path, site="local")
    sciSegs = offSrc
    all_files.append(segs_plot)

if len(sciSegs) == 1:
    logging.info("Generating a single IFO search.")
    mf_tag = "sngl"
elif len(sciSegs) > 1:
    mf_tag = "coherent"

# Update analysis time after coherent segment calculation
ifo = tuple(sciSegs.keys())[0]
wflow.cp = _workflow.set_grb_start_end(wflow.cp, int(sciSegs[ifo][0][0]),
                                       int(sciSegs[ifo][0][1]))

padding = int(wflow.cp.get("inspiral", "pad-data"))
if wflow.cp.has_option("workflow-condition_strain", "do-gating"):
    gate_pad = int(wflow.cp.get("condition_strain", "pad-data"))
    padding += gate_pad

if wflow.cp.has_option("inspiral", "segment-start-pad"):
    start_pad = int(wflow.cp.get("inspiral", "segment-start-pad"))
    end_pad = int(wflow.cp.get("inspiral", "segment-end-pad"))
    wflow.analysis_time = segment(int(sciSegs[ifo][0][0]) +
                                  start_pad + padding,
                                  int(sciSegs[ifo][0][1]) -
                                  padding - end_pad)
elif wflow.cp.has_option("inspiral", "analyse-segment-end"):
    wflow.analysis_time = segment(int(sciSegs[ifo][0][0]) + deadtime -
                                  spec_len + padding - safety,
                                  int(sciSegs[ifo][0][1]) - spec_len -
                                  padding - safety)
else:
    wflow.analysis_time = segment(int(sciSegs[ifo][0][0]) + deadtime + padding,
                                  int(sciSegs[ifo][0][1]) - deadtime - padding)

ext_file = None

# DATAFIND
df_dir = os.path.join(curr_dir, "datafind")
datafind_files, _, sciSegs, _ = _workflow.setup_datafind_workflow(wflow,
                                                                  sciSegs,
                                                                  df_dir,
                                                                  sciSegsFile)
if wflow.cp.has_option("workflow-condition_strain", "do-gating"):
    new_seg = segment(sciSegs[ifo][0][0] + gate_pad,
                      sciSegs[ifo][0][1] - gate_pad)
    for iifo in sciSegs:
        sciSegs[iifo][0] = new_seg
    wflow.cp = _workflow.set_grb_start_end(wflow.cp, int(sciSegs[ifo][0][0]),
                                           int(sciSegs[ifo][0][1]))

ifos = sorted(sciSegs.keys())
ifo = ifos[0]
wflow.ifos = ifos
datafind_veto_files = _workflow.FileList([])

# GATING
if wflow.cp.has_option("workflow-condition_strain", "do-gating"):
    logging.info("Creating gating jobs.")
    wflow.cp = _workflow.set_grb_start_end(wflow.cp, int(sciSegs[ifo][0][0]),
                                           int(sciSegs[ifo][0][1]))
    gating_nodes, gated_files = _workflow.make_gating_node(wflow,
                                                           datafind_files,
                                                           outdir=df_dir)
    gating_method = wflow.cp.get("workflow-condition_strain",
                                 "gating-method")
    for gating_node in gating_nodes:
        if gating_method == "IN_WORKFLOW":
            wflow.add_node(gating_node)
        elif gating_method == "AT_RUNTIME":
            logging.info("Executing gating node...")
            wflow.execute_node(gating_node)
        else:
            msg = "[workflow-condition_strain] option 'gating-method' can "
            msg += "only have one of the values 'IN_WORKFLOW' or 'AT_RUNTIME'."
            msg += " You have provided the value %s." % gating_method
            logging.error(msg)
            sys.exit()
    datafind_files = _workflow.FileList([])
    for ifo in ifos:
        gated_frames = _workflow.FileList([gated_frame for gated_frame in
                                           gated_files
                                           if gated_frame.ifo == ifo])
        # TODO: Remove .lcf cache here
        gated_cache = _workflow.File(
                ifo, "gated",
                segment(int(wflow.cp.get("workflow", "start-time")),
                        int(wflow.cp.get("workflow", "end-time"))),
                extension="lcf", directory=df_dir)
        gated_cache.add_pfn(gated_cache.cache_entry.path, site="local")
        gated_frames.convert_to_lal_cache().tofile(
                            open(gated_cache.storage_path, "w"))
        datafind_files.append(gated_cache)

datafind_veto_files.extend(datafind_files)
ifo_list = sorted(sciSegs.keys())
ifo = ifo_list[0]
ifos = ''.join(ifo_list)
wflow.ifos = ifos

# Config file consistency check for IPN GRBs
if wflow.cp.has_option("workflow-inspiral", "ipn-search-points") \
        and wflow.cp.has_option("workflow-injections", "ipn-sim-points"):
    wflow.cp.set("injections", "ipn-gps-time",
                 wflow.cp.get("workflow", "trigger-time"))
    IPN = True
elif wflow.cp.has_option("workflow-inspiral", "ipn-search-points") \
        or wflow.cp.has_option("workflow-injections", "ipn-sim-points"):
    msg = "You have provided only one of 'ipn-search-points' under "
    msg += "[workflow-inspiral] and 'ipn-sim-points' under "
    msg += "[workflow-injections] in your configuration files. If this is an "
    msg += "IPN GRB please provide both, otherwise provide neither."
    logging.error(msg)
    sys.exit()
else:
    IPN = False

# Get bank_veto_bank.xml if running bank veto
if wflow.cp.has_option('workflow-inspiral', 'bank-veto-bank-file'):
    bank_veto_file = configparser_value_to_file(wflow.cp, 'workflow-inspiral',
                                                'bank-veto-bank-file')
    bank_veto_file = _workflow.FileList([bank_veto_file])
    datafind_veto_files.extend(bank_veto_file)

if IPN:
    file_attrs = {
        'ifos': wflow.ifos,
        'segs': wflow.analysis_time,
        'exe_name': "IPN_SKY_POINTS",
        'tags': ["SEARCH"]
    }
    search_pts_file = configparser_value_to_file(wflow.cp,
                                                 'workflow-inspiral',
                                                 'ipn-search-points',
                                                 file_attrs=file_attrs)
    datafind_veto_files.append(search_pts_file)

all_files.extend(datafind_veto_files)

# TEMPLATE BANK AND SPLIT BANK
bank_files = _workflow.setup_tmpltbank_workflow(wflow, sciSegs,
                                                datafind_files, df_dir)
# Check there are not multiple banks
if len(bank_files) > 1:
    raise NotImplementedError("Multiple banks not supported")
full_bank_file = bank_files[0]
# Note: setup_splittable_workflow requires a FileList as input
splitbank_files = _workflow.setup_splittable_workflow(wflow,
                                                      bank_files,
                                                      df_dir,
                                                      tags=["inspiral"])
all_files.append(full_bank_file)
all_files.extend(splitbank_files)

# INJECTIONS
injs = None
inj_tags = []
inj_files = None
inj_caches = None
inj_insp_files = None
inj_insp_caches = None
if wflow.cp.has_section("workflow-injections"):
    inj_dir = os.path.join(curr_dir, "injections")
    inj_caches = _workflow.FileList([])
    inj_insp_caches = _workflow.FileList([])

    # Given the stretch of time this workflow will analyse, and the onsource
    # window with its buffer, generate the configuration file with the prior
    # for the injections times and add it to the config parser
    inj_method = wflow.cp.get("workflow-injections",
                              "injections-method")
    if inj_method == "IN_WORKFLOW" and \
        wflow.cp.has_option("workflow-injections", "tc-prior-at-runtime"):
        tc_path = os.path.join(base_dir, "tc_prior.ini")
        _workflow.generate_tc_prior(wflow, tc_path, bufferSeg)

    # Generate injection files
    if IPN:
        # TODO: we used to pass this file to setup_injection_workflow as
        # exttrig_file = sim_pts_file to then use it in lalapps_inspinj.
        # The code below picks it up but does not pass it: get
        # setup_injection_workflow or pycbc_create_injections handle it
        # directly via configparser_value_to_file
        file_attrs = {
            'ifos': wflow.ifos,
            'segs': wflow.analysis_time,
            'exe_name': "IPN_SKY_POINTS",
            'tags': ["SIM"]
        }
        sim_pts_file = configparser_value_to_file(wflow.cp,
                                                  'workflow-inspiral',
                                                  'ipn-sim-points',
                                                  file_attrs=file_attrs)
        all_files.append(sim_pts_file)
    inj_files, inj_tags = _workflow.setup_injection_workflow(wflow, inj_dir)
    all_files.extend(inj_files)
    injs = inj_files

    # Either split template bank for injections jobs or use same split banks
    # as for standard matched filter jobs
    if wflow.cp.has_section("workflow-splittable-injections"):
        inj_splitbank_files = _workflow.setup_splittable_workflow(
                wflow, bank_files, inj_dir, tags=["injections"])
        for inj_split in inj_splitbank_files:
            split_str = [s for s in inj_split.tagged_description.split("_")
                         if ("BANK" in s and s[-1].isdigit())]
            if len(split_str) != 0:
                inj_split.tagged_description += "%s_%d" % (
                       inj_split.tag_str,
                       int(split_str[0].replace("BANK", "")))
        all_files.extend(inj_splitbank_files)
    else:
        inj_splitbank_files = _workflow.FileList([])
        inj_splitbank_files.extend(splitbank_files)

    # Split the injection files
    if wflow.cp.has_section("workflow-splittable-split_inspinj"):
        inj_split_files = _workflow.FileList([])
        for inj_file, inj_tag in zip(inj_files, inj_tags):
            file = _workflow.FileList([inj_file])
            inj_splits = _workflow.setup_splittable_workflow(
                    wflow, file, inj_dir, tags=["split_inspinj", inj_tag])
            for inj_split in inj_splits:
                split_str = [s for s in
                             inj_split.tagged_description.split("_")
                             if ("SPLIT" in s and s[-1].isdigit())]
                if len(split_str) != 0:
                    new = inj_split.tagged_description.replace(
                            split_str[0],
                            "SPLIT_%s" % split_str[0].replace("SPLIT", ""))
                    inj_split.tagged_description = new
            inj_split_files.extend(inj_splits)
        all_files.extend(inj_split_files)
        injs = inj_split_files

    # Generate injection matched filter workflow
    inj_insp_files = _workflow.setup_matchedfltr_workflow(
            wflow, sciSegs, datafind_veto_files, inj_splitbank_files,
            inj_dir, injs, tags=[mf_tag + "_injections"])
    for inj_insp_file in inj_insp_files:
        split_str = [s for s in inj_insp_file.name.split("_")
                     if ("SPLIT" in s and s[-1].isdigit())]
        if len(split_str) != 0:
            num = split_str[0].replace("SPLIT", "_")
            inj_insp_file.tagged_description += num

    # Make cache files (needed for post-processing)
    for inj_tag in inj_tags:
        files = _workflow.FileList([file for file in injs
                                    if inj_tag in file.tag_str])
        inj_cache = _workflow.File(ifos, "injections", sciSegs[ifo][0],
                                   extension="lcf", directory=inj_dir,
                                   tags=[inj_tag])
        inj_cache.add_pfn(inj_cache.cache_entry.path, site="local")
        inj_caches.append(inj_cache)
        inj_cache_entries = files.convert_to_lal_cache()
        inj_cache_entries.tofile(open(inj_cache.storage_path, "w"))

        files = _workflow.FileList([file for file in inj_insp_files
                                    if inj_tag in file.tag_str])
        inj_insp_cache = _workflow.File(ifos, "inspiral_injections",
                                        sciSegs[ifo][0], extension="lcf",
                                        directory=inj_dir, tags=[inj_tag])
        inj_insp_cache.add_pfn(inj_insp_cache.cache_entry.path, site="local")
        inj_insp_caches.append(inj_insp_cache)
        inj_insp_cache_entries = files.convert_to_lal_cache()
        inj_insp_cache_entries.tofile(open(inj_insp_cache.storage_path, "w"))

    all_files.extend(inj_caches)
    all_files.extend(inj_insp_files)
    all_files.extend(inj_insp_caches)

# MAIN MATCHED FILTERING
insp_dir = os.path.join(curr_dir, "inspiral")
inspiral_files = _workflow.setup_matchedfltr_workflow(
        wflow, sciSegs,
        datafind_veto_files, splitbank_files, insp_dir,
        tags=[mf_tag + "_no_injections"])
all_files.extend(inspiral_files)
# TODO: Remove .lcf caches here?
inspiral_cache = _workflow.File(ifos, "inspiral", sciSegs[ifo][0],
                                extension="lcf", directory=insp_dir)
inspiral_cache.add_pfn(inspiral_cache.cache_entry.path, site="local")
all_files.append(inspiral_cache)
inspiral_cache_entries = inspiral_files.convert_to_lal_cache()
inspiral_cache_entries.tofile(open(inspiral_cache.storage_path, "w"))

# TODO: LONG TIME SLIDES

# POST-PROCESSING
pp_dir = os.path.join(curr_dir, "post_processing")
os.makedirs(pp_dir)
post_proc_method = wflow.cp.get_opt_tags("workflow-postproc",
                                         "postproc-method", tags)
pp_files = _workflow.FileList([])
results_files = _workflow.FileList([])
if post_proc_method == "PYGRB_OFFLINE":
    trig_comb_files, clustered_files, inj_find_files =\
        _workflow.setup_pygrb_pp_workflow(wflow, pp_dir, seg_dir,
                                          sciSegs[ifo][0], full_bank_file,
                                          inspiral_files, injs,
                                          inj_insp_files, inj_tags)
    sec_name = 'workflow-pygrb_pp_workflow'
    if not wflow.cp.has_section(sec_name):
        msg = 'No {0} section found in configuration file.'.format(sec_name)
        logging.info(msg)
    else:
        logging.info('Entering results module')
        results_files = _workflow.setup_pygrb_results_workflow(wflow, pp_dir,
                                                               clustered_files,
                                                               inj_find_files,
                                                               full_bank_file,
                                                               seg_dir)
        logging.info('Leaving results module')

all_files.extend(pp_files)
all_files.extend(results_files)

# COMPILE WORKFLOW AND WRITE DAX
wflow.save()
logging.info("Written dax.")
