#!/usr/bin/python3.13 -s
# Copyright (C) 2018--2019 Gareth S. Cabourn Davies
# Copyright (C) 2020 Josh Willis
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""
Plot ratios of VTs calculated at various IFARs using pycbc_page_sensitivity
(run with --hdf-out option) for two comparable analyses
"""

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import argparse
from matplotlib.pyplot import cm
from math import ceil

from pycbc import add_common_pycbc_options, init_logging
from pycbc.io.hdf import HFile

parser = argparse.ArgumentParser(description=__doc__)
add_common_pycbc_options(parser)
parser.add_argument('--vt-file-one', required=True,
                    help="HDF file containing VT curves, first set of data"
                    " for comparison")
parser.add_argument('--vt-file-two', required=True,
                    help="HDF file containing VT curves, second set of data"
                    " for comparison")
parser.add_argument('--outfile', type=str, required=True,
                    help='Output file to save to')
parser.add_argument('--plot-title', type=str, required=True,
                    help='(Possibly) quoted string to be title of plot')
parser.add_argument('--log-x', action='store_true', default=False,
                    help='Use logarithmic x-axis')
parser.add_argument('--log-y', action='store_true', default=False,
                    help='Use logarithmic y-axis')
args = parser.parse_args()

init_logging(args.verbose)

# Load in the two datasets
f1 = HFile(args.vt_file_one)
f2 = HFile(args.vt_file_two)

x1 = f1['xvals'][:]
x2 = f2['xvals'][:]
if not np.array_equal(x1, x2):
   raise RuntimeError("IFAR values are not the same between the two files")

xvals = x1

keys = f1['data'].keys()
# sanitise the input so that the files have the same binning parameter and bins
assert keys == f2['data'].keys(), "keys do not match for the two files"

nkeys = len(keys)
#if nkeys != 6:
#   raise RuntimeError("Only prepared for number of chirp mass bins to be six")

#fig, axs = plt.subplots(3, 2, sharex=True, sharey=True)
nrows = int(ceil(nkeys/2.0))
fig, axs = plt.subplots(nrows, 2, sharex=True, sharey=True)

stitle = fig.suptitle(args.plot_title)

color = iter(cm.rainbow(np.linspace(0, 1, nkeys)))
alpha = 0.6

# loop through each chirp mass bin:
lines = []

for n, key in zip(range(nkeys), keys):
    c = next(color)

    data1 = f1['data'][key][:]
    data2 = f2['data'][key][:]

    errhigh1 = f1['errorhigh'][key][:]
    errlow1 = f1['errorlow'][key][:]

    errhigh2 = f2['errorhigh'][key][:]
    errlow2 = f2['errorlow'][key][:]

    ys = np.divide(data1, data2)
    yerr_errlow = np.multiply(np.sqrt(np.divide(errlow1, data1)**2 +
                np.divide(errlow2, data2)**2), ys)
    yerr_errhigh = np.multiply(np.sqrt(np.divide(errhigh1, data1)**2 +
                np.divide(errhigh2, data2)**2), ys)
    ax = axs[n%nrows, n//nrows]
    ax.axhline(y=1.0, alpha=0.5, ls='--', color='k')
    ax.plot(xvals, ys, c='k')
    l, = ax.plot(xvals, ys, c=c)
    lines.append(l)
    ax.fill_between(xvals, ys-yerr_errlow, ys+yerr_errhigh,
                    facecolor=c, edgecolor=c, alpha=alpha)
    if args.log_x:
       ax.set_xscale('log')

    if args.log_y:
       ax.set_yscale('log')

    print("Bin {0}: minimum ratio {1}, maximum ratio {2}".format(key,
                                                         ys.min(), ys.max()))

for ax in axs.flat:
   ax.set(xlabel='Inverse False Alarm Rate (years)',
          ylabel='VT ratio')

for ax in axs.flat:
   ax.label_outer()


# Positioning the legend seems to be a nightmare
# First, grab the center, rightmost of the axes:
ax = axs[1, 1]
lgd = ax.legend(handles=lines, labels=keys, loc="center left",
                borderaxespad=0.1, title="Chirp mass bins",
                bbox_to_anchor=(1.05,0.5))

fig.savefig(args.outfile, bbox_extra_artists=(lgd, stitle), bbox_inches='tight')
plt.close()
