#! /usr/bin/python3

"""\
%(prog)s [-o outfile] <yodafile1>[:<scale1>] <yodafile2>[:<scale1>] ...
  e.g. %(prog)s run1.yoda run2.yoda run3.yoda  (unweighted merging of three runs)
    or %(prog)s run1.yoda:2.0 run2.yoda:3.142  (weighted merging of two runs)

Merge analysis objects from multiple YODA files, combining the statistics of
objects whose names are found in multiple files. May be used either to merge
disjoint collections of data objects, or to combine multiple statistically
independent runs of the same data objects into one high-statistics run. Optional
scaling parameters may be given to rescale the weights of the objects on a
per-file basis before merging.

By default the output is written to stdout since we can't guess what would be
a good automatic filename choice! Use the -o option to provide an output filename.


IMPORTANT!
  This script is not meant to handle all run merging situations or data objects:
  there are limitations to what can be inferred from data objects alone. If you
  need to do something more complex than the common cases handled by this script,
  please write your own script / program to load and process the data objects.


SCATTERS (E.G. HISTOGRAM RATIOS) CAN'T BE MERGED

  Note that 'scatter' data objects, as opposed to histograms, cannot be merged
  by this tool since they do not preserve sufficient statistical
  information. The canonical example of this is a ratio plot: there are
  infinitely many combinations of numerator and denominator which could give the
  same ratio, and the result does not indicate anything about which of those
  infinite inputs is right (or the effects of correlations in the division).

  If you need to merge Scatter2D objects, you can write your own Python script
  or C++ program using the YODA interface, and apply whatever case-specific
  treatment is appropriate. By default the first such copy encountered will be
  returned as the 'merged' output, with no actual merging having been done.

NORMALIZED, UNNORMALIZED, OR A MIX?

  An important detail in histogram merging is whether a statistical treatment
  for normalized or unnormalized histograms should be used: in the former case
  the normalization scaling must be undone *before* the histograms are added
  together, and then re-applied afterwards. This script examines the ScaledBy
  attribute each histograms to determine if it has been normalized. We make the
  assumption that if ScaledBy exists (i.e. h.scaleW has been called) then the
  histogram is normalized and we normalize the resulting merged histogram to the
  weighted average of input norms; if there is no ScaledBy, we assume that the
  histogram is not normalised.

  This is not an infallible approach, but we believe is more robust than heuristics
  to determine whether norms are sufficiently close to be considered equal.
  In complicated situations you will again be better off writing your own
  script or program to do the merging: the merging machinery of this script is
  available directly in the yoda Python module.

See the source of this script (e.g. use 'less `which %(prog)s`) for more discussion.
"""

# MORE NOTES
#
# If all the input histograms with a particular path are found to have the same
# normalization, and they have ScaledBy attributes indicating that a histogram
# weight scaling has been applied in producing the input histograms, each
# histogram in that group will be first unscaled by their appropriate factor, then
# merged, and then re-normalized to the target value. Otherwise the weights from
# each histogram copy will be directly added together with no attempt to guess an
# appropriate normalization. The normalization guesses (and they are guesses --
# see below) are made *before* application of the per-file scaling arguments.
#
# IMPORTANT: note from the above that this script can't work out what to do
# re. scaling and normalization of output histograms from the input data files
# alone. It may be possible (although unlikely) that input histograms have the
# same normalization but are meant to be added directly. It may also be the case
# (and much more likely) that histograms which should be normalized to a common
# value will not trigger the appropriate treatment due to e.g. statistical
# fluctuations in each run's calculation of a cross-section used in the
# normalization. And anything more complex than a global scaling (e.g. calculation
# of a ratio or asymmetry) cannot be handled at all with a post-hoc scaling
# treatment. The --assume-normalized command line option will force all histograms
# to be treated as if they are normalized in the input, which can be useful if
# you know that all the output histograms are indeed of this nature. If they are
# not, it will go wrong: you have been warned!
#
# Please use this script as a template if you need to do something more specific.
#
# NOTE: there are many possible desired behaviours when merging runs, depending on
# the factors above as well as whether the files being merged are of homogeneous
# type, heterogeneous type, or a combination of both. It is tempting, therefore,
# to add a large number of optional command-line parameters to this script, to
# handle these cases. Experience from Rivet 1.x suggests that this is a bad idea:
# if a problem is of programmatic complexity then a command-line interface which
# attempts to solve it in general is doomed to both failure and unusability. Hence
# we will NOT add extra arguments for applying different merging weights or
# strategies based on analysis object path regexes, auto-identifying 'types' of
# run, etc., etc.: if you need to merge data files in such complex ways, please
# use this script as a template around which to write logic that satisfies your
# particular requirements.


