#!/usr/bin/python3.13 -s

"""Supervise the periodic re-fitting of PyCBC Live single-detector triggers,
and the associated plots.
"""

import re
import logging
import argparse
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import os
import shutil
import subprocess
import numpy as np

from lal import gpstime

import pycbc
from pycbc.io.hdf import HFile
from pycbc.live import supervision as sv
from pycbc.types.config import InterpolatingConfigParser as icp

def read_options(args):
    """
    read the options into a dictionary
    """
    logging.info("Reading config file")
    cp = icp(configFiles=[args.config_file])
    config_opts = {
        section: {k: v for k, v in cp[section].items()}
        for section in cp.sections()
    }
    del config_opts['environment']
    return config_opts

def trigger_collation(
        day_dt,
        day_str,
        collation_control_options,
        collation_options,
        output_dir,
        controls
    ):
    """
    Perform the trigger collation as specified
    """
    logging.info("Performing trigger collation")
    collate_args = [
        'pycbc_live_collate_triggers',
    ]
    collate_args += sv.dict_to_args(collation_options)
    gps_start = gpstime.utc_to_gps(day_dt).gpsSeconds
    gps_end = gpstime.utc_to_gps(day_dt + timedelta(days=1)).gpsSeconds

    trig_merge_file = os.path.join(
        output_dir,
        collation_control_options['collated-triggers-format'].format(
            ifos=''.join(controls['ifos'].split()),
            start=gps_start,
            duration=(gps_end - gps_start)
        )
    )
    collate_args += [
        '--gps-start-time', f'{gps_start:d}',
        '--gps-end-time', f'{gps_end:d}',
        '--output-file', trig_merge_file,
    ]

    sv.run_and_error(collate_args, controls)

    return trig_merge_file


def fit_by_template(
        trigger_merge_file,
        day_str,
        fbt_control_options,
        fbt_options,
        output_dir,
        ifo,
        controls
    ):
    """
    Supervise the running of pycbc_fit_sngls_by_template on live triggers
    """
    logging.info("Performing daily fit_by_template")
    fbt_out_fname = fbt_control_options['fit-by-template-format'].format(
        date=day_str,
        ifo=ifo,
    )
    fbt_out_full = os.path.join(output_dir, fbt_out_fname)
    fit_by_args = ['pycbc_fit_sngls_by_template']
    fit_by_args += ['--trigger-file', trigger_merge_file]
    fit_by_args += sv.dict_to_args(fbt_options)
    fit_by_args += ['--output', fbt_out_full, '--ifo', ifo]
    sv.run_and_error(fit_by_args, controls)

    return fbt_out_full, day_str


def find_daily_fit_files(
        combined_control_options,
        daily_fname_format,
        daily_files_dir,
        ifo=None
    ):
    """
    Find files which match the specified formats
    """
    log_str = f"Finding files in {daily_files_dir} with format {daily_fname_format}"
    if ifo is not None:
        log_str += f"in detector {ifo}"
    logging.info(log_str)
    combined_days = int(combined_control_options['combined-days'])
    if 'replay-start-time' in combined_control_options:
        replay_start_time = int(combined_control_options['replay-start-time'])
        true_start_time = int(combined_control_options['true-start-time'])
        replay_duration = int(combined_control_options['replay-duration'])
        rep_start_utc = lal.GPSToUTC(replay_start_time)[0:6]

        dt_replay_start = datetime(
            year=rep_start_utc[0],
            month=rep_start_utc[1],
            day=rep_start_utc[2],
            hour=rep_start_utc[3],
            minute=rep_start_utc[4],
            second=rep_start_utc[5]
        )

        td = (day_dt - dt_replay_start).total_seconds()

        # Time since the start of this replay
        time_since_replay = np.remainder(td, replay_duration)

        # Add this on to the original start time to get the current time of
        # the replay data
        true_utc = lal.GPSToUTC(true_start_time)[0:6]
        dt_true_start = datetime(
            year=true_utc[0],
            month=true_utc[1],
            day=true_utc[2],
            hour=true_utc[3],
            minute=true_utc[4],
            second=true_utc[5]
        )
        # Original time of the data being replayed right now
        current_date = dt_true_start + timedelta(seconds=time_since_replay)
    else:
        current_date = day_dt

    date_test = current_date + timedelta(days=1)

    daily_files = []
    missed_files = 0
    # Maximum consecutive number of days between files before a warning is raised
    # 10 days of the detector being off would be unusual for current detectors
    max_nmissed = combined_control_options.get('maximum_missed_files', 10)
    found_files = 0
    while found_files < combined_days and missed_files < max_nmissed:
        # Loop through the possible file locations and see if the file exists
        date_test -= timedelta(days=1)
        date_out = date_test.strftime("%Y_%m_%d")
        daily_fname = daily_fname_format.format(
            date=date_out,
            ifo=ifo,
        )

        output_dir = os.path.join(
            daily_files_dir,
            date_out
        )
        daily_full = os.path.join(
            output_dir,
            daily_fname
        )
        # Check that the file exists:
        if not os.path.exists(daily_full):
            missed_files += 1
            logging.info("File %s does not exist - skipping", daily_full)
            continue
        if not len(daily_files):
            end_date = date_out
        # This is now the oldest file
        first_date = date_out
        # reset the "missed files" counter, and add to the "found files"
        missed_files = 0
        found_files += 1
        daily_files.append(daily_full)

    if found_files == 0:
        raise RuntimeError("No files found")

    if missed_files == max_nmissed:
        # If more than a set maximum days between files, something
        # is wrong with the analysis. Warn about this and use fewer
        # files
        logging.warning(
            f'More than {max_nmissed} days between files, only using '
            f'{found_files} files!'
        )

    return daily_files, first_date, end_date


