#! /usr/bin/env python3

"""\
%(prog)s <yodafile1> [<yodafile2> ...] -r <yodareffile_or_dir> [-r ...]

Rescale YODA histograms by constant factors or to match (partial) normalizations
of histos in a reference file.

Examples:

%(prog)s run1.yoda run2.yoda -r refdir/ -r reffile.yoda
  Rescales histograms in run1 and run2 to their namesakes in the given ref files
  (including all .yoda found in refdir). Writes out to run{1,2}-scaled.yoda

%(prog)s -c '.* 0.123x' foo.yoda
  Rescales all histograms in foo.yoda by the factor 0.123.

This latter example is a demonstration of the general scaling specification
syntax, many of which can be either given with -c/--spec command-line options,
or can be placed in a file read by the -f/--specfile options.

The scaling specification syntax is a histogram path pattern followed by a
definition of the scaling to be done on that histogram (or histograms). The path
pattern consists of a regular expression, optionally followed by a range
specifier. The best explanation is probably a few examples:

  /path/to/hist      (match all bins in that particular histogram)
  /.*pt              (match all bins in histograms ending in 'pt')
  /myhist@5          (match bins with high edge >= 5 in /myhist)
  /myhist@@10        (match bins with low edge <= 10 in /myhist)
  /myhist@3@10       (match bins with edges within 3-10 in /myhist)
  /myhist#3:10       (match bins with index 3-10 inclusive in /myhist)
  /myhist#5          (match bins with index >= 5 in /myhist)
  /myhist##10        (match bins with index <= 10 in /myhist)
Mixing of bin index and bin position range locators is not currently allowed.

The scaling definition which follows this path pattern can also take several
forms. Again, examples:
  10         (scale the matching histos/bin ranges to normalize to 10)
  2x         (scale the matching histos by a factor of 2)
  2xREF      (scale the matching histos/ranges to 2x the normalization of
              their reference namesake)
  3.14x/some/other/path  (rescale to the normalization of a different ref histo)

This scheme is fairly complex and may evolve slightly as we try to make the
syntax as natural as possible, particularly for the simplest & most common cases,
while retaining the general power evident from the examples above. Please report
bugs and wishes to the YODA authors at yoda@projects.hepforge.org

TODO:
 * x scaling of only the given range?
 * check that ref norm scaling respects the range limits
 * check that ranges can also be given on the RHS path
 * y-scaling of profile histograms (requires scaleY & scaleZ etc. methods)
 * add overflow inclusion in normalization for binned types
"""

## Parse command line args
import argparse
parser = argparse.ArgumentParser(usage=__doc__)
# TODO: how to look up ref histos?
parser.add_argument("INFILES", nargs="+",
                    help="file or folder with reference histos")
parser.add_argument("-r", "--ref", dest="REFS", action="append", default=[],
                    help="file or folder with reference histos")
parser.add_argument("--ref-prefix", dest="REF_PREFIX", metavar="NAME", default="REF",
                    help="treat /NAME/foo as a reference plot for /foo, and don't rescale /NAME histos")
parser.add_argument("-c", "--spec", dest="SPECS", metavar="SPECSTR", action="append", default=[],
                    help="provide a single scaling specification on the command line. Multiple -c options " +
                    "may be given. Specs will be _appended_ to any read from a file with -f/--specfile")
parser.add_argument("--norm-to-area", "--norm-to-integral", dest="NORM_AREA", action="store_true", default=False,
                    help="divide by the integral / area under the curve")
parser.add_argument("--norm-to-sum", dest="NORM_SUM", action="store_true", default=False,
                    help="divide by the sum of weights")
parser.add_argument("-f", "--specfile", dest="SPECFILE", metavar="FILE", default=None,
                    help="specify a file with histogram path patterns (and bin ranges) that are to be normalised")
parser.add_argument("-o", "--output", default=None, dest="OUTPUT_FILE", metavar="PATH",
                    help="write output to specified path")