import yoda, argparse, sys, math

parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument("INFILES", nargs="+", help="datafile1 datafile2 [...]")
parser.add_argument("-o", "--output", default="-", dest="OUTPUT_FILE", metavar="PATH",
                    help="write output to specified path")
parser.add_argument("--s1d-mode", "--s1dmode", default="assume_mean", dest="S1D_MODE", metavar="MODE",
                    help="choose strategy for combining Scatter1D objects: one of 'first', 'combine', 'assume_mean', 'add'")
parser.add_argument("--s2d-mode", "--s2dmode", default="assume_mean", dest="S2D_MODE", metavar="MODE",
                    help="choose strategy for combining Scatter2D objects: one of 'first', 'combine', 'assume_mean', 'add'")
parser.add_argument("--s3d-mode", "--s3dmode", default="assume_mean", dest="S3D_MODE", metavar="MODE",
                    help="choose strategy for combining Scatter3D objects: one of 'first', 'combine', 'assume_mean', 'add'")
parser.add_argument("--type-mismatch-mode", default="scatter", dest="TYPE_MISMATCH_MODE", metavar="MODE",
                    help="choose strategy for combining objects whose types mismatch: one of 'first', 'scatter'")
parser.add_argument("--add", "--stack", action="store_true", default=False, dest="STACK",
                    help="force simple stacking (also forces all scatter modes to 'add')")
parser.add_argument("--no-veto-empty", action="store_false", default=True, dest="VETO_EMPTY",
                    help="disable the removal of empty (sumW=0) data objects _before_ applying merge heuristics. You probably want the default!")
parser.add_argument("--assume-normalized", action="store_true", default=False, dest="ASSUME_NORMALIZED",
                    help="DEPRECATED, AND DOES NOTHING. This option _used_ to bypass the detection heuristic for unnormalized histograms")
args = parser.parse_args()

## Include scatters in "add" mode
if args.STACK:
    args.S1D_MODE = "add"
    args.S2D_MODE = "add"
    args.S3D_MODE = "add"

## Put the incoming objects into a dict from each path to a list of histos and scalings
analysisobjects_in = {}
for fa in args.INFILES:
    filename, scale = fa, 1.0
    if ":" in fa:
        try:
            filename, scale = fa.rsplit(":", 1)
            scale = float(scale)
        except:
            sys.stderr.write("Error processing arg '%s' with file:scale format\n" % fa)
    aos = yoda.read(filename)
    for aopath, ao in aos.items():
        ao.setAnnotation("yodamerge_scale", scale)
        analysisobjects_in.setdefault(aopath, []).append(ao)


