#!/usr/bin/python3.13 -s
""" Convert inference file to parallel temperred compatible start format
"""

import argparse
import numpy
import h5py
from numpy.random import choice

from pycbc import add_common_pycbc_options, init_logging
from pycbc.inference.io import loadfile
from pycbc.inference.sampler import samplers

parser = argparse.ArgumentParser()
add_common_pycbc_options(parser)
parser.add_argument('--input-file')
parser.add_argument('--output-file')
parser.add_argument('--sampler', default='emcee_pt',
    help="The output sampler type, if none, we assume emcee_pt")
parser.add_argument('--ntemps', type=int)
parser.add_argument('--nwalkers', type=int)
args = parser.parse_args()

init_logging(opts.verbose)

# populate an emcee start file with
# values chosen from a dynesty file
# each temperature and walker will get a random
# point from the dynesty output

ntemps = args.ntemps
nwalkers = args.nwalkers

f = loadfile(args.input_file)
params = list(f.variable_params) + ['loglikelihood']
samples = f.read_samples(params)
nsample = len(samples)

# These are the ids we'll use for the temps / walkers
use = choice(nsample, replace=False, size=ntemps * nwalkers)

o = loadfile(args.output_file, 'w', filetype=samplers[args.sampler]._io.name)
for k in params:
    data = samples[k][use]
    o['samples/' + k] = data.reshape(ntemps, nwalkers, 1)

o.attrs['static_params'] = []
o.attrs['variable_params'] = f.variable_params
o.create_group('sampler_info')
o['sampler_info'].attrs['nchains'] = nwalkers

o.close()
f.close()
