#!/usr/bin/python3.13 -s

# Copyright (C) 2017 Vaibhav Tiwari

# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
This script reads the rates priors and posteriors and plots them
"""

__author__ = "Vaibhav Tiwari"
__email__ = "vaibhav.tiwari@ligo.org"
__version__ = "0.0"
__date__ = "31.10.2017"

import logging
import argparse
from numpy import logaddexp, log, newaxis, expm1
import numpy as np
from matplotlib import cm
import pylab

import scipy.stats as ss

import pycbc
import pycbc.version
from pycbc.io.hdf import HFile
from pycbc.population import rates_functions as rf

# Parse command line
parser = argparse.ArgumentParser(
    description=__doc__
)
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--prior-samples', required=True,
                              help="File containing rate prior samples")
parser.add_argument('--posterior-samples', nargs='+', required=True,
             help="Files(s) containing samples for the rate posterior.")
parser.add_argument('--population-models', nargs='+', required=True,
        help="Models to which posteriors belong ('lnm', 'imf', 'bns').")
parser.add_argument('--output-rates', dest='rate_file', required=True,
                   help="File saving all the rate posteriors into one.")
parser.add_argument('--rates-figure', required=True,
                 help="Name of file to draw rates prior and posterior.")
parser.add_argument('--pastro-figure', required=True,
                            help="Name of file to save p_astro figure.")
parser.add_argument('--plot-labels', nargs='+',
              required=True, help="Labels for the population models.")

opts = parser.parse_args()
pycbc.init_logging(opts.verbose)

assert len(opts.posterior_samples) == len(opts.population_models), \
              "Unequal number of posterior files and population models!"

#Save rate posteriors in one file
with HFile(opts.rate_file, 'w') as out:
    for fname, model in zip(opts.posterior_samples, opts.population_models):

        f = HFile(fname, "r")
        Rf = f[model + '/Rf'][:]
        Lf = f[model + '/Lf'][:]

        pl = out.create_group(model)
        pl.create_dataset('Rf', data=Rf, compression='gzip')
        pl.create_dataset('Lf', data=Lf, compression='gzip')

        f.close()

# Make prior/posterior plot -- estimate p_astro
p_astro = []
pylab.figure()
color=iter(cm.rainbow(np.linspace(0, 1, len(opts.population_models))))
mods = zip(opts.posterior_samples, opts.population_models, opts.plot_labels)

f = HFile(opts.prior_samples, "r")
for fpost, model, lbl in mods:

    c = next(color)

    fpo = HFile(fpost, "r")
    Rfpr, Rfpo = f[model + '/Rf'][:], fpo[model + '/Rf'][:]
    prior_alpha, prior_mu, prior_sigma = rf.fit(Rfpr)
    post_alpha, post_mu, post_sigma = rf.fit(Rfpo)

    log_R = np.log(Rfpr)
    xs = np.linspace(min(log_R), max(log_R), 200)
    pylab.plot(np.exp(xs), ss.skewnorm.pdf(xs, prior_alpha, prior_mu,
            prior_sigma), '--', label=lbl + ' Prior', color=c)
    pylab.plot(np.exp(xs), ss.skewnorm.pdf(xs, post_alpha, post_m,
               post_sigma), label=lbl + ' Posterior', color=c)

    Lfpo, Lbpo = fpo[model + '/Lf'][:], fpo[model + '/Lb'][:]
    log_fg_ratios = fpo['data/log_fg_bg_ratio'][:]

    log_pastros = logaddexp.reduce(log(Lfpo[:, newaxis]) +\
                 log_fg_ratios[newaxis,:] - logaddexp(log(Lfpo[:, newaxis]) +\
                 log_fg_ratios[newaxis,:], log(Lbpo[:, newaxis])), axis=0) -\
                 log(Lfpo.shape[0])
    p_astro.append(1 + expm1(np.sort(log_pastros)[::-1]))

    fpo.close()
f.close()

pylab.xscale('log')
pylab.xlabel(r'$R$ ($\mathrm{Gpc}^{-3} \, \mathrm{yr}^{-1}$)')
pylab.ylabel(r'$RP(R)$')
pylab.legend(loc='best')
pylab.savefig(opts.rates_figure)

pylab.figure()
color=iter(cm.rainbow(np.linspace(0, 1, len(opts.population_models))))
for pas, lbl in zip(p_astro, opts.plot_labels):
    c = next(color)
    pylab.plot(log_fg_ratios, 1 - pas, '.', label = lbl, color = c)

pylab.xlabel(r'$\log p(x\mid f)/p(x\mid b)$')
pylab.ylabel(r'$1-p_\mathrm{astro}$')
pylab.legend(loc='best')
pylab.yscale('log')
pylab.savefig(opts.pastro_figure)
print(p_astro)
