#!/usr/bin/python3.13 -s

# Copyright (C) 2021 Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Determine efficiency and exclusion distances of a PyGRB run."""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import json
import matplotlib.pyplot as plt
from matplotlib import rc
import numpy as np
import scipy
from scipy import stats

import pycbc.version
from pycbc import init_logging
from pycbc.detector import Detector
from pycbc.results import save_fig_with_metadata
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.conversions import mchirp_from_mass1_mass2
from pycbc.io.hdf import HFile

plt.switch_backend('Agg')
rc("image")

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_efficiency"


def efficiency_with_errs(found_bestnr, num_injections, num_mc_injs=0):
    """Function to calculate the fraction of recovered injections and its
    error bars (used for efficiency/sensitive distance plots)."""

    if not isinstance(num_mc_injs, int):
        err_msg = "The parameter num_mc_injs is the number of Monte-Carlo "
        err_msg += "injections.  It must be an integer."
        raise TypeError(err_msg)

    only_found_injs = found_bestnr[:-1]
    all_injs = num_injections[:-1]
    fraction = only_found_injs / all_injs

    # Divide by Monte-Carlo iterations
    if num_mc_injs:
        only_found_injs = only_found_injs / num_mc_injs
        all_injs = all_injs / num_mc_injs

    err_common = all_injs * (2 * only_found_injs + 1)
    err_denom = 2 * all_injs * (all_injs + 1)
    err_vary = 4 * all_injs * only_found_injs * (all_injs - only_found_injs) \
                + all_injs**2
    err_vary = err_vary**0.5
    err_low = (err_common - err_vary)/err_denom
    err_low_mc = fraction - err_low
    err_high = (err_common + err_vary)/err_denom
    err_high_mc = err_high - fraction

    # Check for cases where error bars are negative and set them to zero
    if (err_low_mc < 0).any():
        logging.warning("Negative lower error bar(s) detected."
                        " Setting to zero.")
        err_low_mc[err_low_mc < 0] = 0
    if (err_high_mc < 0).any():
        logging.warning("Negative upper error bar(s) detected."
                        " Setting to zero.")
        err_high_mc[err_high_mc < 0] = 0

    return err_low_mc, err_high_mc, fraction


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument("-F", "--trig-file", action="store", required=True,
                    help="Location of off-source trigger file.")
parser.add_argument("--onsource-file", action="store",
                    help="Location of on-source trigger file (or a " +
                    "background trigger file to be treated as such).")
parser.add_argument("--background-output-file", required=True,
                    help="Detection efficiency output file.")
parser.add_argument("--onsource-output-file", required=True,
                    help="Exclusion distance output file.")
parser.add_argument("--exclusion-dist-output-file", required=True,
                    help="JSON file containing exclusion distances.")
parser.add_argument("-g", "--glitch-check-factor", action="store",
                    type=float, default=1.0, help="When deciding " +
                    "exclusion efficiencies this value is multiplied " +
                    "to the offsource around the injection trigger to " +
                    "determine if it is just a loud glitch.")
parser.add_argument("--found-missed-file", action="store", type=str,
                    required=True, help="Location of found/missed " +
                    "injections and trigger file")
parser.add_argument("--injection-set-name", action="store", type=str,
                    default="", help="Name of the injection set to be " +
                    "used in the plot title.")
parser.add_argument("--trial-name", action="store", type=str,
                    required=True, help="Name of trial used " +
                    "for this run (i.e. ONSOURCE, OFFTRIAL)")
parser.add_argument("-C", "--cluster-window", action="store", type=float,
                    default=0.1, help="The cluster window used " +
                    "to cluster triggers in time.")
parser.add_argument("--bank-file", action="store", type=str, required=True,
                    help="Location of the full template bank used.")
ppu.pygrb_add_injmc_opts(parser)
ppu.pygrb_add_bestnr_cut_opt(parser)
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Load bank file
bank_file = HFile(opts.bank_file, 'r')

