#!/usr/bin/python3.13 -s
""" Plot variation in PSD
"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy
import argparse
import sys
import math

import pycbc.results
import pycbc.types
import pycbc.waveform
import pycbc.filter
from pycbc.io.hdf import HFile

from pycbc.fft.fftw import set_measure_level
set_measure_level(0)

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--psd-files", nargs='+',
                    help='HDF file of psds')
parser.add_argument("--output-file", help='output file name')
parser.add_argument("--min_mtot", nargs="+", type=float,
                    help="Minimum total mass for range")
parser.add_argument("--max_mtot", nargs="+", type=float,
                    help="Maximum total mass for range")
parser.add_argument("--d_mtot", nargs="+", type=float,
                    help="Delta total mass for range ")
parser.add_argument("--approximant", nargs="+",
                    help="approximant to use for range")
args = parser.parse_args()

pycbc.init_logging(args.verbose)

canonical_snr = 8.0

fig = plt.figure(0)
plt.xlabel('Total Mass (M$_{\odot}$)')
plt.ylabel('Inspiral Range (Mpc)')
plt.grid() 

for psd_file in args.psd_files:
    f = HFile(psd_file, 'r')
    ifo = tuple(f.keys())[0]
    flow = f.attrs['low_frequency_cutoff']
    keys = f[ifo + '/psds'].keys()
    start, end = f[ifo + '/start_time'][:], f[ifo + '/end_time'][:]
    seglen = numpy.subtract(end, start)
    tott = sum(seglen)
    f.close()
    ranges = {}
    avg_range, rangerr, mbin = [], [], []

    for i in range(len(keys)):
        name = ifo + '/psds/' + str(i)
        psd = pycbc.types.load_frequencyseries(psd_file, group=name)
        delta_t = 1.0 / ((len(psd) - 1) * 2 * psd.delta_f)
        out = pycbc.types.zeros(len(psd), dtype=numpy.complex64)

        for mi, mf, dm, apx in zip(args.min_mtot, args.max_mtot, args.d_mtot, args.approximant):
            for M in numpy.arange(mi, mf, dm):
                htilde = pycbc.waveform.get_waveform_filter(out,
                                     mass1=M/2.,mass2=M/2., approximant=apx,
                                     f_lower=flow, delta_f=psd.delta_f,
                                     delta_t=delta_t, 
                                     distance = 1.0/pycbc.DYN_RANGE_FAC)
                htilde = htilde.astype(numpy.complex64)
                sigma = pycbc.filter.sigma(htilde, psd=psd,
                                           low_frequency_cutoff=flow)
                horizon_distance = sigma / canonical_snr 
                inspiral_range = horizon_distance / 2.26

                if M in ranges:
                    ranges[M].append(inspiral_range)
                else:
                    ranges[M] = [inspiral_range]

    for M in numpy.arange(mi, mf, dm):
        mean = numpy.average(ranges[M], weights=seglen)
        variance = numpy.average((ranges[M]-mean)**2, weights=seglen)
        stddev = math.sqrt(variance)
        avg_range.append(mean), rangerr.append(stddev), mbin.append(M)

    for apx in args.approximant:
        label = '%s-%s' % (ifo, apx)
        plt.errorbar(mbin, avg_range, yerr=rangerr, ecolor=pycbc.results.ifo_color(ifo), label=label, fmt='none')  
        plt.plot(mbin, avg_range, color=pycbc.results.ifo_color(ifo))

plt.legend(loc="upper left")

pycbc.results.save_fig_with_metadata(fig, args.output_file,
    title = "Inspiral Range",
    caption = "The canonical sky-averaged inspiral range for a single "
              "detector at SNR 8 vs total mass:equal mass binary",
    cmd = ' '.join(sys.argv),
    fig_kwds={'dpi':200}
    )
