#!/usr/bin/python3.13 -s

# Copyright (C) 2017 Christopher M. Biwer, Alexander Harvey Nitz, Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Creates a DAX for a parameter estimation injection study.
"""

import argparse
import logging
import os
import shlex
import numpy
import socket
import sys

from pycbc import results, init_logging, add_common_pycbc_options
from pycbc.results import layout
from pycbc.results import metadata
from pycbc.workflow import configuration
from pycbc.workflow import core
from pycbc.workflow.jobsetup import (PycbcCreateInjectionsExecutable,
                                     PycbcInferenceExecutable)
from pycbc.workflow import inference_followups as inffu
from pycbc.workflow import plotting
from pycbc.workflow import versioning
from pycbc.inject import InjectionSet
from pycbc.io import FieldArray


def config_from_config(cp, section, skip_opts=None):
    """Loads a config parser from the given section."""
    # create a dummy command-line parser for getting the config files and
    # options
    cfparser = argparse.ArgumentParser()
    configuration.add_workflow_command_line_group(cfparser)
    cli = cp.section_to_cli(section, skip_opts=skip_opts)
    opts = cfparser.parse_args(shlex.split(cli))
    return configuration.WorkflowConfigParser.from_cli(opts)


def read_inference_settings_from_config(cp, section):
    """Loads the config parser and gets the number of inference runs to do."""
    if cp.has_option(section, 'nruns'):
        nruns = int(cp.get(section, 'nruns'))
    else:
        nruns = 1
    return config_from_config(cp, section, skip_opts=['nruns']), nruns
    

def symlink_path(f, path):
    """ Symlinks a path.
    """
    if f is None:
        return
    try:
        os.symlink(f.storage_path, os.path.join(path, f.name))
    except OSError:
        pass


# set command line parser
parser = argparse.ArgumentParser(description=__doc__[1:])
add_common_pycbc_options(parser)

# injection options: either specify a number to create, or use the given file
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--num-injections", type=int,
                   help="The number of injections to create.")
group.add_argument("--injection-file", type=str, nargs="+", 
                   help="Analyze injections in the given file(s) instead of "
                        "creating them.")
# add option groups
configuration.add_workflow_command_line_group(parser)
# add workflow group
core.add_workflow_settings_cli(parser, include_subdax_opts=True)
parser.add_argument("--seed", type=int, default=0,
                    help="Starting to seed to use. This will be incremented "
                         "one for each injection analyzed. Default is 0.")

# parser command line
opts = parser.parse_args()

init_logging(opts.verbose, default_level=1)

# configuration files
config_file_tmplt = 'inference-{}.ini'
config_file_dir = 'config_files'
# the directory we'll store samples files to
samples_file_dir = 'samples_files'
# the directory we'll store posterior files to
posterior_file_dir = 'posterior_files'
# the directory we'll store the injection files to
injection_file_dir = 'injection_files'

# make data output directory
if opts.output_dir is None:
    opts.output_dir = opts.workflow_name + '_output'
core.makedir(opts.output_dir)
core.makedir('{}/{}'.format(opts.output_dir, config_file_dir))
core.makedir('{}/{}'.format(opts.output_dir, posterior_file_dir))
core.makedir('{}/{}'.format(opts.output_dir, injection_file_dir))

# create workflow and sub-workflows
workflow = core.Workflow(opts, name=opts.workflow_name)
finalize_workflow = core.Workflow(opts, name="finalization")

# load the inference config file
inference_cp, nruns = read_inference_settings_from_config(
    workflow.cp, 'workflow-inference')

# change working directory to the output
origdir = os.path.abspath(os.curdir)
os.chdir(opts.output_dir)

# if an injection file was provided, write each injection as a separate file
# in the injections directory
if opts.injection_file:
    injection_files = core.FileList([])
    n_injections = 0
    for injection_file in opts.injection_file:
        injections = InjectionSet('{}/{}'.format(origdir, injection_file))
        local_numinj = len(injections.table)
        st_args = {k: injections.table[k][0] 
                   for k in injections._injhandler.static_args}
        for ii in range(local_numinj):
            samples = {f: injections.table[ii][f] 
                       for f in injections.table.fieldnames
                       if f not in injections._injhandler.static_args}
            outfile = '{}/injection{}.hdf'.format(injection_file_dir,
                                                  n_injections)
            injections.write(
                 outfile, FieldArray.from_kwargs(**samples),
                 write_params=[k for k in injections.table.fieldnames 
                               if k not in injections._injhandler.static_args],
                 static_args=st_args)
            injection_files.append(core.resolve_url_to_file(outfile))
            n_injections += 1
else:
    n_injections = opts.num_injections

# figure out what diagnostic jobs there are
diagnostics = inffu.get_diagnostic_plots(workflow)

# figure out if we're doing pp tests
do_pp_test = workflow.cp.has_option('executables', 'pp_table_summary')
if do_pp_test:
    # get the parameters to do the test
    pp_params = ["'"+s+"'" for s in
                 shlex.split(workflow.cp.get('workflow-pp_test', 'pp-params'))]
    # see if there is an injection map provided
    if workflow.cp.has_option('workflow-pp_test', 'injection-samples-map'):
        inj_samples_map = workflow.cp.get('workflow-pp_test',
                                          'injection-samples-map')
    else:
        inj_samples_map = None
    pp_section = 'percentile-percentile_test'
    pp_dir = [pp_section]
else:
    pp_dir = []

# sections for output HTML pages
rdir = layout.SectionNumber("results",
                            pp_dir +
                            ["detector_sensitivity", "priors", "posteriors"] +
                            diagnostics +
                            ["config_files", "workflow"])

# make results directories
core.makedir(rdir.base)
core.makedir(rdir["workflow"])
core.makedir(rdir["config_files"])

# create files for workflow log
log_file_txt = core.File(workflow.ifos, "workflow-log",
                         workflow.analysis_time,
                         extension=".txt", directory=rdir["workflow"])
log_file_html = core.File(workflow.ifos, "WORKFLOW-LOG",
                          workflow.analysis_time,
                          extension=".html", directory=rdir["workflow"])

# Save log to file as well
init_logging(opts.verbose, default_level=1, to_file=log_file_txt.storage_path)
logging.info("Created log file %s", log_file_txt.storage_path)

config_files = {}
posterior_files = core.FileList([])
seed = opts.seed
# loop over number of injecionts to be analyzed
sub_workflows = []
for num_inj in range(n_injections):
    zpad = int(numpy.ceil(numpy.log10(n_injections)))
    label = 'Injection {}'.format(str(num_inj+1).zfill(zpad))
    event = label.lower().replace(' ', '_')

    # create a sub workflow for this event
    # we need to go back to the original directory to do this for all the file
    # references to work correctly
    os.chdir(origdir)
    sub_workflow = core.Workflow(opts, name=event)
    # now go back to the output
    os.chdir(opts.output_dir)

    # if the config file has a fake-strain-seed set, increment to get
    # independent noise realizations
    if inference_cp.has_option('data', 'fake-strain-seed'):
        # get the detectors
        detectors = shlex.split(inference_cp.get('data', 'instruments'))
        fake_strain_seeds = {}
        for det in detectors:
            fake_strain_seeds[det] = seed
            seed = seed + 1
        inference_cp.set('data', 'fake-strain-seed',
                         ' '.join(['{}:{}'.format(d, s)
                                   for d, s in fake_strain_seeds.items()]))

    # write the configuration file to the config files directory
    config_file = sub_workflow.save_config(config_file_tmplt.format(event),
                                           config_file_dir, inference_cp)[0]

    # create sym links to config file for results page
    base = "config_files/{}".format(event)
    layout.single_layout(rdir[base], [config_file])
    symlink_path(config_file, rdir[base])

    # add a node to create the injection if an injection file wasn't provided
    if opts.injection_file:
        injection_file = injection_files[num_inj]
    else:
        # construct Executable for creating injections
        create_injections_exe = PycbcCreateInjectionsExecutable(
            sub_workflow.cp, "create_injections", ifos=sub_workflow.ifos,
            out_dir=injection_file_dir)
        node, injection_file = create_injections_exe.create_node(
            [config_file], seed=seed, tags=opts.tags+[event])
        node.add_opt("--ninjections", 1)
        sub_workflow += node
        seed = seed + 1

    # add the injection file to the inference config file
    cp = configuration.WorkflowConfigParser([config_file.storage_path])
    cp.set('data', 'injection-file', injection_file.name)
    with open(config_file.storage_path, "w") as fp:
        cp.write(fp)

    # make node(s) for running sampler
    samples_files = []
    inference_exe = PycbcInferenceExecutable(sub_workflow.cp, "inference",
                                             ifos=sub_workflow.ifos,
                                             out_dir=samples_file_dir)
    for nn in range(nruns):
        tags = opts.tags + [event]
        if nruns > 1:
            tags.append(str(nn))
        node, samples_file = inference_exe.create_node(
            config_file, seed=seed, tags=tags,
            analysis_time=sub_workflow.analysis_time)
        # declare the injection file as a needed input file
        node.add_input(injection_file)
        # add node to workflow
        sub_workflow += node
        samples_files.append(samples_file)
        # increment the seed
        seed = seed + 1

    # create the posterior file and plots
    posterior_file, summary_files, _, _ = inffu.make_posterior_workflow(
        sub_workflow, samples_files, config_file, event, rdir,
        posterior_file_dir=posterior_file_dir, tags=opts.tags)
    posterior_files.append(posterior_file)

    # create the diagnostic plots
    _ = inffu.make_diagnostic_plots(sub_workflow, diagnostics, samples_files,
                                    event, rdir, tags=opts.tags)

    # files for detector_sensitivity summary subsection
    base = "detector_sensitivity"
    psd_plot = plotting.make_spectrum_plot(
        sub_workflow, [samples_files[0]], rdir[base],
        tags=opts.tags+[event],
        hdf_group="data")

    # build the summary page
    zpad = int(numpy.ceil(numpy.log10(len(samples_files))))
    layout.two_column_layout(rdir.base, summary_files,
                             unique=str(num_inj).zfill(zpad),
                             title=label, collapse=True)

    # build the psd page
    layout.single_layout(rdir['detector_sensitivity'], [psd_plot],
                         unique=str(num_inj).zfill(zpad),
                         title=label, collapse=True)

    # add the sub workflow to the main workflow
    workflow += sub_workflow
    sub_workflows.append(sub_workflow)

# create the PP and recovery plot
if do_pp_test:
    # create a sub workflow to do the PP test
    # we need to go back to the original directory to do this for all the file
    # references to work correctly
    os.chdir(origdir)
    pp_workflow = core.Workflow(opts, name='pp_test')
    # now go back to the output
    os.chdir(opts.output_dir)
    # create the pp summary table
    pp_table = inffu.make_inference_pp_table(
        pp_workflow, posterior_files, rdir[pp_section],
        parameters=pp_params, injection_samples_map=inj_samples_map,
        analysis_seg=workflow.analysis_time, tags=opts.tags)[0]
    # now the pp plots and injection recovery
    pp_plots = []
    inj_recovery_plots = []
    for pi, param in enumerate(pp_params):
        tags = opts.tags + ['PARAM_{}'.format(pi)]
        # the pp plot
        pp_plot = inffu.make_inference_pp_plot(
            pp_workflow, posterior_files, rdir[pp_section],
            parameters=param, injection_samples_map=inj_samples_map,
            analysis_seg=workflow.analysis_time, tags=tags)[0]
        pp_plots.append(pp_plot)
        # the injection recovery plot
        injrec_plot = inffu.make_inference_inj_recovery_plot(
            pp_workflow, posterior_files, rdir[pp_section], param,
            injection_samples_map=inj_samples_map,
            analysis_seg=workflow.analysis_time, tags=tags)[0]
        inj_recovery_plots.append(injrec_plot)
    # add to the results page
    pp_files = [(pp_table,)] + list(zip(pp_plots, inj_recovery_plots))
    layout.two_column_layout(rdir[pp_section], pp_files)
    # add to the main workflow
    workflow += pp_workflow

# Create versioning information
versioning.make_versioning_page(
    workflow,
    workflow.cp,
    rdir['workflow/version'],
)

# create node for making HTML pages
plotting.make_results_web_page(finalize_workflow,
    os.path.join(os.getcwd(), rdir.base))

# add finalize workflow to workflow and make it depend on the others
workflow += finalize_workflow
for sub_workflow in sub_workflows:
    workflow.add_subworkflow_dependancy(sub_workflow, finalize_workflow)
if do_pp_test:
    workflow.add_subworkflow_dependancy(pp_workflow, finalize_workflow)

# write dax
workflow.save()

# save workflow configuration file
base = rdir["workflow/configuration"]
core.makedir(base)
wf_ini = workflow.save_config("workflow.ini", base, workflow.cp)
layout.single_layout(base, wf_ini)

# close the log and flush to the html file
logging.shutdown()
with open (log_file_txt.storage_path, "r") as log_file:
    log_data = log_file.read()
log_str = """
<p>Workflow generation script created workflow in output directory: %s</p>
<p>Workflow name is: %s</p>
<p>Workflow generation script run on host: %s</p>
<pre>%s</pre>
""" % (os.getcwd(), opts.workflow_name, socket.gethostname(), log_data)
kwds = {"title" : "Workflow Generation Log",
        "caption" : "Log of the workflow script %s" % sys.argv[0],
        "cmd" : " ".join(sys.argv)}
results.save_fig_with_metadata(log_str, log_file_html.storage_path, **kwds)
layout.single_layout(rdir["workflow"], ([log_file_html]))