parser.add_argument("-i", "--in-place", dest="IN_PLACE", default=False, action="store_true",
                    help="overwrite input file(s) rather than making <input>-scaled.yoda[.gz]")
parser.add_argument("-q", "--quiet", dest="VERBOSITY", action="store_const", const=0, default=1,
                    help="reduce printouts to errors-only")
parser.add_argument("-v", "--debug", dest="VERBOSITY", action="store_const", const=2, default=1,
                    help="increase printouts to include debug info")
args = parser.parse_args()

import sys
if args.OUTPUT_FILE is not None and args.IN_PLACE:
    sys.stderr.write("ERROR: The --output and --in-place options are incompatible\n")
    sys.exit(1)

if args.NORM_AREA and args.NORM_SUM:
    sys.stderr.write("ERROR: The --norm-to-area/--norm-to-integral and --norm-to-sum options are incompatible\n")
    sys.exit(1)

## Define parser for scale specification strings
from yoda.search import PointMatcher

def parse_specstr(line):
    ## Strip comments
    line = " " + line
    if " #" in line:
        line = line[:line.index(" #")]
    ## Split whitespace-separated target and ref parts
    parts = line.strip().split()
    pathpatt, scalespec = parts[0], " ".join(parts[1:])
    ## Match and extract the spec command structure
    import re
    re_scalespec = re.compile(r"([\d\.eE\+\-]+x?|UNSCALE)?(.*)")
    m = re_scalespec.match(scalespec)
    if not m:
        raise Exception("Invalid scaling spec string: '%s'" % scalespec)
    scalearg, refpatt = m.groups()
    if not scalearg:
        scalearg = "1x"
    if scalearg.endswith("x"):
        scaleop = "x"
        scalearg = float(scalearg[:-1])
    elif scalearg == "UNSCALE":
        scaleop = "x"
    else:
        scaleop = "="
        scalearg = float(scalearg)
    rtn = (PointMatcher(pathpatt), PointMatcher(refpatt), scaleop, scalearg)
    #print(rtn)
    return rtn


## Parse spec file and command-line specs
SPECS = []
if args.SPECFILE:
    with open(args.SPECFILE, "r") as f:
        for line in f:
            line = line.strip()
            if line.startswith("#"):
                continue
            SPECS.append( parse_specstr(line) )
SPECS += [parse_specstr(s) for s in args.SPECS]


## Read reference histograms
import yoda, os, glob, math
reffiles = []
for r in args.REFS:
    if os.path.isdir(r):
        reffiles += glob.glob( os.path.join(r, "*.yoda*") ) #< TODO: Add a yoda.util function for finding files that it can read
    elif r.endswith(".yoda.gz") or r.endswith(".yoda"):
        reffiles.append(r)
aos_ref = {}
for r in reffiles:
    aos = yoda.read(r)
    for aopath, ao in aos.items():
        ## Use /REF_PREFIX/foo as a ref plot for /foo, if the ref prefix is set
        if args.REF_PREFIX and aopath.startswith("/%s/" % args.REF_PREFIX):
            aopath = aopath.replace("/%s/" % args.REF_PREFIX, "/", 1) # NB. ao.path is unchanged
        aos_ref[aopath] = ao
if args.VERBOSITY > 1:
    print("DEBUG: %d reference histos" % len(aos_ref))


