#!/usr/bin/python3.13 -s

"""
This script compares the sensitivities (VTs) of two searches having consistent
sets of injections. It reads two HDF files produced by pycbc_page_sensitivity's
--hdf-out option, and plots the ratios of their VTs at various IFARs.
"""

import sys
import argparse
import logging
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np

from pycbc import init_logging, add_common_pycbc_options
from pycbc.results import save_fig_with_metadata
from pycbc.io.hdf import HFile

parser = argparse.ArgumentParser(description=__doc__)
add_common_pycbc_options(parser)
parser.add_argument('--vt-files-one', nargs='+',
                    help='HDF files containing VT curves, data for '
                         'the numerator (top) of the ratio')
parser.add_argument('--vt-files-two', nargs='+',
                    help='HDF files containing VT curves, data for '
                         'the denominator (bottom) of the ratio')
parser.add_argument('--desc-one',  required=True,
                    help='Descriptor tag for first set of data '
                         '(short, for use in subscript)')
parser.add_argument('--desc-two', type=str, required=True,
                    help='Descriptor tag for second set of data '
                         '(short, for use in subscript)')
parser.add_argument('--outfile', type=str, required=True,
                    help='Output file to save to')
parser.add_argument('--ifars', type=float, required=True, nargs='+',
                    help='IFAR values to plot VT ratio for. Note that the '
                         'plotted values will be the closest values available '
                         'from the VT files')
parser.add_argument('--log-x', action='store_true',
                    help='Use logarithmic x-axis')
parser.add_argument('--log-y', action='store_true',
                    help='Use logarithmic y-axis')
args = parser.parse_args()

init_logging(args.verbose)

# Warn user if different numbers of files in numerator vs denominator
if len(args.vt_files_one) != len(args.vt_files_two):
    logging.warning(
        'WATCH OUT! You gave different numbers of One and Two files!')


# Load in the first numerator file
with HFile(args.vt_files_one[0], 'r') as ftop_init:
    # Find the index closest to the given IFAR value
    idxs = [np.argmin(np.abs(ftop_init['xvals'][:] - ifv))
            for ifv in args.ifars]
    plot_ifars = ftop_init['xvals'][idxs]
    # Get binning keys for reference
    keys = list(ftop_init['data'].keys())

# Dicts holding data for total VT and variances
vt_top = {k: np.zeros_like(plot_ifars) for k in keys}
vt_top_errsqhi = {k: np.zeros_like(plot_ifars) for k in keys}
vt_top_errsqlow = {k: np.zeros_like(plot_ifars) for k in keys}
vt_bot = {k: np.zeros_like(plot_ifars) for k in keys}
vt_bot_errsqhi = {k: np.zeros_like(plot_ifars) for k in keys}
vt_bot_errsqlow = {k: np.zeros_like(plot_ifars) for k in keys}

# Cycle over inputs for numerator
for ftop in args.vt_files_one:
    with HFile(ftop, 'r') as f:
        # Check the input bins
        if list(f['data'].keys()) != keys:
            raise ValueError(
                f'keys do not match for the given input files - '
                '{keys} v {list(f["data"].keys())}')
        # Add the data
        for k in keys:
            vt_top[k] += f['data'][k][idxs]
            # Variances add over files
            vt_top_errsqhi[k] += f['errorhigh'][k][idxs] ** 2.
            vt_top_errsqlow[k] += f['errorlow'][k][idxs] ** 2.

# Same for denominator
for fbot in args.vt_files_two:
    with HFile(fbot, 'r') as f:
        if list(f['data'].keys()) != keys:
            raise ValueError(
                f'keys do not match for the given input files - '
                '{keys} v {list(f["data"].keys())}')
        for k in keys:
            vt_bot[k] += f['data'][k][idxs]
            vt_bot_errsqhi[k] += f['errorhigh'][k][idxs] ** 2.
            vt_bot_errsqlow[k] += f['errorlow'][k][idxs] ** 2.

