#!/usr/bin/python3.13 -s
# Copyright (C) 2023 Gareth S. Cabourn Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""
Plots various parameters against one
another for pycbc banks in an hdf file.
"""


import os
import sys
import h5py
import numpy as np
import argparse
import logging
from textwrap import wrap

import pycbc
from pycbc.results.plot import (add_style_opt_to_parser, set_style_from_cli)
from pycbc.io import FieldArray, HFile
from pycbc.inference import option_utils
from pycbc.tmpltbank import bank_conversions as bconv
from pycbc.pnutils import get_imr_duration
from pycbc.results.scatter_histograms import create_multidim_plot
from pycbc.results import metadata

conversion_options = bconv.conversion_options
_fit_parameters = [
    'count_above_thresh',
    'count_in_template',
    'fit_coeff',
    'median_sigma'
]
parameter_options = conversion_options + _fit_parameters


parser = argparse.ArgumentParser(usage='pycbc_plot_bank_corner [--options]',
                    description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--bank-file",
    required=True,
    help="The bank file to read in and plot")
parser.add_argument("--output-plot-file",
    required=True,
    help="Name of the output file")
parser.add_argument("--parameters",
    nargs="+",
    default=[],
    action=option_utils.ParseParametersArg,
    metavar="PARAM[:LABEL]",
    help="Only plot the given parameters. May also provide a "
         "label for each parameter. Choose from a combination of " + \
         ", ".join(parameter_options)  + ", though some "
         "may not be buildable from bank parameters. Can optionally "
         "also specify a LABEL for each parameter. If no LABEL is "
         "provided, PARAM will used as the LABEL. If LABEL "
         "is the same as a parameter in "
         "pycbc.waveform.parameters, the label "
         "property of that parameter will be used. If not "
         "provided, will plot all of the parameters in the "
         "bank.")
parser.add_argument('--plot-histogram',
    action='store_true',
    help="Plot 1D histograms of parameters on the "
         "diagonal axes.")
# add mins, maxs options
parser.add_argument('--mins',
    nargs='+',
    metavar='PARAM:VAL',
    default=[],
    help="Specify minimum parameter values to plot. This "
         "should be done by specifying the parameter name "
         "followed by the value. Parameter names must be "
         "the same as the PARAM argument in --parameters "
         "(or, if no parameters are provided, the same as "
         "the parameter name specified in the variable "
         "args in the input bank. If none provided, "
         "the smallest parameter value in the posterior "
         "will be used.")
parser.add_argument('--maxs',
    nargs='+',
    metavar='PARAM:VAL',
    default=[],
    help="Same as mins, but for the maximum values to "
         "plot.")
parser.add_argument("--color-parameter",
    nargs=1,
    action=option_utils.ParseParametersArg,
    metavar='PARAM[:LABEL]',
    help="Color scatter points according to the parameter given. "
         "May optionally provide a label in the same way as for "
         "--parameters. Default=No scatter point coloring.")
parser.add_argument('--dpi',
    type=int,
    default=200,
    help="Set the DPI of the plot. Default is 200.")
parser.add_argument('--fits-file',
    help="Provide a fits file to plot parameters from. Required if any of "
         ', '.join(_fit_parameters) + " are given as parameters.")
parser.add_argument('--title',
    help="A title for the plot. If not given, the files supplied "
         "and the number of templates will be used")

add_style_opt_to_parser(parser)
args = parser.parse_args()

pycbc.init_logging(args.verbose)
set_style_from_cli(args)

mins, maxs = option_utils.plot_ranges_from_cli(args)

logging.info("Reading in the bank")
bank = {}
with HFile(args.bank_file, 'r') as bankf:
    for p in bankf.keys():
        if not isinstance(bankf[p], h5py.Dataset): continue
        # Ignore things which aren't numbers, which cannot
        # be histogrammed - e.g. approximant
        if not np.issubdtype(bankf[p].dtype, np.number): continue
        bank[p] = bankf[p][:]
        banklen = bankf[p].size

if args.fits_file is not None:
    # Add fit parameters to the bank object
    with HFile(args.fits_file, 'r') as fits_f:
        for p in _fit_parameters:
            if not fits_f[p].size == banklen:
                raise RuntimeError(
                    "Fits parameter %s is not the same size (%d) as the "
                    "bank (%d), this looks like the fits file does not "
                    "correspond to this bank" % (p, fits_f[p].size, banklen)
                )
            param = fits_f[p][:].astype(float)
            # We need to check for the cardinal '-1' value which means
            # that the fit is invalid
            param[param <= 0] = 0 if 'count' in p and 'log' not in p else np.nan
            bank[p] = param

logging.info("Got %d templates from the bank", banklen)

# If no parameters are given - just plot whatever is in the bank
if not args.parameters:
    # Plot anything which is not singular
    args.parameters = [k for k in bank
                       if len(np.unique(bank[k])) > 1]
    args.parameters_labels = {k: k for k in bank}

# Check through the conversion options to see if any are used in
# the parameters argument:
cpar = args.color_parameter[0] if args.color_parameter else False
cpar_label = args.color_parameter_labels[cpar] if cpar else False

required_params = []
for co in parameter_options:
    if any([co in par for par in args.parameters]):
        required_params.append(co)
    if cpar and co in cpar:
        required_params.append(co)

# Check for possible double-counting of duration parameter:
if 'duration' in required_params and any([p.endswith('_duration')
                                          for p in required_params]):
    # Check whether duration was specified directly:
    dur_pars = [par for par in args.parameters if 'duration' in par]
    _dur_pars = [par for par in args.parameters if '_duration' in par]
    if len(dur_pars) == len(_dur_pars):
        # It looks like duration has been added where it didnt need to be
        required_params.remove('duration')

# Do the same with the duration functions, but here we need to make
# sure we have some other keys in order to calculate duration
duration_required_keys = ['mass1', 'mass2', 'spin1z',
                          'spin2z', 'f_lower']

if any(['duration' in par for par in args.parameters]):
    required_params += duration_required_keys

# Some things may have been double counted - undo this
required_params = np.unique(required_params)

logging.info("Required parameters to get from bank: %s",
             ', '.join(required_params))

# Get parameters not directly in the bank:
for p in required_params:
    if p in bank: continue
    if p not in parameter_options:
        raise KeyError(f"Parameter {p} not in bank, fits or conversion "
                       "options, choose from bank parameters or " + \
                       ', '.join(parameter_options))
    if p in _fit_parameters and args.fits_file is None:
        parser.error(f"Parameter {p} needs a --fits-file, but none is given")

    logging.info("Converting %s", p)
    bank[p] = bconv.get_bank_property(p, bank, np.arange(banklen))

if args.mpl_style == 'dark_background':
    hist_color = 'white'
else:
    hist_color = 'black'

# All parameters should be in the bank now, check they are the right size:
assert all([len(bank[p]) == banklen for p in bank])

# Change the bank to the format required for the plot:
bank_fa = FieldArray.from_arrays(bank.values(), names=list(bank.keys()))

# Check that the required parameters have min/max values set:
required_minmax = args.parameters[:]
if cpar:
    required_minmax.append(cpar)

# check for min/max values of the color parameter
for p in required_minmax:
    minval = np.nanmin(bank_fa[p][bank_fa[p] != -np.inf])
    maxval = np.nanmax(bank_fa[p][bank_fa[p] != np.inf])
    valrange = maxval - minval

    if p not in mins:
        mins[p] = minval - 0.05 * valrange
    if p not in maxs:
        maxs[p] = maxval + 0.05 * valrange

# Deal with non-coloring case:
zvals = bank_fa[cpar] if cpar else None

logging.info("Generating corner plot")

fig, axis_dict = create_multidim_plot(
    args.parameters,
    bank_fa,
    labels=args.parameters_labels,
    plot_marginal=args.plot_histogram,
    plot_scatter=True,
    plot_contours=False,
    scatter_cmap="viridis",
    marginal_title=False,
    marginal_percentiles=[],
    fill_color='g',
    zvals=zvals,
    show_colorbar=cpar is not False,
    cbar_label=cpar_label,
    vmin=mins[cpar] if cpar else 0,
    vmax=maxs[cpar] if cpar else 0,
    hist_color=hist_color,
    mins=mins,
    maxs=maxs,
)

title_text = f"{os.path.basename(args.bank_file)}"
if args.fits_file is not None:
    title_text += f", {os.path.basename(args.fits_file)}"
title_text += f" - {banklen}\u00a0templates"
fig.suptitle('\n'.join(wrap(args.title if args.title is not None else title_text, 60)))
for k, v in axis_dict.items():
    # Some may be long labels - tilt the label a little to fit together
    xlab = v[0].get_xlabel()
    v[0].set_xlabel(xlab, rotation=10)
    ylab = v[0].get_ylabel()
    v[0].set_ylabel(ylab, rotation=80)
    v[0].grid(zorder=-30)

# Some matplotlib magic to get shared axes after creation
# on each row/column excluding the diagonals
for i in range(len(args.parameters)):
    # Get the axes on row i
    # not statement removes diagonals
    sharey_axes = [v[0] for v in axis_dict.values() if v[1] == i and not v[2] == i]
    # Share the axis to the one to the left of it
    # can't call sharey method on the same axes twice for some reason
    if len(sharey_axes) > 1:
        for s0, s1 in zip(sharey_axes[:-1], sharey_axes[1:]):
            s0.sharey(s1)
    # Get the axes on column i
    # not statement removes diagonals
    sharex_axes = [v[0] for v in axis_dict.values() if v[2] == i and not v[1] == i]
    # Share the axis to the one to the left of it
    # can't call sharex method on the same axes twice for some reason
    if len(sharex_axes) > 1:
        for s0, s1 in zip(sharex_axes[:-1], sharex_axes[1:]):
            s0.sharex(s1)

logging.info("Plot generated")
fig.set_dpi(args.dpi)

# save
caption = ("Template bank as a corner plot with "
           "scatter points for each waveform.")

metadata.save_fig_with_metadata(
                 fig, args.output_plot_file,
                 cmd=" ".join(sys.argv),
                 title="Template bank corner plot",
                 caption=caption,
                 fig_kwds={'bbox_inches': 'tight'})

# finish
logging.info("Done")
