#!/usr/bin/python3.13 -s
""" plot a an hdf bank file based on background binning
"""
import sys
import argparse
import h5py
import numpy
import matplotlib
matplotlib.use('Agg')
import pylab
import inspect
from itertools import cycle

import pycbc.events, pycbc.pnutils, pycbc.conversions, pycbc.results


class H5BankFile(h5py.File):
    "Convenience class for getting CBC parameters out of an HDF5 bank."

    @classmethod
    def get_param_names(cls):
        """Returns a list of CBC parameters which can be obtained
        from the class instance."""
        return sorted([m[0].replace('_param', '') for m in inspect.getmembers(cls)
                       if m[0].endswith('_param')])

    def __len__(self):
        return len(self['mass1'])

    def mass1_param(self):
        'Mass 1 $M_\odot$'
        return self['mass1'][:]

    def mass2_param(self):
        'Mass 2 $M_\odot$'
        return self['mass2'][:]

    def spin1z_param(self):
        'Spin 1 z-component'
        return self['spin1z'][:]

    def spin2z_param(self):
        'Spin 2 z-component'
        return self['spin2z'][:]

    def chirp_mass_param(self):
        'Chirp mass $M_\odot$'
        return pycbc.pnutils.mass1_mass2_to_mchirp_eta(
                self.mass1_param(), self.mass2_param())[0]

    def total_mass_param(self):
        'Total mass $M_\odot$'
        return self.mass1_param() + self.mass2_param()

    def mass_ratio_param(self):
        'Mass ratio'
        return self.mass1_param() / self.mass2_param()

    def eta_param(self):
        'Symmetric mass ratio'
        return pycbc.conversions.eta_from_mass1_mass2(
                self.mass1_param(), self.mass2_param())

    def effective_spin_param(self):
        'Effective spin z-component'
        return (self.spin1z_param() * self.mass1_param() + self.spin2z_param() * self.mass2_param()) / self.total_mass_param()

    def tau0_param(self):
        '$\\tau_0$'
        return pycbc.pnutils.mass1_mass2_to_tau0_tau3(
                self.mass1_param(), self.mass2_param(), self.f_lower)[0]

    def tau3_param(self):
        '$\\tau_3$'
        return pycbc.pnutils.mass1_mass2_to_tau0_tau3(
                self.mass1_param(), self.mass2_param(), self.f_lower)[1]


parser = argparse.ArgumentParser()
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--bank-file', help='hdf format template bank file',
                    required=True)
parser.add_argument('--background-bins', nargs='+',
                    help='list of background bin format strings')
parser.add_argument('--f-lower', type=float,
                    help="Lower frequency cutoff for evaluating template "
                         "duration. Should be equal to the lower cutoff "
                         "used in inspiral jobs")
parser.add_argument('--output-file', help='output file', required=True)
parser.add_argument('--x-var', type=str, choices=H5BankFile.get_param_names(),
                    default='mass1',
                    help='Template parameter to plot on the x-axis')
parser.add_argument('--y-var', type=str, choices=H5BankFile.get_param_names(),
                    default='mass2',
                    help='Template parameter to plot on the y-axis')
parser.add_argument('--log-x', action='store_true',
                    help='Make x-axis logarithmic')
parser.add_argument('--log-y', action='store_true',
                    help='Make y-axis logarithmic')
args = parser.parse_args()

pycbc.init_logging(args.verbose)

bank = H5BankFile(args.bank_file, 'r')
f_lower = args.f_lower or bank['f_lower'][:]

if args.background_bins:
    data = {'mass1': bank['mass1'][:], 'mass2': bank['mass2'][:],
            'spin1z': bank['spin1z'][:], 'spin2z': bank['spin2z'][:],
            'f_lower': f_lower}
    locs_dict = pycbc.events.background_bin_from_string(args.background_bins, data)
else:
    locs_dict = {'Template Bank': numpy.arange(0, len(bank), 1)}

color = cycle(['red', 'green', 'blue', 'purple'])

x_var = getattr(bank, args.x_var + '_param')()
x_var_name = getattr(bank, args.x_var + '_param').__doc__
y_var = getattr(bank, args.y_var + '_param')()
y_var_name = getattr(bank, args.y_var + '_param').__doc__

fig = pylab.figure()
pylab.grid()
for name in locs_dict:
    locs = locs_dict[name]
    pylab.scatter(x_var[locs], y_var[locs], label=name, edgecolor='none', s=1,
                  c=next(color))

pylab.legend(loc='upper left', markerscale=5)
pylab.xlabel(x_var_name)
pylab.ylabel(y_var_name)
pylab.xlim(x_var.min(), x_var.max())
pylab.ylim(y_var.min(), y_var.max())
if args.log_x:
    pylab.xscale('log')
if args.log_y:
    pylab.yscale('log')

title = "Template Bank and Bins Used to Compute Background"
caption = """This plot shows the template bank in the {x_var}-{y_var} plane.
Each template is colored by the bin that it is placed in to
compute the search background. Note that the bins may be chosen
in a space higher than two dimensions for spinning templates,
causing apparent overlap in the {x_var}-{y_var} plane."""
caption = caption.format(x_var=args.x_var, y_var=args.y_var)
pycbc.results.save_fig_with_metadata(fig, args.output_file, title=title,
                                     caption=caption, cmd=' '.join(sys.argv))