def fit_over_multiparam(
        fit_over_controls,
        fit_over_options,
        ifo,
        day_str,
        output_dir,
        controls
    ):
    """
    Supervise the smoothing of live trigger fits using
    pycbc_fit_sngls_over_multiparam
    """
    daily_files, first_date, end_date = find_daily_fit_files(
        fit_over_controls,
        fit_over_controls['fit-by-format'],
        controls['output-directory'],
        ifo=ifo
    )
    logging.info(
        "Smoothing fits using fit_over_multiparam with %d files and "
        "specified parameters",
        len(daily_files)
    )
    file_id_str = f'{first_date}-{end_date}'
    out_fname = fit_over_controls['fit-over-format'].format(
        dates=file_id_str,
        ifo=ifo,
    )

    fit_over_args = ['pycbc_fit_sngls_over_multiparam', '--template-fit-file']
    fit_over_args += daily_files
    fit_over_args += sv.dict_to_args(fit_over_options)
    fit_over_full = os.path.join(output_dir, out_fname)
    fit_over_args += ['--output', fit_over_full]
    sv.run_and_error(fit_over_args, controls)
    if 'variable-fit-over-param' in fit_over_controls:
        variable_fits = fit_over_controls['variable-fit-over-param'].format(
            ifo=ifo
        )
        sv.symlink(fit_over_full, variable_fits)

    return fit_over_full, file_id_str

def plot_fits(
        fits_file,
        ifo,
        day_title_str,
        plot_fit_options,
        controls,
        smoothed=False
    ):
    """Plotting for fit_by files, and linking to the public directory"""
    fits_plot_output = fits_file[:-3] + 'png'
    logging.info(
        "Plotting template fits %s to %s",
        fits_file,
        fits_plot_output
    )
    fits_plot_arguments = [
        'pycbc_plot_bank_corner',
        '--fits-file',
        fits_file,
        '--output-plot-file',
        fits_plot_output,
    ]
    fits_plot_arguments += sv.dict_to_args(plot_fit_options)

    title = "Fit parameters for pycbc-live, triggers from " + day_title_str
    if smoothed == True:
        title += ', smoothed'
    fits_plot_arguments += ['--title', title]
    sv.run_and_error(fits_plot_arguments, controls)
    if 'public-dir' in controls:
        public_dir = os.path.abspath(os.path.join(
            controls['public-dir'],
            *day_str.split('_')
        ))
        sv.symlink(fits_plot_output, public_dir)


def single_significance_fits(
        daily_controls,
        daily_options,
        output_dir,
        day_str,
        day_dt,
        controls,
        test_options,
        stat_files=None,
    ):
    """
    Supervise the significance fits for live triggers using
    pycbc_live_single_significance_fits
    """
    daily_options['output'] = os.path.join(
        output_dir,
        daily_controls['sig-daily-format'].format(date=day_str),
    )
    daily_args = ['pycbc_live_single_significance_fits']

    gps_start_time = gpstime.utc_to_gps(day_dt).gpsSeconds
    gps_end_time = gpstime.utc_to_gps(day_dt + timedelta(days=1)).gpsSeconds
    daily_options['gps-start-time'] = f'{gps_start_time:d}'
    daily_options['gps-end-time'] = f'{gps_end_time:d}'
    daily_args += sv.dict_to_args(daily_options)
    if stat_files is not None:
        daily_args += ['--statistic-files'] + stat_files

    sv.run_and_error(daily_args, controls)

    return daily_options['output']

