#!/usr/bin/python3.13 -s

# Copyright (C) 2015 Tito Dal Canton
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Plot PyCBC's single-detector triggers over the search parameter space.
"""

import logging
import argparse
import numpy as np
import matplotlib
matplotlib.use('agg')
import pylab as pl
from matplotlib.colors import LogNorm
from matplotlib.ticker import LogLocator
import h5py
import sys

import pycbc
import pycbc.pnutils
import pycbc.events
import pycbc.results
import pycbc.io
from pycbc.events import ranking

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--single-trig-file', required=True,
                    help='Path to file containing single-detector triggers in '
                         'HDF5 format. Required')
parser.add_argument('--bank-file', required=True,
                    help='Path to file containing template bank in HDF5 format'
                         '. Required')
parser.add_argument('--veto-file', type=str,
                    help='Optional path to file containing veto segments')
parser.add_argument('--segment-name', default=None, type=str,
                    help='Optional, name of segment list to use for vetoes')
parser.add_argument('--filter-string', default=None, type=str,
                    help='Optional, boolean expression for filtering triggers')
parser.add_argument('--output-file', type=str, required=True,
                    help='Destination path for plot')
parser.add_argument('--x-var', required=True,
                    choices=pycbc.io.SingleDetTriggers.get_param_names(),
                    help='Parameter to plot on the x-axis. Required')
parser.add_argument('--y-var', required=True,
                    choices=pycbc.io.SingleDetTriggers.get_param_names(),
                    help='Parameter to plot on the y-axis. Required')
ranking_keys = list(ranking.sngls_ranking_function_dict.keys())
parser.add_argument('--z-var', required=True,
                    choices=['density'] + ranking_keys,
                    help='Quantity to plot on the color scale. Required')
parser.add_argument('--detector', required=True,
                    help='Detector. Required')
parser.add_argument('--grid-size', type=int, default=80,
                    help='Bin resolution (larger = smaller bins)')
parser.add_argument('--log-x', action='store_true',
                    help='Use log scale for x-axis')
parser.add_argument('--log-y', action='store_true',
                    help='Use log scale for y-axis')
parser.add_argument('--min-x', type=float, help='Optional minimum x value')
parser.add_argument('--max-x', type=float, help='Optional maximum x value')
parser.add_argument('--min-y', type=float, help='Optional minimum y value')
parser.add_argument('--max-y', type=float, help='Optional maximum y value')
parser.add_argument('--min-z', type=float, help='Optional minimum z value')
parser.add_argument('--max-z', type=float, help='Optional maximum z value')
opts = parser.parse_args()

pycbc.init_logging(opts.verbose)

if opts.z_var == 'density' or opts.min_z is None:
    filter_rank = None
    filter_thresh = None
else:
    # We can apply a/another filter on the z value
    filter_rank = opts.z_var
    filter_thresh = opts.min_z

trigs = pycbc.io.SingleDetTriggers(
    opts.single_trig_file,
    opts.detector,
    bank_file=opts.bank_file,
    veto_file=opts.veto_file,
    segment_name=opts.segment_name,
    filter_func=opts.filter_string,
    filter_rank=filter_rank,
    filter_threshold=filter_thresh
)

# Can this be folded into the SingleDetTriggers call? Maybe, but we can
# do this later if needed
x = getattr(trigs, opts.x_var)
y = getattr(trigs, opts.y_var)

mask = np.ones(len(x), dtype=bool)
if opts.min_x is not None:
    mask = np.logical_and(mask, x >= opts.min_x)
if opts.max_x is not None:
    mask = np.logical_and(mask, x <= opts.max_x)
if opts.min_y is not None:
    mask = np.logical_and(mask, y >= opts.min_y)
if opts.max_y is not None:
    mask = np.logical_and(mask, y <= opts.max_y)
x = x[mask]
y = y[mask]

title = f'{opts.z_var.title()} of {opts.detector} triggers ' + \
        f'over {opts.x_var.title()} and {opts.y_var.title()}'
fig_caption = f"This plot shows the {opts.z_var} of single detector " + \
              f"triggers for the {opts.detector} detector. " + \
              f"{opts.z_var.title()} is shown on the colorbar axis " + \
              f"against {opts.x_var} and {opts.y_var} on the x- and y-axes."

if not any(mask):
    # All triggers removed - make a blank plot which says so:
    fig = pl.figure()
    ax = fig.gca()
    pl.text(0.5, 0.5, 'no triggers in the range')

    pycbc.results.save_fig_with_metadata(
        fig,
        opts.output_file,
        title=title,
        caption=fig_caption,
        cmd=' '.join(sys.argv),
        fig_kwds={'dpi': 200}
    )
    sys.exit()

hexbin_style = {
    'gridsize': opts.grid_size,
    'linewidths': 0.03
}

# In earlier versions mpl will try to take the max over bins with 0 triggers
# and fail, unless we tell it to leave these blank by setting mincnt
if matplotlib.__version__ < '3.8.1':
    hexbin_style['mincnt'] = 0

if opts.log_x:
    hexbin_style['xscale'] = 'log'
if opts.log_y:
    hexbin_style['yscale'] = 'log'
minz = opts.min_z if opts.min_z else 1
maxz = opts.max_z
hexbin_style['norm'] = LogNorm(vmin=minz, vmax=maxz)

logging.info('Plotting')
fig = pl.figure()
ax = fig.gca()

if opts.z_var == 'density':
    hb = ax.hexbin(x, y, **hexbin_style)
    fig.colorbar(hb, ticks=LogLocator(subs=range(10)))
elif opts.z_var in ranking.sngls_ranking_function_dict:
    cb_style = {}
    z = trigs.get_ranking(opts.z_var)

    z = z[mask]
    min_z = z.min() if opts.min_z is None else opts.min_z
    max_z = z.max() if opts.max_z is None else opts.max_z
    if max_z / min_z > 10:
        cb_style['ticks'] = LogLocator(subs=range(10))
    hb = ax.hexbin(x, y, C=z, reduce_C_function=np.max, **hexbin_style)
    fig.colorbar(hb, **cb_style)
else:
    raise RuntimeError('z_var = %s is not recognized!' % (opts.z_var))

ax.set_xlabel(opts.x_var)
ax.set_ylabel(opts.y_var)
ax.set_title(opts.z_var.title() + ' of %s triggers ' % (opts.detector))
pycbc.results.save_fig_with_metadata(fig, opts.output_file, title=title,
                                     caption=fig_caption, cmd=' '.join(sys.argv),
                                     fig_kwds={'dpi': 200})

logging.info('Done')
