#!/usr/bin/python3.13 -s

import numpy, argparse
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from pycbc import init_logging, add_common_pycbc_options
from pycbc.io import HFile
import logging

parser = argparse.ArgumentParser(description=__doc__)
add_common_pycbc_options(parser)
parser.add_argument('--signal-file')
parser.add_argument('--template-file', required=True)
parser.add_argument('--param', nargs='+', required=True,
                    help='Specify one parameter name for a kde_vs_param plot, or '
                         'two parameter names for a param_vs_param plot. Param '
                         'names must exist as datasets in the input files')
parser.add_argument('--log-axis', nargs='+', choices=['True', 'False'], required=True,
                    help='For each parameter, specify True for a log axis and False '
                         'for a linear axis')
parser.add_argument('--plot-type', choices=['kde_vs_param', 'param_vs_param'])
parser.add_argument('--plot-order', choices=['file', 'increasing', 'decreasing'],
                    default='file',
                    help='Choose the order to plot KDE values in: "file", "increasing"'
                         ' or "decreasing"')
parser.add_argument('--which-kde', choices=['signal_kde', 'template_kde', 'ratio_kde'])
parser.add_argument('--plot-dir', required=True)
args = parser.parse_args()

init_logging(args.verbose)

if args.plot_type == 'kde_vs_param':
    if len(args.param) != 1:
        parser.error('For kde_vs_param, give exactly one parameter name')
else:
    if len(args.param) != 2:
        parser.error('For param_vs_param, give exactly two parameter names')
if len(args.param) != len(args.log_axis):
    parser.error('Must specify either log (True) or non-log (False) for each parameter')

if args.signal_file:
    signal_data = HFile(args.signal_file, 'r')
    signal_kde = signal_data['data_kde'][:]
template_data = HFile(args.template_file, 'r')
template_kde = template_data['data_kde'][:]
param_arrays = [template_data[param][:] for param in args.param]

kde_values = {
    'signal_kde': signal_kde,
    'template_kde': template_kde,
    'ratio_kde': signal_kde / template_kde,
}[args.which_kde]

# Sort each of the parameter arrays
if args.plot_order == 'increasing':
    idx = numpy.argsort(kde_values)
elif args.plot_order == 'decreasing':
    idx = numpy.argsort(kde_values)[::-1]
else:
    idx = numpy.arange(len(kde_values))
param_arrays = [param[idx] for param in param_arrays]
param0 = param_arrays[0]
param1 = param_arrays[1] if len(param_arrays) > 1 else None
kde_values = kde_values[idx]

if args.plot_type == 'kde_vs_param':
    fig, ax = plt.subplots(1, figsize=(12,7), constrained_layout=True)
    im = ax.scatter(kde_values, param0, marker=".", c="r", s=5)
    ax.set_xticklabels(args.which_kde, fontsize=13)
    ax.set_yticklabels(args.param[0], fontsize=13)
    ax.set_xlabel(args.which_kde, fontsize=15)
    ax.set_ylabel(args.param[0], fontsize=15)
    ax.set_xscale('log')
    if args.log_axis[0] == 'True':
        ax.set_yscale('log')
    else:
        ax.set_yscale('linear')
    plot_loc = args.plot_dir + args.which_kde + '_vs_' + args.param[0] + '.png'
    plt.savefig(plot_loc)

elif args.plot_type == 'param_vs_param':
    fig, ax = plt.subplots(1, figsize=(12,7), constrained_layout=True)
    im = ax.scatter(param0, param1, marker=".", c=kde_values, cmap='turbo', s=5, norm=LogNorm())
    cbar = fig.colorbar(im, ax=ax, pad=0.01)
    ax.set_xticklabels(args.param[0], fontsize=13)
    ax.set_yticklabels(args.param[1], fontsize=13)
    ax.set_xlabel(args.param[0], fontsize=15)
    ax.set_ylabel(args.param[1], fontsize=15)
    if args.log_axis[0] == 'True':
        ax.set_xscale('log')
    else:
        ax.set_xscale('linear')
    if args.log_axis[1] == 'True':
        ax.set_yscale('log')
    else:
        ax.set_yscale('linear')
    cbar.ax.set_ylabel(args.which_kde, rotation=270, fontsize=15, labelpad=15)
    cbar.ax.tick_params(labelsize=15)
    plot_loc = f'{args.plot_dir}/{args.which_kde}_{args.plot_order}_{args.param[0]}_vs_{args.param[1]}.png'
    plt.savefig(plot_loc)

else:
    raise RuntimeError('Unknown plot type!', args.plot_type)