# Store options used multiple times in local variables
outdir = os.path.split(os.path.abspath(opts.background_output_file))[0]
trig_file = opts.trig_file
onsource_file = opts.onsource_file
found_missed_file = opts.found_missed_file
inj_set_name = opts.injection_set_name
wf_err = opts.waveform_error
cal_errs = {}
cal_errs['G1'] = opts.g1_cal_error
cal_errs['H1'] = opts.h1_cal_error
cal_errs['K1'] = opts.k1_cal_error
cal_errs['L1'] = opts.l1_cal_error
cal_errs['V1'] = opts.v1_cal_error
cal_dc_errs = {}
cal_dc_errs['G1'] = opts.g1_dc_cal_error
cal_dc_errs['H1'] = opts.h1_dc_cal_error
cal_dc_errs['K1'] = opts.k1_dc_cal_error
cal_dc_errs['L1'] = opts.l1_dc_cal_error
cal_dc_errs['V1'] = opts.v1_dc_cal_error
# pycbc_multi_inspiral already applies sngl, coinc and null SNR cuts
new_snr_thresh = opts.newsnr_threshold
upper_dist = opts.upper_inj_dist
lower_dist = opts.lower_inj_dist
num_bins = opts.num_bins
wav_err = opts.waveform_error
cluster_window = opts.cluster_window
glitch_check_fac = opts.glitch_check_factor
num_mc_injs = opts.num_mc_injections
# Initialize random number generator
np.random.seed(opts.seed)
logging.info("Setting random seed to %d.", opts.seed)

# Set output directory
logging.info("Setting output directory.")
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Load triggers (apply reweighted SNR cut), time-slides, and segment dictionary
logging.info("Loading triggers.")
trigs = ppu.load_triggers(trig_file, ifos, vetoes,
                          rw_snr_threshold=opts.newsnr_threshold)
logging.info("%d offsource triggers surviving reweighted SNR cut.",
    len(trigs['network/event_id']))
logging.info("Loading timeslides.")
slide_dict = ppu.load_time_slides(trig_file)
logging.info("Loading segments.")
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials
logging.info("Constructing trials.")
trial_dict = ppu.construct_trials(opts.seg_files, segment_dict,
                                  ifos, slide_dict, vetoes)
total_trials = sum([len(trial_dict[slide_id]) for slide_id in slide_dict])
logging.info("%d trials generated.", total_trials)

# Extract basic trigger properties and store as dictionaries
trig_time, trig_snr, trig_bestnr = \
    ppu.extract_basic_trig_properties(trial_dict, trigs, slide_dict,
                                      segment_dict, opts)

# Calculate BestNR values and maximum
time_veto_max_bestnr = {}

for slide_id in slide_dict:
    num_slide_segs = len(trial_dict[slide_id])
    time_veto_max_bestnr[slide_id] = np.zeros(num_slide_segs)

for slide_id in slide_dict:
    for j, trial in enumerate(trial_dict[slide_id]):
        trial_cut = (trial[0] <= trig_time[slide_id])\
                          & (trig_time[slide_id] < trial[1])
        if not trial_cut.any():
            continue
        # Max BestNR
        time_veto_max_bestnr[slide_id][j] = \
                        max(trig_bestnr[slide_id][trial_cut])

logging.info("SNR and bestNR maxima calculated.")

# Output details of loudest offsouce triggers
offsource_trigs = []
sorted_trigs = ppu.sort_trigs(trial_dict, trigs, slide_dict, segment_dict)
for slide_id in slide_dict:
    offsource_trigs.extend(zip(trig_bestnr[slide_id], sorted_trigs[slide_id]))
offsource_trigs.sort(key=lambda element: element[0])
offsource_trigs.reverse()

# ==========================
# Print loudest SNRs to file
# THIS OUTPUT FILE IS CURRENTLY UNUSED - MAYBE DELETE?
# Note: the only new info from above is the median SNR, bestnr
# and loudest SNR, so could just add this to the above's caption.
# ==========================
max_bestnr, _, full_time_veto_max_bestnr =\
    ppu.max_median_stat(slide_dict, time_veto_max_bestnr, trig_bestnr, total_trials)

# ==========================
# Calculate template chirp masses from bank
# ==========================
logging.info('Reading template chirp masses')
with HFile(opts.bank_file, 'r') as bank_file:
    template_mchirps = mchirp_from_mass1_mass2(
        bank_file['mass1'][:],
        bank_file['mass2'][:]
    )