# make the plot pretty
plt.rc('axes.formatter', limits=[-3, 4])
plt.rc('figure', dpi=300)
fig_mi = plt.figure(figsize=(10, 4))
ax_mi = fig_mi.gca()

ax_mi.grid(True, zorder=1)

# read in labels for the different plotting points
labels = ['$ ' + label.split('\\in')[-1] for label in keys]

# read in the splitting parameter name from the first data set
x_param = r'$' + tuple(keys)[0].split('\\in')[0].strip('$').strip() + r'$'

# read in the positions from the labels
xpos = np.array([float(l.split('[')[1].split(',')[0]) for l in labels])

# offset different ifars by 1/20th of the mean distance between parameters
try: 
    if args.log_x:
        xpos_logdiffmean = np.diff(np.log(xpos)).mean()
        xpos_add_dx = 0.05 * np.ones_like(xpos) * xpos_logdiffmean
    else:
        xpos_diffmean = np.diff(xpos).mean()
        xpos_add_dx = 0.05 * np.ones_like(xpos) * xpos_diffmean
except IndexError:
    # If there's only one value of xpos, then diff doesn't work
    xpos_add_dx = 0.05

# set the x ticks to be the positions given in the labels
plt.xticks(xpos, labels, rotation='horizontal')

colors = ['#7b85d4', '#f37738', '#83c995', '#d7369e', '#c4c9d8', '#859795']

# loop through each IFAR and plot the VT ratio with error bars
for j in range(len(idxs)):
    data1 = np.array([vt_top[key][j] for key in keys])
    errsqhi1 = np.array([vt_top_errsqhi[key][j] for key in keys])
    errsqlow1 = np.array([vt_top_errsqlow[key][j] for key in keys])
    
    data2 = np.array([vt_bot[key][j] for key in keys])
    errsqhi2 = np.array([vt_bot_errsqhi[key][j] for key in keys])
    errsqlow2 = np.array([vt_bot_errsqlow[key][j] for key in keys])

    ys = data1 / data2
    # fractional error propagation
    yerr_low = (errsqlow1 / (data1**2.) + errsqlow2 / (data2**2.)) ** 0.5 * ys
    yerr_hi = (errsqhi1 / (data1**2.) + errsqhi2 / (data2**2.)) ** 0.5 * ys

    if args.log_x:
        xvals = np.exp(np.log(xpos) +
                       xpos_add_dx * (j - float(len(args.ifars) - 1) / 2.))
    else:
        xvals = xpos + xpos_add_dx * (j - float(len(args.ifars) - 1) / 2.)
    ax_mi.errorbar(xvals, ys,
        yerr=[yerr_low, yerr_hi], fmt='o', markersize=7, linewidth=5,
        label='IFAR = %d yr' % plot_ifars[j], capsize=5,
        capthick=2, mec='k', color=colors[j % len(colors)])

if args.log_x:
    plt.xscale('log')
if args.log_y:
    plt.yscale('log')
plt.xticks(xpos, labels, rotation='horizontal')

# get the limit of the x axes, and draw a black line in order to highlight
# equal comparison
xlimits = plt.xlim()
plt.plot(xlimits, [1, 1], 'k', lw=2, zorder=0)
plt.xlim(xlimits) # reassert the x limits so that the plot doesn't expand

ax_mi.legend(bbox_to_anchor=(0.5, 1.01), ncol=len(args.ifars),
             loc='lower center')
ax_mi.get_legend().get_title().set_fontsize('14')
ax_mi.get_legend().get_frame().set_alpha(0.7)
ax_mi.set_xlabel(x_param, size='large')
ax_mi.set_ylabel(r'$\frac{VT(\mathrm{' + args.desc_one +'})}\
                         {VT(\mathrm{' + args.desc_two +'})}$', 
                 size='large')
plt.tight_layout()

title = f'VT sensitivity comparison between {args.desc_one} and ' \
        f'{args.desc_two}'
save_fig_with_metadata(fig_mi, args.outfile, cmd=' '.join(sys.argv),
                       title=title)

plt.close()

