#!/usr/bin/python3.13 -s

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
import itertools
import logging
import sys

import numpy

from matplotlib import use
use('agg')
from matplotlib import pyplot as plt

import pycbc
from pycbc import (distributions, results, waveform)
from pycbc.inference.option_utils import ParseParametersArg
from pycbc.distributions.utils import prior_from_config
from pycbc.workflow import WorkflowConfigParser
from pycbc.results.scatter_histograms import create_multidim_plot


def cartesian(arrays):
    """ Returns a cartesian product from a list of iterables.
    """
    return numpy.array([numpy.array(element)
                        for element in itertools.product(*arrays)])

# command line usage
parser = argparse.ArgumentParser(
    usage="pycbc_inference_plot_prior [--options]",
    description="Plots prior distributions.")
pycbc.add_common_pycbc_options(parser)
# add input options
parser.add_argument("--config-files", type=str, nargs="+", required=True,
                    help="A file parsable by "
                         "pycbc.workflow.WorkflowConfigParser.")
parser.add_argument("--parameters", nargs="+", action=ParseParametersArg,
                    metavar="PARAM[:LABEL]",
                    help="Only plot the given parameters. May also provide a "
                         "label for each parameter. If provided, the "
                         "parameters can be any of parameters in the "
                         "[variable_params] section, derived parameters from "
                         "them, or any function of them. Syntax for "
                         "functions is python; any math functions in the "
                         "numpy libary may be used. Can optionally also "
                         "specify a LABEL for each parameter. If no LABEL is "
                         "provided, PARAM will used as the LABEL. If LABEL "
                         "is the same as a parameter in "
                         "pycbc.waveform.parameters, the label "
                         "property of that parameter will be used. If not "
                         "provided, will plot all of the parameters in the "
                         "[variable_params] section of the config file.")
parser.add_argument("--prior-section", type=str, default="prior",
                    help="Name of section with distribution configurations. "
                         "Default is 'prior'.")
# add output options
parser.add_argument("--nsamples", type=int, default=10000,
                    help="The number of samples to draw from the prior for "
                         "plotting. Default is 10000.")
parser.add_argument("--output-file", type=str, required=True,
                    help="Path to output plot.")

# parse the command line
opts = parser.parse_args()

# setup log
pycbc.init_logging(opts.verbose)

# read configuration file
logging.info("Reading configuration files")
cp = WorkflowConfigParser(opts.config_files)

logging.info("Loading prior")
prior = prior_from_config(cp, prior_section=opts.prior_section)
logging.info("Drawing samples from prior")
samples = prior.rvs(opts.nsamples)

# figure out what parameters to plot
if opts.parameters is None:
    parameters = prior.variable_args
    labels = {}
    for param in parameters:
        # try to get the parameter label from the waveform module
        try:
            labels[param] = getattr(waveform.parameters, param).label
        except AttributeError:
            labels[param] = param
else:
    parameters = opts.parameters
    labels = opts.parameters_labels

logging.info("Plotting")
fig, axis_dict = create_multidim_plot(
    parameters, samples,
    labels=labels,
    plot_marginal=True,
    marginal_percentiles=[5, 50, 95],
    plot_scatter=False,
    plot_density=True,
    plot_contours=True,
    contour_percentiles=[50, 90],)


# set DPI
fig.set_dpi(200)

# save figure with meta-data
caption = """This plot shows the probability density function (PDF) from the 
prior distributions."""
title = "Prior Distribution"
results.save_fig_with_metadata(fig, opts.output_file,
                               cmd=" ".join(sys.argv),
                               title=title,
                               caption=caption)
plt.close()

# exit
logging.info("Done")