## Loop over input files for rescaling
for infile in args.INFILES:
    aos_out = dict()
    aos_in = yoda.read(infile)
    for aopath, ao in aos_in.items():

        ## Default is to write out the unrescaled AO if no match is found
        aos_out[aopath] = ao

        ## Don't rescale /REF_PREFIX objects
        if aopath.startswith("/%s/" % args.REF_PREFIX):
            if args.VERBOSITY > 1:
                print("DEBUG: not rescaling ref object '%s'" % aopath)
            continue


        ## Match specs to MC AO path
        aospecs = []
        for s in SPECS:
            if s[0].search_path(aopath):
                aoref, (matcher, mode, factor) = None, s[1:]
                # TODO: s[1] -> matcher?
                ## Match to ref AOs if any were given
                if s[1].path is not None:
                    refpaths = []
                    if s[1].path.pattern == "REF":
                        if aopath in aos_ref:
                            refpaths = [aopath]
                    else:
                        refpaths = [p for p in list(aos_ref.keys()) if (s[1] and s[1].search_path(p))]
                    if args.VERBOSITY > 1:
                        print("DEBUG:", aopath, "=>", refpaths)
                    if not refpaths:
                        print("ERROR: No reference histogram '%s' found for '%s'. Not rescaling" % (s[1].path.pattern, aopath))
                        # TODO: should we actually skip it from the output entirely in this case?
                        continue
                    if len(refpaths) > 1:
                        if args.VERBOSITY > 0:
                            print("WARNING: Multiple reference histograms found for '%s'. Using first: '%s'" % (aopath, refpaths[0]))
                    aoref = aos_ref[refpaths[0]]

                aospecs.append( [aoref, s[1], mode, factor] )


        ## Identify and check spec to be used for scaling determination
        if not aospecs:
            aoref, matcher, mode, factor = None, None, "x", 1.0
            if args.VERBOSITY > 1:
                print("DEBUG: No scaling spec found for '%s'. Output is unscaled" % aopath)
        else:
            aoref, matcher, mode, factor = aospecs[0]
            if len(aospecs) > 1 and args.VERBOSITY > 1:
                print("WARNING: Multiple scaling specs found for '%s'. Using first: '%r'" % (aopath, aospecs[0]))


        ## Work out scalefactor
        if aoref is None and mode == "x":
            if factor != "UNSCALE":
                denom = 1.0
                if ao.dim() > 1 and args.NORM_AREA:
                    denom = None
                    # calculate the integral
                    if 'Estimate' in ao.type():
                        denom = ao.auc()
                    elif 'Scatter' in ao.type():
                        # For scatters, the highest dimension is assumed to be the
                        # dependent axis, and the sum of the two error components
                        # at each point corresponds to the "bin width" along that
                        # axis. The integral is then given by the "height" times
                        # the product of the widths:
                        denom = sum(p.val(ao.dim()-1) * math.prod( sum(p.errs(i)) for i in range(ao.dim()-1) ) \
                                                                                  for p in ao.points())
                    else:
                        denom = ao.integral()
                elif ao.dim() > 1 and args.NORM_SUM:
                    denom = sum(ao.vals()) if 'Estimate' in ao.type() else \
                            sum(ao.vals(ao.dim()-1)) if 'Scatter' in ao.type() else b.sumW()
                sf = factor/denom
            elif ao.hasAnnotation("ScaledBy"):
                sf = 1/float(ao.annotation("ScaledBy"))
            else:
                sf = 1.0

        else:

            # TODO: check binning compatibility (between types... hmm, check via equiv scatters?)
            # TODO: need a function on Scatter for this? Or would break symmetry? (including 3D scatters)

            ## Convert types to scatters
            # TODO: depends on mode and types. Histo/Profile have outflows, Scatters do not
            # TODO: provide a mode to only do rescales on Histos, not Profiles and Scatters?
            # TODO: ranges (index or val) from spec file. Need syntax for 2D ranges
            # TODO: assume full integral for now, but ignore overflows
            # TODO: check that the Scatter type attribute matches the non-Scatter type
            ## Note that we are assuming that it is the Scatter y (or z) axis that is to be rescaled
            s = ao.mkScatter()
            sref = None
            if aoref:
                sref = aoref.mkScatter()
            if sref and type(s) is not type(sref):
                if args.VERBOSITY > 0:
                    print("WARNING: Type mismatch between Scatter reps of '%s'. Are input and ref histos of same dimension?" % aopath)
                continue

            def matchpoint(i, p):
                class Pt(object):
                    def __init__(self, n, xmin, xmax):
                        self.n    = n
                        self.xmin = xmin
                        self.xmax = xmax

                pt = Pt(i, p.xMin(), p.xMax())
                result = matcher.match_pos(pt)
                #print(result)
                return result


            ## Work out normalisations
            # TODO: Are these at all appropriate (with width factor?) for profiles? And ratios? Need an "ignore width" mode flag?
            if type(s) is yoda.Scatter1D:
                norm = sum(p.x() for i, p in enumerate(s.points()) if matchpoint(i, p))
                if norm == 0:
                    print("ERROR: Normalisation of given range is 0; cannot rescale this to a finite value. Result will be unscaled")
                    sf = 1.0
                elif mode == "=":
                    sf = factor/norm
                elif mode == "x":
                    refnorm = sum(p.x() for i, p in enumerate(sref.points()) if matchpoint(i, p))
                    sf = factor*refnorm/norm

            if type(s) is yoda.Scatter2D:
                ## NOTE: sum(errs) only works if there's only one +- pair
                assert(all(len(p.xErrs()) == 2 for p in s.points()))
                norm = sum(p.y() * sum(p.xErrs()) for i, p in enumerate(s.points()) if matchpoint(i, p))
                if norm == 0:
                    print("ERROR: Normalisation of given range is 0; cannot rescale this to a finite value. Result will be unscaled")
                    sf = 1.0
                elif mode == "=":
                    sf = factor/norm
                elif mode == "x":
                    assert(all(len(p.xErrs()) == 2 for p in sref.points()))
                    refnorm = sum(p.y() * sum(p.xErrs()) for i, p in enumerate(sref.points()) if matchpoint(i, p))
                    sf = factor*refnorm/norm
            #
            elif type(s) is yoda.Scatter3D:
                ## NOTE: sum(errs) only works if there's only one +- pair
                if args.VERBOSITY > 0:
                    print("WARNING: Point/bin range restrictions are not yet implemented for 2D scatter / 3D histo types")
                    assert(all(len(p.xErrs()) == 2 and len(p.yErrs()) == 2) for p in s.points())
                    norm = sum(p.z() * sum(p.xErrs()) * sum(p.yErrs()) for i, p in enumerate(s.points()) if matchpoint(i, p))
                if norm == 0:
                    print("ERROR: Normalisation of given range is 0; cannot rescale this to a finite value. Result will be unscaled")
                    sf = 1.0
                elif mode == "=":
                    sf = factor/norm
                elif mode == "x":
                    assert(all(len(p.xErrs()) == 2 and len(p.yErrs()) == 2) for p in sref.points())
                    refnorm = sum(p.z() * sum(p.xErrs()) * sum(p.yErrs()) for p in sref.points())
                    sf = factor*refnorm/norm


        ## Rescale
        if args.VERBOSITY > 1:
            print("DEBUG: '%s' rescaled by factor %.3g" % (aopath, sf))
        if type(ao) is yoda.Counter or 'Histo' in ao.type():
            ao.scaleW(sf)
        elif 'Profile' in ao.type():
            # TODO: should this scale y or sumW? A: sumW... only relevant for merging & "norm" is wrong concept
            print("WARNING:", aopath, "no scaling applied to " + ao.type())
        elif 'Estimate' in ao.type():
            ao.scale(sf)
        elif 'Scatter' in ao.type():
            ao.scale(ao.dim()-1, sf)
        else:
            print("WARNING: Unknown type '%s' encountered. No scaling applied." % ao.type())


        ## Store rescaled result
        aos_out[aopath] = ao

    ## Write out rescaled file, possibly in-place (i.e. replace input -- not the default behaviour!)
    if args.OUTPUT_FILE:
        outfile = args.OUTPUT_FILE
    elif args.IN_PLACE:
        outfile = infile
    else:
        gzext = ""
        stem, ext = os.path.splitext(os.path.basename(infile))
        if ext == ".gz":
            gzext = ".gz"
            stem, _ = os.path.splitext(stem)
        outfile = stem + "-scaled.yoda" + gzext

    yoda.write(aos_out.values(), outfile)