# =======================
# Load on source triggers
# =======================
if onsource_file:

    logging.info("Processing onsource.")
    # Get onsouce_triggers (apply reweighted SNR cut)
    on_trigs = ppu.load_triggers(onsource_file, ifos, vetoes,
                                 rw_snr_threshold=opts.newsnr_threshold)

    # Calculate chirp mass values
    on_mchirp = template_mchirps[on_trigs['network/template_id']]

    # Set loudest event arrays
    loud_on_bestnr = 0

    # Retrieve BestNRs and record loudest trig by BestNR.
    # Get indices of loudest triggers and pick the loudest.
    if on_trigs and on_trigs['network/reweighted_snr'].size > 0:
        loud_on_bestnr_idx = np.argmax(on_trigs['network/reweighted_snr'])
        loud_on_bestnr = np.max(on_trigs['network/reweighted_snr'])

    # Convert to float for output
    loud_on_bestnr = float(loud_on_bestnr)

    # If the loudest event has bestnr = 0, there is no event at all!
    if loud_on_bestnr == 0:
        loud_on_bestnr_idx = None
        loud_on_fap = 1

    logging.info("Onsource analysed.")

    if loud_on_bestnr_idx is not None:
        num_trials_louder = 0
        tot_off_snr = np.array([])
        for slide_id in slide_dict:
            num_trials_louder += sum(time_veto_max_bestnr[slide_id] >
                                     loud_on_bestnr)
            tot_off_snr = np.concatenate([tot_off_snr,
                                          time_veto_max_bestnr[slide_id]])
        #fap_test = sum(tot_off_snr > loud_on_bestnr)/total_trials
        loud_on_fap = num_trials_louder/total_trials

else:
    tot_off_snr = np.array([])
    for slide_id in slide_dict:
        tot_off_snr = np.concatenate([tot_off_snr,
                                      time_veto_max_bestnr[slide_id]])
    med_snr = np.median(tot_off_snr)
    #loud_on_fap = sum(tot_off_snr > med_snr)/total_trials

# =======================
# Post-process injections
# =======================

sites = [ifo[0] for ifo in ifos]

# injs contains the information about found/missed injections AND triggers
# Triggers and injections are discared if at vetoed times and/or below
# Reweighted SNR thrshold
injs = ppu.load_triggers(found_missed_file, ifos, vetoes,
                         rw_snr_threshold=opts.newsnr_threshold)

logging.info("Missed/found injections/triggers loaded.")

# Calculate quantities not included in trigger files, such as chirp mass
found_trig_mchirp = template_mchirps[injs['network/template_id']]


# Construct conditions for injection:
# 1) found louder than background,
zero_fap = np.zeros(len(injs['network/end_time_gc'])).astype(bool)
zero_fap_cut = injs['network/reweighted_snr'][:] > max_bestnr
zero_fap = zero_fap | (zero_fap_cut)

# 2) found (bestnr > 0) but not louder than background (non-zero FAP)
nonzero_fap = ~zero_fap & (injs['network/reweighted_snr'] != 0)

# 3) missed after being recovered (i.e., vetoed) are not used here
# missed = (~zero_fap) & (~nonzero_fap)

# Non-zero FAP triggers (g_ifar)
g_ifar = {}
g_ifar['bestnr'] = injs['network/reweighted_snr'][nonzero_fap]
g_ifar['stat'] = np.zeros([len(g_ifar['bestnr'])])
for ix, (mc, bestnr) in \
            enumerate(zip(found_trig_mchirp[nonzero_fap], g_ifar['bestnr'])):
    g_ifar['stat'][ix] = (full_time_veto_max_bestnr > bestnr).sum()
g_ifar['stat'] = g_ifar['stat'] / total_trials

# Set the sigma values
inj_sigma = {ifo: injs[f'{ifo}/sigmasq'][:] for ifo in ifos}
# If the sigmasqs are not populated, we can still do calibration errors,
# but only in the 1-detector case
for ifo in ifos:
    if sum(inj_sigma[ifo] == 0):
        logging.info("%s: sigmasq not set for at least one trigger.", ifo)
    if sum(inj_sigma[ifo] != 0) == 0:
        logging.info("%s: sigmasq not set for any trigger.", ifo)
        if len(ifos) == 1:
            msg = "Since this is a single ifo analysis, sigmasq will be "
            msg += "set to unity for all triggers in order to build the "
            msg += "calibration errors."
            logging.info(msg)
            inj_sigma[ifo][:] = 1.

f_resp = {}
for ifo in ifos:
    antenna = Detector(ifo)
    f_resp[ifo] = ppu.get_antenna_responses(antenna,
                                            injs['found/ra'][:], injs['found/dec'][:],
                                            injs['found/tc'][:])

inj_sigma_mult = (np.asarray(list(inj_sigma.values())) *
                  np.asarray(list(f_resp.values())))

inj_sigma_tot = inj_sigma_mult[0, :]
for i in range(1, len(ifos)):
    inj_sigma_tot += inj_sigma_mult[i, :]