analysisobjects_out = {}
for p, aos in analysisobjects_in.items():

    ## Identify the canonical aotype being handled from the type of the first entry in aos
    aotype = type(aos[0])

    ## Check that types match, and just output the first one if they don't
    if not all(type(ao) is aotype for ao in aos):
        msg = "WARNING: cannot merge mismatched analysis object types for path %s: " % p
        scatter_fail = False
        if args.TYPE_MISMATCH_MODE == "scatter":
            saos = []
            for ao in aos:
                sao = ao.mkScatter()
                sao.setAnnotation("yodamerge_scale", ao.annotation("yodamerge_scale"))
                saos.append(sao)
            saotype = type(saos[0])
            msg += "converting to %ss" % saotype.__name__
            if all(type(sao) is saotype for sao in saos):
                sys.stderr.write(msg + "\n")
                aos = saos
                aotype = saotype
            else:
                msg += "... failed, "
                scatter_fail = True
        if args.TYPE_MISMATCH_MODE == "first" or scatter_fail:
            sys.stderr.write(msg + "returning first object\n")
            analysisobjects_out[p] = aos[0]
            continue


    ## Remove empty fillable data objects, to avoid gotchas where e.g. histos are normalised and hence
    ## ScaledBy should be set... but isn't because the emptiness blocked rescaling to finite area
    if args.VETO_EMPTY:
        # TODO: Add a Fillable interface/ABC and use that for the type matching
        if aotype in (yoda.Counter, yoda.Histo1D, yoda.Histo2D, yoda.Profile1D, yoda.Profile2D):
            aos_nonzero = [ao for ao in aos if ao.sumW() != 0] #< possible that this doesn't mean no fills :-/
            ## Just output the first histo if they are all empty
            if not aos_nonzero:
                analysisobjects_out[p] = aos[0]
                continue
            ## Reset aos to only contain non-empty ones
            aos = aos_nonzero


    ## Counter, Histo and Profile (i.e. Fillable) merging
    # TODO: Add a Fillable interface/ABC and use that for the type matching
    if aotype in (yoda.Counter, yoda.Histo1D, yoda.Histo2D, yoda.Profile1D, yoda.Profile2D):

        ## Identify a target rescaling factor from the 1/scalefactor-weighted norms of each run
        rescale = None
        if len(aos) > 1 and args.STACK:
            pass # we're in dumb stacking mode
        elif all("ScaledBy" in ao.annotations() for ao in aos):
            try:
                rescale = 1.0 / sum(float(ao.annotation("yodamerge_scale"))/float(ao.annotation("ScaledBy")) for ao in aos)
            except ZeroDivisionError:
                sys.stderr.write("WARNING: Abandoning normalized merge of path %s because ScaledBy attributes are zero.\n" % p)
        elif all("ScaledBy" not in ao.annotations() for ao in aos):
            pass
        else:
            sys.stderr.write("WARNING: Abandoning normalized merge of path %s because some but not all inputs have ScaledBy attributes\n" % p)

        ## Now that the normalization-identifying heuristic is done, apply user scalings and undo the normalization scaling if appropriate
        for ao in aos:
            if rescale:
                ao.scaleW( 1.0/float(ao.annotation("ScaledBy")) )
            ao.scaleW( float(ao.annotation("yodamerge_scale")) )

        ## Make a copy of the (scaled & unnormalized) first object as the basis for the output
        ## and merge for histograms (including weights, normalization, and user scaling)
        ao_out = aos[0].clone()
        ao_out.rmAnnotation("yodamerge_scale")
        for ao in aos[1:]:
            ao_out += ao
        if rescale:
            ao_out.scaleW(rescale)


    ## Merge for Scatters, assuming equal run sizes, and applying user scaling
    else:

        ## Make a copy of the first object as the basis for merging (suitable for all Scatter types)
        ao_out = aos[0].clone()
        ao_out.rmAnnotation("yodamerge_scale")
        ## If there's only one object, there's no need to do any combining
        if len(aos) == 1:
            pass

        elif aotype in (yoda.Scatter1D,yoda.Scatter2D,yoda.Scatter3D):

            ## Retrieve dimensionality of the Scatter*D object
            dim = ao_out.dim
            SND_MODE = getattr(args, "S%dD_MODE" % dim)
            axis = ['','x','y','z']

            ## Use asymptotic mean+stderr convergence statistics
            if SND_MODE in ("assume_mean", "add"):

                msg = "WARNING: Scatter%dD %s merge assumes asymptotic statistics and equal run sizes" % (dim, p)
                if any(float(ao.annotation("yodamerge_scale")) != 1.0 for ao in aos):
                    msg += " (+ user scaling)"
                sys.stderr.write(msg + "\n")
                npoints = len(ao_out.points())
                for i in range(npoints):
                    val_i = scalesum = 0.0
                    ep_i = {} # will hold the values of the multiple error sources
                    em_i = {} # will hold the values of the multiple error sources
                    for ao in aos:
                        scale = float(ao.annotation("yodamerge_scale"))
                        variations = ao.variations()
                        scalesum += scale
                        val_i += scale * ao.point(i).val(dim)
                        for var in variations:
                            if not var in ep_i.keys():
                                ep_i[var] = 0.
                                em_i[var] = 0.
                            ep_i[var] += (scale * ao.point(i).errs(dim,var)[0])**2
                            em_i[var] += (scale * ao.point(i).errs(dim,var)[1])**2
                    for var in ep_i.keys():
                        ep_i[var] = math.sqrt(ep_i[var])
                        em_i[var] = math.sqrt(em_i[var])
                    if SND_MODE == "assume_mean":
                        val_i /= scalesum
                        for var in ep_i.keys():
                            ep_i[var] /= scalesum
                            em_i[var] /= scalesum
                    setattr(ao_out.point(i),'%s' % axis[dim], val_i)
                    for var in ep_i.keys():
                        #setattr(ao_out.point(i),'set%sErrs' % axis[dim].upper(), ((ep_i[var], em_i[var]),var))
                        ao_out.point(i).setErrs(dim , (ep_i[var], em_i[var]),var)

            ## Add more points to the output scatter
            elif SND_MODE == "combine":
                for ao in aos[1:]:
                    ao_out.combineWith(ao)

            ## Just return the first AO unmodified & unmerged
            elif SND_MODE == "first":
                pass

            else:
                raise Exception("Unknown Scatter%dD merging mode:" % dim + args.SND_MODE)

        ## Other data types (just warn, and write out the first object)
        else:
            sys.stderr.write("WARNING: Analysis object %s of type %s cannot be merged\n" % (p, str(aotype)))

    ## Put the output AO into the output dict
    analysisobjects_out[p] = ao_out

## Write output
yoda.writeYODA(analysisobjects_out, args.OUTPUT_FILE)
