#!/usr/bin/python3.11
""" Plot variation in PSD
"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import h5py, numpy, argparse, sys, math
import pycbc.results, pycbc.types, pycbc.version, pycbc.waveform, pycbc.filter

from pycbc.fft.fftw import set_measure_level
set_measure_level(0)

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--version", action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument("--psd-files", nargs='+', help='HDF file of psds')
parser.add_argument("--output-file", help='output file name')
parser.add_argument("--min_mtot", nargs="+", help="Minimum total mass for range", type=float)
parser.add_argument("--max_mtot", nargs="+", help="Maximum total mass for range", type=float)
parser.add_argument("--d_mtot", nargs="+", help="Delta total mass for range ", type=float)
parser.add_argument("--approximant", nargs="+", help="approximant to use for range")
args = parser.parse_args()

canonical_snr = 8.0

fig = plt.figure(0)
plt.xlabel('Total Mass (M$_{\odot}$)')
plt.ylabel('Inspiral Range (Mpc)')
plt.grid() 

for psd_file in args.psd_files:
    f = h5py.File(psd_file, 'r')
    ifo = tuple(f.keys())[0]
    flow = f.attrs['low_frequency_cutoff']
    keys = f[ifo + '/psds'].keys()
    start, end = f[ifo + '/start_time'][:], f[ifo + '/end_time'][:]
    seglen = numpy.subtract(end, start)
    tott = sum(seglen)
    f.close()
    ranges = {}
    avg_range, rangerr, mbin = [], [], []

    for i in range(len(keys)):
        name = ifo + '/psds/' + str(i)
        psd = pycbc.types.load_frequencyseries(psd_file, group=name)
        delta_t = 1.0 / ((len(psd) - 1) * 2 * psd.delta_f)
        out = pycbc.types.zeros(len(psd), dtype=numpy.complex64)

        for mi, mf, dm, apx in zip(args.min_mtot, args.max_mtot, args.d_mtot, args.approximant):
            for M in numpy.arange(mi, mf, dm):
                htilde = pycbc.waveform.get_waveform_filter(out,
                                     mass1=M/2.,mass2=M/2., approximant=apx,
                                     f_lower=flow, delta_f=psd.delta_f,
                                     delta_t=delta_t, 
                                     distance = 1.0/pycbc.DYN_RANGE_FAC)
                htilde = htilde.astype(numpy.complex64)
                sigma = pycbc.filter.sigma(htilde, psd=psd,
                                           low_frequency_cutoff=flow)
                horizon_distance = sigma / canonical_snr 
                inspiral_range = horizon_distance / 2.26

                if M in ranges:
                    ranges[M].append(inspiral_range)
                else:
                    ranges[M] = [inspiral_range]

    for M in numpy.arange(mi, mf, dm):
        mean = numpy.average(ranges[M], weights=seglen)
        variance = numpy.average((ranges[M]-mean)**2, weights=seglen)
        stddev = math.sqrt(variance)
        avg_range.append(mean), rangerr.append(stddev), mbin.append(M)

    for apx in args.approximant:
        label = '%s-%s' % (ifo, apx)
        plt.errorbar(mbin, avg_range, yerr=rangerr, ecolor=pycbc.results.ifo_color(ifo), label=label, fmt='none')  
        plt.plot(mbin, avg_range, color=pycbc.results.ifo_color(ifo))

plt.legend(loc="upper left")

pycbc.results.save_fig_with_metadata(fig, args.output_file,
    title = "Inspiral Range",
    caption = "The canonical sky-averaged inspiral range for a single "
              "detector at SNR 8 vs total mass:equal mass binary",
    cmd = ' '.join(sys.argv),
    fig_kwds={'dpi':200}
    )