inj_sigma_mean = {}
for ifo in ifos:
    inj_sigma_mean[ifo] = ((inj_sigma[ifo]*f_resp[ifo])/inj_sigma_tot).mean()

logging.info("%d found injections analysed.", len(injs['found/tc']))

# Process missed injections (injs['missed'])
logging.info("%d missed injections analysed.", len(injs['missed/tc']))

# Create new set of injections for efficiency calculations
total_injs = len(injs['found/distance']) + len(injs['missed/distance'])
long_inj = {}
long_inj['dist'] = stats.uniform.rvs(size=total_injs) * (upper_dist-lower_dist) +\
                   upper_dist

logging.info("%d long distance injections created.", total_injs)

# Set distance bins and data arrays
dist_bins = zip(np.arange(lower_dist, upper_dist + (upper_dist-lower_dist),
                         (upper_dist-lower_dist)/num_bins),
                np.arange(lower_dist, upper_dist + (upper_dist-lower_dist),
                         (upper_dist-lower_dist)/num_bins) +
                         (upper_dist-lower_dist)/num_bins)
dist_bins = list(dist_bins)

num_dist_bins_plus_one = len(dist_bins) + 1
num_injections = {}
found_max_bestnr = {}
found_on_bestnr = {}
for key in ['mc', 'no_mc']:
    num_injections[key] = np.zeros(num_dist_bins_plus_one)
    found_max_bestnr[key] = np.zeros(num_dist_bins_plus_one)
    found_on_bestnr[key] = np.zeros(num_dist_bins_plus_one)

# Construct FAP list for all found injections
inj_fap = np.zeros(len(injs['found/distance']))
inj_fap[nonzero_fap] = g_ifar['stat']

# Calculate the amplitude error
# Begin by calculating the components from each detector
cal_error = 0
for ifo in ifos:
    cal_error += cal_errs[ifo]**2 * inj_sigma_mean[ifo]**2
cal_error = cal_error**0.5

max_dc_cal_error = max(cal_dc_errs.values())

# Calibration phase uncertainties are neglected
logging.info("Calibration amplitude uncertainty calculated.")

# Now create the numbers for the efficiency plots; these include calibration
# and waveform errors. These are incorporated by running over each injection
# num_mc_injs times, where each time we draw a random value of distance

# Distribute injections
# NOTE: the loop on num_mc_injs would fill up the *_inj['dist_mc']'s at the
# same time, so filling them up sequentially will vary the numbers a little
# (this is an MC, order of operations matters!)
found_inj_dist_mc = ppu.mc_cal_wf_errs(num_mc_injs, injs['found/distance'],
                                          cal_error, wav_err, max_dc_cal_error)
missed_inj_dist_mc = ppu.mc_cal_wf_errs(num_mc_injs, injs['missed/distance'],
                                           cal_error, wav_err, max_dc_cal_error)
long_inj['dist_mc'] = ppu.mc_cal_wf_errs(num_mc_injs, long_inj['dist'],
                                         cal_error, wav_err, max_dc_cal_error)

logging.info("MC injection set distributed with %d iterations.",
             num_mc_injs)

# Check injections against on source
if onsource_file:
    more_sig_than_onsource = (inj_fap <= loud_on_fap)
else:
    more_sig_than_onsource = (inj_fap <= 0.5)

distance_count = np.zeros(len(dist_bins))

found_trig_max_bestnr = np.empty(len(injs['network/event_id']))
found_trig_max_bestnr.fill(max_bestnr)

max_bestnr_cut = (injs['network/reweighted_snr'] > found_trig_max_bestnr)

# Check louder than on source
found_trig_loud_on_bestnr = np.empty(len(injs['network/event_id']))
if onsource_file:
    found_trig_loud_on_bestnr.fill(loud_on_bestnr)
else:
    found_trig_loud_on_bestnr.fill(med_snr)
on_bestnr_cut = injs['network/reweighted_snr'] > found_trig_loud_on_bestnr

# Check whether injection is found for the purposes of exclusion
# distance calculation.
# Found: if louder than all on source
# Missed: if not louder than loudest on source
found_excl = on_bestnr_cut & (more_sig_than_onsource) & \
            (injs['network/reweighted_snr'] != 0)
# If not missed, double check bestnr against nearby triggers
near_test = np.zeros((found_excl).sum()).astype(bool)
for j, (t, bestnr) in enumerate(zip(injs['found/tc'][found_excl],
                                    injs['network/reweighted_snr'][found_excl])):
    # 0 is the zero-lag timeslide
    near_bestnr = trig_bestnr[0]\
                  [np.abs(trig_time[0]-t) < cluster_window]
    near_test[j] = ~((near_bestnr * glitch_check_fac > bestnr).any())