def plot_single_significance_fits(daily_output, daily_plot_options, controls):
    """
    Plotting daily significance fits, and link to public directory if wanted
    """
    daily_plot_output = f'{daily_output[:-4]}_{{ifo}}.png'
    logging.info(
        "Plotting daily significance fits from %s to %s",
        daily_output,
        daily_plot_output
    )
    daily_plot_arguments = [
        'pycbc_live_plot_single_significance_fits',
        '--trigger-fits-file',
        daily_output,
        '--output-plot-name-format',
        daily_plot_output,
    ]
    daily_plot_arguments += sv.dict_to_args(daily_plot_options)
    sv.run_and_error(daily_plot_arguments, controls)

    # Link the plots to the public-dir if wanted
    if 'public-dir' in controls:
        daily_plot_outputs = [
            daily_plot_output.format(ifo=ifo)
            for ifo in controls['ifos'].split()
        ]
        logging.info("Linking daily fits plots")
        for dpo in daily_plot_outputs:
            public_dir = os.path.abspath(os.path.join(
                controls['public-dir'],
                *day_str.split('_')
            ))

            sv.symlink(dpo, public_dir)


def combine_significance_fits(
        combined_options,
        combined_controls,
        output_dir,
        day_str,
        controls
    ):
    """
    Supervise the smoothing of live trigger significance fits using
    pycbc_live_combine_single_significance_fits
    """
    daily_files, first_date, end_date = find_daily_fit_files(
        combined_controls,
        combined_controls['daily-format'],
        controls['output-directory'],
    )
    logging.info(
        "Smoothing significance fits over %d files",
        len(daily_files)
    )
    date_range = f'{first_date}-{end_date}'
    outfile_name = combined_controls['outfile-format'].format(
        date=day_str,
        date_range=date_range,
    )
    combined_options['output'] = os.path.join(output_dir, outfile_name)
    combined_options['trfits-files'] = ' '.join(daily_files)

    combined_args = ['pycbc_live_combine_single_significance_fits']
    combined_args += sv.dict_to_args(combined_options)

    sv.run_and_error(combined_args, controls)

    if 'variable-significance-fits' in combined_controls:
        logging.info("Linking to variable significance fits file")
        sv.symlink(
            combined_options['output'],
            combined_controls['variable-significance-fits']
        )

    return combined_options['output'], date_range

def plot_combined_significance_fits(
        csf_file,
        date_range,
        output_dir,
        combined_plot_options,
        combined_plot_control_options,
        controls
    ):
    """
    Plotting combined significance fits, and link to public directory if wanted
    """

    oput_fmt = combined_plot_control_options['output-plot-name-format']
    if not '{date_range}' in oput_fmt:
        raise RuntimeError(
            "Must specify {date_range} in output-plot-name-format"
        )
    oput_fmt = oput_fmt.replace('{date_range}', date_range)
    oput_full = os.path.join(output_dir, oput_fmt)
    combined_plot_arguments = [
        'pycbc_live_plot_combined_single_significance_fits',
        '--combined-fits-file',
        csf_file,
        '--output-plot-name-format',
        oput_full
    ]
    combined_plot_arguments += sv.dict_to_args(combined_plot_options)

    sv.run_and_error(combined_plot_arguments, controls)

    # Get the list of combined plotting output files:
    combined_plot_outputs = [
        oput_full.format(ifo=ifo, type='fit_coeffs') for ifo in
        controls['ifos'].split()
    ]
    combined_plot_outputs += [
        oput_full.format(ifo=ifo, type='counts') for ifo in
        controls['ifos'].split()
    ]

    if 'public-dir' in controls:
        logging.info("Linking combined fits to public dir")
        public_dir = os.path.abspath(os.path.join(
            controls['public-dir'],
            *day_str.split('_')
        ))
        for cpo in combined_plot_outputs:
            sv.symlink(cpo, public_dir)

