#!/usr/bin/python3.13 -s

# Copyright (C) 2018 Ian Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Plot the "Q-scan" time-frequency representation of single-detector strain data.

See https://iopscience.iop.org/article/10.1088/0264-9381/21/20/024 for
information on the Q-transform and its parameters.
"""

import sys
import argparse
import logging
import numpy

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm

import pycbc.strain
import pycbc.results

# https://stackoverflow.com/questions/9978880/python-argument-parser-list-of-list-or-tuple-of-tuples
def t_window(s):
    try:
        start, end = map(float, s.split(','))
        return [start, end]
    except:
        raise argparse.ArgumentTypeError("Input must be start,end start,end")

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--output-file', required=True, help='Output plot')
parser.add_argument('--center-time', type=float,
                    help='Center plot on the given GPS time. If omitted, use '
                         'center of interval set via --gps-start-time and '
                         '--gps-end-time')
parser.add_argument('--time-windows', required=True, type=t_window,
                    nargs='+', metavar='START,END',
                    help='Use these set of times for time windows. Each '
                         'window produces a different panel. Should '
                         'be provided as start1,end1 start2,end2 ... '
                         'where start/end times are relative to --center-time '
                         'and both positive')
parser.add_argument("--low-frequency-cutoff", type=float,
                    help="The low frequency cutoff to use for fake strain generation")

parser.add_argument('--qtransform-delta-t', default=0.001, type=float,
                    help='The time resolution to interpolate to (optional)')
parser.add_argument('--qtransform-delta-f', type=float,
                    help='Frequency resolution to interpolate to (optional)')
parser.add_argument('--qtransform-logfsteps', type=int, default=200,
                    help='Do a log interpolation (incompatible with '
                         '--qtransform-delta-f option) and set the number '
                         'of steps to take')
parser.add_argument('--qtransform-frange-lower', type=float,
                    help='Lower frequency (in Hz) at which to compute '
                         'qtransform. Optional, default=20 Hz')
parser.add_argument('--qtransform-frange-upper', type=float,
                    help='Upper frequency (in Hz) at which to compute '
                         'qtransform. Optional, default=Half of Nyquist')
parser.add_argument('--qtransform-qrange-lower', default=4, type=float,
                    help='Lower limit of the range of q to consider, '
                         'default=%(default)f')
parser.add_argument('--qtransform-qrange-upper', default=64, type=float,
                    help='Upper limit of the range of q to consider, '
                         'default=%(default)f')
parser.add_argument('--qtransform-mismatch', default=0.2, type=float,
                    help='Mismatch between frequency tiles, default=%(default)f')

parser.add_argument('--linear-y-axis', dest='log_y', default=True,
                    action='store_false',
                    help='Use a linear y-axis. By default a log axis is used.')
parser.add_argument('--linear-colorbar', dest='log_colorbar', default=True,
                    action='store_false',
                    help='Use a linear colorbar scale.')
parser.add_argument('--plot-title',
                    help="If given, use this as the plot title")
parser.add_argument('--plot-caption',
                    help="If given, use this as the plot caption")
parser.add_argument('--colormap', choices=plt.colormaps(), default='viridis',
                    help='Colormap to use (default %(default)s)')
parser.add_argument('--mass1', type=float,
                    help='Provide masses to plot the waveform track on top of '
                         'the qscan.')
parser.add_argument('--mass2', type=float,
                    help='Provide masses to plot the waveform track on top of '
                         'the qscan.')
parser.add_argument('--spin1z', type=float, default=0,
                    help='If masses provided, provide spins to plot the '
                         'waveform track on top of the qscan (default=0).')
parser.add_argument('--spin2z', type=float, default=0,
                    help='If masses provided, provide spins to plot the '
                         'waveform track on top of the qscan (default=0).')

pycbc.strain.insert_strain_option_group(parser)
opts = parser.parse_args()

pycbc.init_logging(opts.verbose, default_level=1)

if opts.center_time is None:
    center_time = (opts.gps_start_time + opts.gps_end_time) / 2.
else:
    center_time = opts.center_time

if center_time == -1.0:
    fig, axes = plt.subplots(1,1)
    fig.add_subplot(111, frameon=False)
    plt.tick_params(labelcolor='none', top=False, bottom=False, left=False,
                    right=False)
    plt.grid(False)
    plt.xlabel('Time from {:.3f} (s)'.format(opts.center_time))
    plt.ylabel('Frequency (Hz)')
    title = 'No data for this detector'
    pycbc.results.save_fig_with_metadata(
            fig, opts.output_file, cmd=' '.join(sys.argv),
            fig_kwds={'dpi': 150}, title=title, caption=opts.plot_caption)
    sys.exit(0)

plot_chirp = False
if opts.mass1 is not None or opts.mass2 is not None:
    if None in [opts.mass1, opts.mass2]:
        parser.error('Must provide either both --mass1 and --mass2 or neither')
    plot_chirp = True

strain = pycbc.strain.from_cli(opts, pycbc.DYN_RANGE_FAC)

if opts.qtransform_frange_upper is None and \
        opts.qtransform_frange_lower is None:
    curr_frange = (20, opts.sample_rate / 4.)
elif opts.qtransform_frange_upper is None or \
          opts.qtransform_frange_lower is None:
    parser.error('Must provide either both --qtransform-frange-upper and '
                 '--qtransform-frange-lower or neither option.')
else:
    curr_frange = (opts.qtransform_frange_lower, opts.qtransform_frange_upper)

rem_corrupted = True
if (center_time - strain.start_time) < 2 or (strain.end_time - center_time) < 2:
    rem_corrupted = False

strain = strain.whiten(4, 4, remove_corrupted=rem_corrupted)

wins = opts.time_windows
fig, axes = plt.subplots(len(wins), 1)
times_all = []
freqs_all = []
qvals_all = []

qrange = (opts.qtransform_qrange_lower, opts.qtransform_qrange_upper)

for curr_idx in range(len(wins)):
    curr_win = wins[curr_idx]
    # Catch the case that not enough data is available.
    if opts.center_time - curr_win[0] < strain.start_time:
        curr_win[0] = float(opts.center_time - strain.start_time - 0.01)
    if opts.center_time + curr_win[1] > strain.end_time:
        curr_win[1] = float(strain.end_time - opts.center_time - 0.01)
    strain_zoom = strain.time_slice(opts.center_time - curr_win[0],
                                    opts.center_time + curr_win[1])

    times, freqs, qvals = strain_zoom.qtransform(
            delta_t=opts.qtransform_delta_t, delta_f=opts.qtransform_delta_f,
            logfsteps=opts.qtransform_logfsteps, frange=curr_frange,
            qrange=qrange, mismatch=opts.qtransform_mismatch)
    times_all.append(times)
    freqs_all.append(freqs)
    qvals_all.append(qvals)

max_qval = max([qvals.max() for qvals in qvals_all])

for curr_idx in range(len(wins)):
    ax = axes[curr_idx] if len(wins) > 1 else axes
    times = times_all[curr_idx]
    freqs = freqs_all[curr_idx]
    qvals = qvals_all[curr_idx]
    curr_win = wins[curr_idx]

    norm = None
    if opts.log_colorbar:
        norm = LogNorm(vmin=1, vmax=max_qval)

    im = ax.pcolormesh(times - opts.center_time, freqs, qvals, norm=norm,
                       cmap=opts.colormap)
    ax.set_xlim(-curr_win[0], curr_win[1])
    ax.set_ylim(curr_frange[0], curr_frange[1])
    if opts.log_y:
        ax.set_yscale('log')

# https://stackoverflow.com/questions/6963035/pyplot-axes-labels-for-subplots
fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top=False, bottom=False, left=False,
                right=False)
plt.grid(False)
plt.xlabel('Time from {:.3f} (s)'.format(opts.center_time))
plt.ylabel('Frequency (Hz)')

# https://stackoverflow.com/questions/13784201/matplotlib-2-subplots-1-colorbar
cb = fig.colorbar(im, ax=(axes.ravel().tolist() if len(wins) > 1 else axes))
cb.set_label('Normalized power')

if plot_chirp:
    from pycbc.pnutils import get_inspiral_tf
    f_low = 20.
    approximant = 'SPAtmplt' if opts.mass1+opts.mass2<4 else 'SEOBNRv4_ROM'
    track_t, track_f = get_inspiral_tf(opts.center_time, opts.mass1,
                                       opts.mass2, opts.spin1z, opts.spin2z,
                                       f_low, approximant=approximant)
    for curr_idx in range(len(wins)):
        ax = axes[curr_idx] if len(wins) > 1 else axes
        ax.plot(track_t - opts.center_time, track_f, 'r-', lw=1.5)

if opts.channel_name is not None:
    plt.suptitle(opts.channel_name)

if opts.plot_title is None:
    opts.plot_title = 'Q-transform plot around {:.3f}'.format(opts.center_time)
if opts.plot_caption is None:
    opts.plot_caption = ("This shows the Q-transform as a function of time and "
                         "frequency.")
    if opts.channel_name is not None:
        opts.plot_caption += (' The strain channel is ' + opts.channel_name
                              + '.')
    if plot_chirp:
        chirp_caption = (' The red curve is the time-frequency curve of a'
                         ' quadrupole-order quasicircular aligned-spin'
                         ' inspiral with mass1={:.3f}, mass2={:.3f},'
                         ' spin1z={:.3f}, spin1z={:.3f}.')
        opts.plot_caption += chirp_caption.format(opts.mass1, opts.mass2,
                                                  opts.spin1z, opts.spin2z)

pycbc.results.save_fig_with_metadata(
        fig, opts.output_file, cmd=' '.join(sys.argv), fig_kwds={'dpi': 150},
        title=opts.plot_title, caption=opts.plot_caption)