# Apply the local test
c = 0
for z, b in enumerate(found_excl):
    if found_excl[z]:
        found_excl[z] = near_test[c]
        c += 1

# Loop over each random instance of the injection set
for k in range(num_mc_injs+1):
    # Loop over the distance bins
    for j, dist_bin in enumerate(dist_bins):
        # Construct distance cut
        found_dist_cut = (dist_bin[0] <= found_inj_dist_mc[k, :]) &\
                         (found_inj_dist_mc[k, :] < dist_bin[1])
        missed_dist_cut = (dist_bin[0] <= missed_inj_dist_mc[k, :]) &\
                          (missed_inj_dist_mc[k, :] < dist_bin[1])
        long_dist_cut = (dist_bin[0] <= long_inj['dist_mc'][k, :]) &\
                        (long_inj['dist_mc'][k, :] < dist_bin[1])

        # Count all injections in this distance bin
        num_found_pass = (found_dist_cut).sum()
        num_missed_pass = (missed_dist_cut).sum()
        num_long_pass = long_dist_cut.sum() or 0
        # Count only zero FAR injections
        num_zero_far = (found_dist_cut & max_bestnr_cut).sum()
        # Count number found for exclusion
        num_excl = (found_dist_cut & (found_excl)).sum()

        # Record number of injections, number found for exclusion
        # and number of zero FAR
        if k == 0:
            key = 'no_mc'
        else:
            key = 'mc'
        num_pass = num_found_pass + num_missed_pass + num_long_pass
        num_injections[key][j] += num_pass
        num_injections[key][-1] += num_pass
        found_max_bestnr[key][j] += num_zero_far
        found_max_bestnr[key][-1] += num_zero_far
        found_on_bestnr[key][j] += num_excl
        found_on_bestnr[key][-1] += num_excl

logging.info("Found/missed injection efficiency calculations completed.")

# Post-processing of injections ends here

# ==========
# Make plots
# ==========
# Calculate distances (horizontal axis) as means
dist_plot_vals = [np.asarray(dist_bin).mean() for dist_bin in dist_bins]

# Calculate error bars for efficiency/distance plots and datafiles
# using max BestNR of background
yerr_low_mc, yerr_high_mc, fraction_mc = efficiency_with_errs(
     found_max_bestnr['mc'], num_injections['mc'], num_mc_injs=num_mc_injs)
yerr_low_no_mc, yerr_high_no_mc, fraction_no_mc = efficiency_with_errs(\
                           found_max_bestnr['no_mc'], num_injections['no_mc'])

# Calculate and save to disk the 50% sensitive distance
eff_low = fraction_no_mc
eff_idx = np.where(eff_low < 0.5)[0]
if eff_idx.size == 0:
    sens_dist = -1
    err_msg = "Efficiency does not drop below 50%!"
    logging.error(err_msg)
elif eff_idx[0] == 0:
    sens_dist = 0
    err_msg = "Efficiency does not drop below 90%!"
    logging.error(err_msg)
else:
    i = eff_idx[0]
    d = dist_plot_vals[i]
    d_low = dist_plot_vals[i-1]
    e = eff_low[i]
    e_low = eff_low[i-1]
    sens_dist = d + (e - 0.5) * (d - d_low) / (e_low - e)

# Plot efficiency using loudest background
fig = plt.figure()
ax = fig.gca()
ax.plot(dist_plot_vals, (fraction_no_mc), 'g-',
        label='No marginalisation')
ax.errorbar(dist_plot_vals, (fraction_no_mc),
            yerr=[yerr_low_no_mc, yerr_high_no_mc], c='green')
marg_eff = fraction_mc
if np.nansum(marg_eff) > 0:
    ax.plot(dist_plot_vals, marg_eff, 'r-', label='Marginalised')
    ax.errorbar(dist_plot_vals, marg_eff, yerr=[yerr_low_mc, yerr_high_mc],
                c='red')
ax.legend()
ax.grid()
ax.set_ylim([0, 1])
ax.set_xlim(0, 2.*upper_dist - lower_dist)
ax.set_ylabel("Fraction of injections found louder than loudest background")
ax.set_xlabel("Distance (Mpc)")
plot_title = "Detection efficiency - "+inj_set_name
plot_caption = "Injection recovery efficiency using "
plot_caption += "BestNR as detection statistic.  "
plot_caption += "Injections louder than loudest background trigger."
fig_path = opts.background_output_file
save_fig_with_metadata(fig, fig_path, cmd=' '.join(sys.argv),
                       title=plot_title, caption=plot_caption)