def supervise_collation_fits_dq(args, day_dt, day_str):
    """
    Perform the trigger collation and fits etc. as specified
    """
    # Read in the config file and pack into appropriate dictionaries
    config_opts = read_options(args)
    controls = config_opts['control']
    collation_options = config_opts['collation']
    collation_control_options = config_opts['collation_control']
    fit_by_template_options = config_opts['fit_by_template']
    fit_by_template_control_options = config_opts['fit_by_template_control']
    fit_over_options = config_opts['fit_over_multiparam']
    fit_over_control_options = config_opts['fit_over_multiparam_control']
    plot_fit_options = config_opts['plot_fit']
    daily_options = config_opts['significance_daily_fits']
    daily_control_options = config_opts['significance_daily_fits_control']
    daily_plot_options = config_opts['plot_significance_daily']
    combined_options = config_opts['significance_combined_fits']
    combined_control_options = config_opts['significance_combined_fits_control']
    combined_plot_options = config_opts['plot_significance_combined']
    combined_plot_control_options = config_opts['plot_significance_combined_control']
    test_options = config_opts['test']

    # The main output directory will have a date subdirectory which we
    # put the output into
    sv.ensure_directories(controls, day_str)

    ifos = controls['ifos'].split()
    output_dir = os.path.join(
        controls['output-directory'],
        day_str
    )
    logging.info("Outputs to %s", output_dir)
    if 'public_dir' in controls:
        public_dir = os.path.abspath(os.path.join(
            controls['public-dir'],
            *day_str.split('_')
        ))
        logging.info("Outputs to be linked to % ", public_dir)

    merged_triggers = trigger_collation(
        day_dt,
        day_str,
        collation_control_options,
        collation_options,
        output_dir,
        controls
    )
    # Store the locations of files needed for the statistic
    stat_files = []
    for ifo in config_opts['control']['ifos'].split():
        if args.fit_by_template:
            fbt_file, date_str = fit_by_template(
                merged_triggers,
                day_str,
                fit_by_template_control_options,
                fit_by_template_options,
                output_dir,
                ifo,
                controls,
            )
            plot_fits(
                fbt_file,
                ifo,
                date_str,
                plot_fit_options,
                controls
            )

        if args.fit_over_multiparam:
            fom_file, date_str = fit_over_multiparam(
                fit_over_control_options,
                fit_over_options,
                ifo,
                day_str,
                output_dir,
                controls
            )
            stat_files.append(fom_file)
            plot_fits(
                fom_file,
                ifo,
                date_str,
                plot_fit_options,
                controls,
                smoothed=True,
            )

    if args.single_significance_fits:
        ssf_file = single_significance_fits(
            daily_control_options,
            daily_options,
            output_dir,
            day_str,
            day_dt,
            controls,
            test_options,
            stat_files=stat_files,
        )
        plot_single_significance_fits(
            ssf_file,
            daily_plot_options,
            controls
        )
    if args.combine_significance_fits:
        csf_file, date_str = combine_significance_fits(
            combined_options,
            combined_control_options,
            output_dir,
            date_str,
            controls
        )
        plot_combined_significance_fits(
            csf_file,
            date_str,
            output_dir,
            combined_plot_options,
            combined_plot_control_options,
            controls
        )


def get_yesterday_date():
    """ Get the date string for yesterday's triggers """
    day_dt = datetime.utcnow() - timedelta(days=1)
    day_dt = datetime.combine(day_dt, datetime.min.time())
    day_str = day_dt.strftime('%Y_%m_%d')
    return date_dt, date_str

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument(
    '--config-file',
    required=True
)
parser.add_argument(
    '--date',
    help='Date to analyse, if not given, will analyse yesterday (UTC). '
         'Format YYYY_MM_DD. Do not use if using --run-daily-at.'
)
parser.add_argument(
    '--fit-by-template',
    action='store_true',
    help="Perform template fits calculation."
)
parser.add_argument(
    '--fit-over-multiparam',
    action='store_true',
    help="Perform template fits smoothing."
)
parser.add_argument(
    '--single-significance-fits',
    action='store_true',
    help="Perform daily singles significance fits."
)
parser.add_argument(
    '--combine-significance-fits',
    action='store_true',
    help="Do combination of singles significance fits."
)
parser.add_argument(
    '--run-daily-at',
    metavar='HH:MM:SS',
    help='Stay running and repeat the fitting daily at the given UTC hour.'
)
args = parser.parse_args()

pycbc.init_logging(args.verbose, default_level=1)

if args.run_daily_at is not None and args.date is not None:
    parser.error('Cannot take --run-daily-at and --date at the same time')

if args.run_daily_at is not None:
    # keep running and repeat the fitting every day at the given hour
    if not re.match('[0-9][0-9]:[0-9][0-9]:[0-9][0-9]', args.run_daily_at):
        parser.error('--run-daily-at takes a UTC time in the format HH:MM:SS')
    logging.info('Starting in daily run mode')
    while True:
        sv.wait_for_utc_time(args.run_daily_at)
        day_dt, day_str = get_yesterday_date()
        logging.info('==== Time to update the single fits, waking up ====')
        supervise_collation_fits_dq(args, day_dt, day_str)
else:
    # run just once
    if args.date:
        day_str = args.date
        day_dt = datetime.strptime(args.date, '%Y_%m_%d')
    else:
        day_dt, day_str = get_yesterday_date()
    supervise_collation_fits_dq(args, day_dt, day_str)