plt.close()

# Calculate error bars for efficiency/distance plots and datafiles
# using max BestNR of foreground
yerr_low_no_mc, yerr_high_no_mc, fraction_no_mc = efficiency_with_errs(\
                            found_on_bestnr['no_mc'], num_injections['no_mc'])
yerr_low, yerr_high, fraction_mc = efficiency_with_errs(found_on_bestnr['mc'],\
                                 num_injections['mc'], num_mc_injs=num_mc_injs)

# Marginalized efficiency (isf = inverse survival function)
red_efficiency = (fraction_mc) - (yerr_low) * scipy.stats.norm.isf(0.1)

# Calculate and save to disk 50% and 90% exclusion distances
for percentile in [50, 90]:
    eff_idx = np.where(red_efficiency < (percentile / 100.))[0]
    if eff_idx.size == 0:
        green_efficiency = (fraction_no_mc)
        excl_efficiency = green_efficiency
        eff_idx = np.where(green_efficiency < (percentile / 100.))[0]
    else:
        excl_efficiency = red_efficiency
    if eff_idx.size and eff_idx[0] != 0:
        i = eff_idx[0]
        d = dist_plot_vals[i]
        d_low = dist_plot_vals[i-1]
        e = excl_efficiency[i]
        e_low = excl_efficiency[i-1]
        excl_dist = d + (e - (percentile / 100.)) * (d - d_low) /\
                (e_low - e)
    else:
        # Warn the user if the exclusion distance cannot be established,
        # but let the workflow continue: the user will see the plot(s) and
        # repeat with more injections
        excl_dist = 0
        msg = "Efficiency below %d%% in first bin!" % (percentile)
        logging.warning(msg)

# Write 50% and 90% exclusion distances to JSON file
# Also include injection set name and trial name
excl_dist_dict = {}
excl_dist_dict['inj_set'] = inj_set_name
excl_dist_dict['trial_name'] = opts.trial_name
excl_dist_dict['50%'] = sens_dist
excl_dist_dict['90%'] = excl_dist
with open(opts.exclusion_dist_output_file, 'w') as excl_dist_file:
    json.dump(excl_dist_dict, excl_dist_file)

# Plot efficiency using loudest foreground
fig = plt.figure()
ax = fig.gca()
ax.grid()
ax.plot(dist_plot_vals, (fraction_no_mc), 'g-',
        label='No marginalisation')
ax.errorbar(dist_plot_vals, (fraction_no_mc),
            yerr=[yerr_low_no_mc, yerr_high_no_mc], c='green')
marg_eff = fraction_mc
if not np.nansum(marg_eff) > 0:
    ax.plot(dist_plot_vals, marg_eff, 'r-', label='Marginalised')
    ax.errorbar(dist_plot_vals, marg_eff, yerr=[yerr_low, yerr_high], c='red')
if not np.nansum(red_efficiency) > 0:
    ax.plot(dist_plot_vals, red_efficiency, 'm-',
            label='Inc. counting errors')
ax.set_ylim([0, 1])
ax.grid()
ax.legend()
ax.get_legend().get_frame().set_alpha(0.5)
ax.grid()
ax.set_ylim([0, 1])
ax.set_xlim(0, 2.*upper_dist - lower_dist)
ax.set_ylabel("Fraction of injections found louder than "+\
              "loudest foreground")
ax.set_xlabel("Distance (Mpc)")
ax.plot([excl_dist], [0.9], 'gx')
ax.set_ylim([0, 1])
ax.set_xlim(0, 2.*upper_dist - lower_dist)
plot_title = "Exclusion distance - "+inj_set_name
plot_caption = "Injection recovery efficiency using "
plot_caption += "BestNR as detection statistic.  "
plot_caption += "Injections louder than loudest foreground trigger.\n"
plot_caption += f" 90%% exclusion distance: {excl_dist} Mpc\n"
plot_caption += f" 50%% sensitive distance: {sens_dist} Mpc"
fig_path = opts.onsource_output_file
save_fig_with_metadata(fig, fig_path, cmd=' '.join(sys.argv),
                       title=plot_title, caption=plot_caption)
plt.close()

logging.info("Plots complete.")
