/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops;

import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.MemoTable;
import org.apache.sysds.hops.MultiThreadedHop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.DnnTransform;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysds.runtime.matrix.data.DnnParameters;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.DnnUtils;

public class DnnOp
extends MultiThreadedHop {
    private static final Log LOG = LogFactory.getLog((String)DnnOp.class.getName());
    private static final boolean INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP = true;
    private static final boolean THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH = true;
    private Types.OpOpDnn op;
    private DnnParameters _cachedParams;

    private DnnOp() {
        this._cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, this._maxNumThreads);
    }

    public DnnOp(String l, Types.DataType dt, Types.ValueType vt, Types.OpOpDnn o, ArrayList<Hop> inp) {
        super(l, dt, vt);
        this._cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, this._maxNumThreads);
        this.op = o;
        for (int i = 0; i < inp.size(); ++i) {
            Hop in = inp.get(i);
            this.getInput().add(i, in);
            in.getParent().add(this);
        }
        this.refreshSizeInformation();
    }

    @Override
    public void checkArity() {
        HopsException.check(this._input.size() >= 1, this, "should have at least one input but has %d inputs", this._input.size());
    }

    public Types.OpOpDnn getOp() {
        return this.op;
    }

    @Override
    public String getOpString() {
        return this.op.toString();
    }

    private static boolean isEligibleForSpark() {
        return false;
    }

    @Override
    public boolean isGPUEnabled() {
        return DMLScript.USE_ACCELERATOR;
    }

    @Override
    public boolean isMultiThreadedOpType() {
        return true;
    }

    @Override
    public Lop constructLops() {
        if (this.getLops() != null) {
            return this.getLops();
        }
        Types.ExecType et = this.optFindExecType();
        ArrayList<Hop> inputs = this.getInput();
        switch (this.op) {
            case MAX_POOL: 
            case MAX_POOL_BACKWARD: 
            case AVG_POOL: 
            case AVG_POOL_BACKWARD: 
            case CONV2D: 
            case CONV2D_BACKWARD_DATA: 
            case CONV2D_BACKWARD_FILTER: 
            case BIASADD: 
            case BIASMULT: {
                if (et == Types.ExecType.CP || et == Types.ExecType.GPU) {
                    this.setLops(this.constructDnnLops(et, inputs));
                    break;
                }
                throw new HopsException("Unimplemented DnnOp for execution type: " + et.name());
            }
            case BATCH_NORM2D_TEST: 
            case CHANNEL_SUMS: 
            case UPDATE_NESTEROV_X: {
                if (et == Types.ExecType.GPU) {
                    this.setLops(this.constructDnnLops(et, inputs));
                    break;
                }
                throw new HopsException("Unimplemented DnnOp for execution type: " + et.name());
            }
            default: {
                throw new HopsException("Unsupported lops construction for operation type '" + this.op + "'.");
            }
        }
        this.constructAndSetLopsDataFlowProperties();
        return this.getLops();
    }

    public void setOp(Types.OpOpDnn op) {
        this.op = op;
    }

    private int getNumExpectedInputs() {
        switch (this.op) {
            case MAX_POOL_BACKWARD: 
            case AVG_POOL_BACKWARD: 
            case CONV2D: 
            case CONV2D_BACKWARD_DATA: 
            case CONV2D_BACKWARD_FILTER: {
                return 14;
            }
            case BIASADD: 
            case BIASMULT: {
                return 2;
            }
            case BATCH_NORM2D_TEST: {
                return 6;
            }
            case CHANNEL_SUMS: {
                return 3;
            }
            case UPDATE_NESTEROV_X: {
                return 4;
            }
        }
        return 13;
    }

    private static Hop isInputReLU(Hop input) {
        if (HopRewriteUtils.isBinary(input, Types.OpOp2.MAX)) {
            if (HopRewriteUtils.isLiteralOfValue(input.getInput().get(0), 0.0)) {
                return input.getInput().get(1);
            }
            if (HopRewriteUtils.isLiteralOfValue(input.getInput().get(1), 0.0)) {
                return input.getInput().get(0);
            }
            return null;
        }
        return null;
    }

    private static boolean isInputConv2d(Hop input) {
        return HopRewriteUtils.isDnn(input, Types.OpOpDnn.CONV2D);
    }

    private static boolean isPoolingParametersEqualAndKnown(DnnParameters param1, DnnParameters param2) {
        return DnnOp.isEqualAndKnown(param1.stride_h, param2.stride_h) && DnnOp.isEqualAndKnown(param1.stride_w, param2.stride_w) && DnnOp.isEqualAndKnown(param1.pad_h, param2.pad_h) && DnnOp.isEqualAndKnown(param1.pad_w, param2.pad_w) && DnnOp.isEqualAndKnown(param1.R, param2.R) && DnnOp.isEqualAndKnown(param1.S, param2.S) && DnnOp.isEqualAndKnown(param1.N, param2.N) && DnnOp.isEqualAndKnown(param1.C, param2.C) && DnnOp.isEqualAndKnown(param1.H, param2.H) && DnnOp.isEqualAndKnown(param1.W, param2.W);
    }

    public boolean isStride1Pad0() {
        DnnParameters tmp = this.parseInput();
        return tmp.stride_h == 1 && tmp.stride_w == 1 && tmp.pad_h == 0 && tmp.pad_w == 0;
    }

    private static boolean isEqualAndKnown(int val1, int val2) {
        return val1 >= 0 && val2 >= 0 && val1 == val2;
    }

    private Lop getMaxPoolOutputLop() {
        if (this.op == Types.OpOpDnn.MAX_POOL_BACKWARD || this.op == Types.OpOpDnn.AVG_POOL_BACKWARD) {
            Types.OpOpDnn opType = this.op == Types.OpOpDnn.MAX_POOL_BACKWARD ? Types.OpOpDnn.MAX_POOL : Types.OpOpDnn.AVG_POOL;
            Hop inputImage = this.getInput().get(0);
            for (Hop tmpParent : inputImage.getParent()) {
                DnnOp parent;
                if (!(tmpParent instanceof DnnOp) || (parent = (DnnOp)tmpParent).getOp() != opType || !DnnOp.isPoolingParametersEqualAndKnown(parent._cachedParams, this._cachedParams)) continue;
                return parent.constructLops();
            }
        }
        return null;
    }

    public Lop constructDnnLops(Types.ExecType et, ArrayList<Hop> inputs) {
        if (inputs.size() != this.getNumExpectedInputs()) {
            throw new HopsException("Incorrect number of inputs for " + this.op.name());
        }
        Lop lhsInputLop = null;
        Lop optionalRhsInputLop = null;
        ArrayList<Hop> inputsOfPotentiallyFusedOp = inputs;
        Types.OpOpDnn lopOp = this.op;
        Hop parentReLU = DnnOp.isInputReLU(inputs.get(0));
        if (OptimizerUtils.ALLOW_OPERATOR_FUSION && et == Types.ExecType.CP && this.op == Types.OpOpDnn.MAX_POOL && parentReLU != null) {
            lhsInputLop = parentReLU.constructLops();
            lopOp = Types.OpOpDnn.RELU_MAX_POOL;
        } else if (OptimizerUtils.ALLOW_OPERATOR_FUSION && et == Types.ExecType.CP && this.op == Types.OpOpDnn.MAX_POOL_BACKWARD && parentReLU != null) {
            lhsInputLop = parentReLU.constructLops();
            lopOp = Types.OpOpDnn.RELU_MAX_POOL_BACKWARD;
        } else if (OptimizerUtils.ALLOW_OPERATOR_FUSION && this.op == Types.OpOpDnn.BIASADD && DnnOp.isInputConv2d(inputs.get(0))) {
            lopOp = Types.OpOpDnn.CONV2D_BIAS_ADD;
            lhsInputLop = inputs.get(0).getInput().get(0).constructLops();
            optionalRhsInputLop = inputs.get(1).constructLops();
            inputsOfPotentiallyFusedOp = inputs.get(0).getInput();
        } else {
            lhsInputLop = inputs.get(0).constructLops();
        }
        double intermediateMemEstimate = this.computeIntermediateMemEstimate(-1L, -1L, -1L);
        if (et == Types.ExecType.GPU && this.getDim1() >= 0L && this.getDim2() >= 0L) {
            double optimisticIntermediateMemEstimate = (double)GPUContextPool.initialGPUMemBudget() - this.getOutputMemEstimate() - inputs.get(0).getOutputMemEstimate();
            if (optionalRhsInputLop != null) {
                optimisticIntermediateMemEstimate -= inputs.get(1).getOutputMemEstimate();
            }
            intermediateMemEstimate = Math.max(intermediateMemEstimate, optimisticIntermediateMemEstimate);
        }
        Lop optionalMaxPoolOutput = et == Types.ExecType.GPU ? this.getMaxPoolOutputLop() : null;
        Lop[] l2inputs = new Lop[inputsOfPotentiallyFusedOp.size() - 1];
        for (int i = 1; i < inputsOfPotentiallyFusedOp.size(); ++i) {
            l2inputs[i - 1] = inputsOfPotentiallyFusedOp.get(i).constructLops();
        }
        DnnTransform convolutionLop = new DnnTransform(lhsInputLop, lopOp, this.getDataType(), this.getValueType(), et, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads), intermediateMemEstimate);
        this.setOutputDimensions(convolutionLop);
        this.setLineNumbers(convolutionLop);
        lhsInputLop.addOutput(convolutionLop);
        if (optionalRhsInputLop != null) {
            convolutionLop.addInput(optionalRhsInputLop);
            optionalRhsInputLop.addOutput(convolutionLop);
        }
        for (int i = 0; i < l2inputs.length; ++i) {
            convolutionLop.addInput(l2inputs[i]);
            l2inputs[i].addOutput(convolutionLop);
        }
        if (optionalMaxPoolOutput != null) {
            convolutionLop.addInput(optionalMaxPoolOutput);
            optionalMaxPoolOutput.addOutput(convolutionLop);
        }
        convolutionLop.updateLopProperties();
        return convolutionLop;
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        if (this.getOp() == Types.OpOpDnn.BIASMULT) {
            if (DMLScript.USE_ACCELERATOR) {
                return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
            }
            return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, this.getInput().get(0).getSparsity());
        }
        double sparsity = 1.0;
        return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
    }

    private double computeIntermediateMemEstimateHelper(ArrayList<IntermediateDimensions> gpuIntermediates, ArrayList<IntermediateDimensions> cpIntermediates) {
        int numWorkers = (int)Math.min((long)OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads), Math.max(this.getDim("N"), 1L));
        if (DMLScript.USE_ACCELERATOR) {
            double oneThreadCPMemBudget;
            double gpuMemBudget = IntermediateDimensions.addEstimateSizes(gpuIntermediates, 1);
            double cpMemoryBudget = IntermediateDimensions.addEstimateSizes(cpIntermediates, numWorkers);
            if (cpMemoryBudget > gpuMemBudget && (oneThreadCPMemBudget = IntermediateDimensions.addEstimateSizes(cpIntermediates, 1)) <= gpuMemBudget) {
                cpMemoryBudget = oneThreadCPMemBudget;
            }
            return IntermediateDimensions.guardedMax(cpMemoryBudget, gpuMemBudget);
        }
        return IntermediateDimensions.addEstimateSizes(cpIntermediates, numWorkers);
    }

    @Override
    protected double computeIntermediateMemEstimate(long ignoreDim1, long ignoreDim2, long ignoreNnz) {
        ArrayList<IntermediateDimensions> gpuIntermediates = new ArrayList<IntermediateDimensions>();
        ArrayList<IntermediateDimensions> cpIntermediates = new ArrayList<IntermediateDimensions>();
        if (this.getOp() == Types.OpOpDnn.CONV2D) {
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
            gpuIntermediates.add(new IntermediateDimensions(this, "K", "CRS"));
            cpIntermediates.add(new IntermediateDimensions(this, "CRS", "PQ", this.getInput().get(0).getSparsity()));
        } else if (this.getOp() == Types.OpOpDnn.CONV2D_BACKWARD_DATA) {
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "KPQ"));
            gpuIntermediates.add(new IntermediateDimensions(this, "K", "CRS"));
            cpIntermediates.add(new IntermediateDimensions(this, "PQ", "K", this.getInput().get(1).getSparsity()));
            cpIntermediates.add(new IntermediateDimensions(this, "PQ", "CRS"));
        } else if (this.getOp() == Types.OpOpDnn.CONV2D_BACKWARD_FILTER) {
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "KPQ"));
            cpIntermediates.add(new IntermediateDimensions(this, "PQ", "K", this.getInput().get(1).getSparsity()));
            cpIntermediates.add(new IntermediateDimensions(this, "CRS", "PQ", this.getInput().get(0).getSparsity()));
        } else if (this.getOp() == Types.OpOpDnn.MAX_POOL || this.getOp() == Types.OpOpDnn.AVG_POOL) {
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
        } else if (this.getOp() == Types.OpOpDnn.MAX_POOL_BACKWARD || this.getOp() == Types.OpOpDnn.AVG_POOL_BACKWARD) {
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "CPQ"));
        }
        if (gpuIntermediates.size() > 0 || cpIntermediates.size() > 0) {
            return this.computeIntermediateMemEstimateHelper(gpuIntermediates, cpIntermediates);
        }
        return 0.0;
    }

    @Override
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) {
        DataCharacteristics ret = new MatrixCharacteristics();
        if (this.op == Types.OpOpDnn.BIASADD || this.op == Types.OpOpDnn.BIASMULT || this.op == Types.OpOpDnn.BATCH_NORM2D_TEST || this.op == Types.OpOpDnn.UPDATE_NESTEROV_X) {
            DataCharacteristics[] mc = memo.getAllInputStats(this.getInput());
            ret = new MatrixCharacteristics(mc[0].rowsKnown() ? mc[0].getRows() : -1L, mc[0].colsKnown() ? mc[0].getCols() : -1L, -1, -1L);
            return ret.dimsKnown() ? ret : null;
        }
        if (this.op == Types.OpOpDnn.CHANNEL_SUMS) {
            long numChannels = Hop.computeSizeInformation(this.getInput().get(1));
            return new MatrixCharacteristics(numChannels, 1L, -1, -1L);
        }
        this.refreshSizeInformation();
        ret = this._dc;
        return ret.dimsKnown() ? ret : null;
    }

    @Override
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override
    protected Types.ExecType optFindExecType(boolean transitive) {
        this.checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            this._etype = OptimizerUtils.isMemoryBasedOptLevel() ? this.findExecTypeByMemEstimate() : Types.ExecType.SPARK;
            this.checkAndSetInvalidCPDimsAndSize();
        }
        this._etype = !DnnOp.isEligibleForSpark() && this._etype == Types.ExecType.SPARK ? Types.ExecType.CP : this._etype;
        this.setRequiresRecompileIfNecessary();
        return this._etype;
    }

    DnnParameters parseInput() {
        boolean unknownCHWPQ;
        if (this.op == Types.OpOpDnn.MAX_POOL_BACKWARD || this.op == Types.OpOpDnn.AVG_POOL_BACKWARD || this.op == Types.OpOpDnn.CONV2D || this.op == Types.OpOpDnn.CONV2D_BACKWARD_FILTER || this.op == Types.OpOpDnn.CONV2D_BACKWARD_DATA) {
            this._cachedParams.setIfUnknown(this.getInput().get(6), this.getInput().get(7), this.getInput().get(8), this.getInput().get(9), this.getInput().get(10), this.getInput().get(12), this.getInput().get(13), this.getInput().get(2), this.getInput().get(3), this.getInput().get(4), this.getInput().get(5), this._maxNumThreads);
        } else {
            this._cachedParams.setIfUnknown(this.getInput().get(5), this.getInput().get(6), this.getInput().get(7), this.getInput().get(8), this.getInput().get(9), this.getInput().get(11), this.getInput().get(12), this.getInput().get(1), this.getInput().get(2), this.getInput().get(3), this.getInput().get(4), this._maxNumThreads);
        }
        boolean isPool = this.getOp() == Types.OpOpDnn.MAX_POOL || this.getOp() == Types.OpOpDnn.AVG_POOL;
        boolean isConv = this.getOp() == Types.OpOpDnn.CONV2D;
        boolean bl = unknownCHWPQ = this._cachedParams.C < 0 || this._cachedParams.H < 0 || this._cachedParams.W < 0 || this._cachedParams.P < 0 || this._cachedParams.Q < 0;
        if ((isPool || isConv) && unknownCHWPQ) {
            this.inferCHWPQFromParentOp();
        }
        if (this._cachedParams.R < 0 && this._cachedParams.H > 0) {
            this._cachedParams.R = this._cachedParams.H;
        }
        if (this._cachedParams.P < 0 && this._cachedParams.H >= 0 && this._cachedParams.R >= 0 && this._cachedParams.stride_h >= 0 && this._cachedParams.pad_h >= 0) {
            this._cachedParams.P = (int)DnnUtils.getP(this._cachedParams.H, this._cachedParams.R, this._cachedParams.stride_h, this._cachedParams.pad_h);
        }
        if (this._cachedParams.Q < 0 && this._cachedParams.W >= 0 && this._cachedParams.S >= 0 && this._cachedParams.stride_w >= 0 && this._cachedParams.pad_w >= 0) {
            this._cachedParams.Q = (int)DnnUtils.getQ(this._cachedParams.W, this._cachedParams.S, this._cachedParams.stride_w, this._cachedParams.pad_w);
        }
        return this._cachedParams;
    }

    private static boolean isInputBiasAdd(Hop hop) {
        return HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD);
    }

    private static void throwExceptionIfNotEqual(int dim1, int dim2, String paramType) {
        if (dim1 >= 0 && dim2 >= 0 && dim1 != dim2) {
            throw new DMLRuntimeException("Inferred " + paramType + " from parent doesn't match with given " + paramType + ":" + dim1 + " != " + dim2);
        }
    }

    private void inferCHWPQFromParentOp() {
        DnnOp parentOp;
        Hop tmp = this.getInput().get(0);
        tmp = DnnOp.isInputBiasAdd(tmp) ? tmp.getInput().get(0) : tmp;
        Hop parentReLU = DnnOp.isInputReLU(tmp);
        tmp = parentReLU != null ? parentReLU : tmp;
        DnnOp dnnOp = parentOp = tmp instanceof DnnOp ? (DnnOp)tmp : null;
        if (parentOp == null) {
            return;
        }
        if (parentOp.getOp() == Types.OpOpDnn.MAX_POOL || parentOp.getOp() == Types.OpOpDnn.AVG_POOL) {
            DnnParameters parentParam = parentOp.parseInput();
            int prevC = this._cachedParams.C;
            int prevH = this._cachedParams.H;
            int prevW = this._cachedParams.W;
            this._cachedParams.C = this._cachedParams.C < 0 ? parentParam.C : this._cachedParams.C;
            this._cachedParams.H = this._cachedParams.H < 0 ? parentParam.P : this._cachedParams.H;
            int n = this._cachedParams.W = this._cachedParams.W < 0 ? parentParam.Q : this._cachedParams.W;
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("Inferring [C,H,W] from maxpool parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + this._cachedParams.C + "," + this._cachedParams.H + "," + this._cachedParams.W + "]"));
            }
            DnnOp.throwExceptionIfNotEqual(prevC, this._cachedParams.C, "C");
            DnnOp.throwExceptionIfNotEqual(prevH, this._cachedParams.H, "H");
            DnnOp.throwExceptionIfNotEqual(prevW, this._cachedParams.W, "W");
        } else if (parentOp.getOp() == Types.OpOpDnn.CONV2D) {
            DnnParameters parentParam = parentOp.parseInput();
            int prevC = this._cachedParams.C;
            int prevH = this._cachedParams.H;
            int prevW = this._cachedParams.W;
            this._cachedParams.C = this._cachedParams.C < 0 ? parentParam.K : this._cachedParams.C;
            this._cachedParams.H = this._cachedParams.H < 0 ? parentParam.P : this._cachedParams.H;
            int n = this._cachedParams.W = this._cachedParams.W < 0 ? parentParam.Q : this._cachedParams.W;
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("Inferring [C,H,W] from maxpool parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + this._cachedParams.C + "," + this._cachedParams.H + "," + this._cachedParams.W + "]"));
            }
            DnnOp.throwExceptionIfNotEqual(prevC, this._cachedParams.C, "C");
            DnnOp.throwExceptionIfNotEqual(prevH, this._cachedParams.H, "H");
            DnnOp.throwExceptionIfNotEqual(prevW, this._cachedParams.W, "W");
        }
    }

    @Override
    public void refreshSizeInformation() {
        if (this.op == Types.OpOpDnn.BIASADD || this.op == Types.OpOpDnn.BIASMULT || this.op == Types.OpOpDnn.BATCH_NORM2D_TEST || this.op == Types.OpOpDnn.UPDATE_NESTEROV_X) {
            Hop input1 = this.getInput().get(0);
            this.setDim1(input1.getDim1());
            this.setDim2(input1.getDim2());
            this.setNnz(-1L);
            return;
        }
        if (this.op == Types.OpOpDnn.CHANNEL_SUMS) {
            long numChannels = Hop.computeSizeInformation(this.getInput().get(1));
            this.setDim1(numChannels);
            this.setDim2(1L);
            this.setNnz(-1L);
            return;
        }
        this._cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, this._maxNumThreads);
        switch (this.op) {
            case MAX_POOL: 
            case AVG_POOL: {
                this.setDim1(this.getDim("N"));
                this.setDim2(this.getDim("CPQ"));
                this.setNnz(-1L);
                break;
            }
            case MAX_POOL_BACKWARD: 
            case AVG_POOL_BACKWARD: {
                this.setDim1(this.getDim("N"));
                this.setDim2(this.getDim("CHW"));
                this.setNnz(-1L);
                break;
            }
            case CONV2D: {
                this.setDim1(this.getDim("N"));
                this.setDim2(this.getDim("KPQ"));
                this.setNnz(-1L);
                break;
            }
            case CONV2D_BACKWARD_DATA: {
                this.setDim1(this.getDim("N"));
                this.setDim2(this.getDim("CHW"));
                this.setNnz(-1L);
                break;
            }
            case CONV2D_BACKWARD_FILTER: {
                this.setDim1(this.getDim("K"));
                this.setDim2(this.getDim("CRS"));
                this.setNnz(-1L);
                break;
            }
            default: {
                throw new RuntimeException("The sizes are not refreshed for " + this.op.name());
            }
        }
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        DnnOp ret = new DnnOp();
        ret.clone(this, false);
        ret.op = this.op;
        ret._maxNumThreads = this._maxNumThreads;
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        boolean ret;
        if (!(that instanceof DnnOp)) {
            return false;
        }
        DnnOp that2 = (DnnOp)that;
        boolean bl = ret = this.op == that2.op && this.getInput().size() == that.getInput().size() && this._maxNumThreads == that2._maxNumThreads;
        if (ret) {
            for (int i = 0; i < this._input.size(); ++i) {
                ret &= this.getInput().get(i) == that2.getInput().get(i);
            }
        }
        return ret;
    }

    private long getDim(String dimString) {
        if (this.op == Types.OpOpDnn.BIASADD || this.op == Types.OpOpDnn.BIASMULT || this.op == Types.OpOpDnn.BATCH_NORM2D_TEST || this.op == Types.OpOpDnn.CHANNEL_SUMS || this.op == Types.OpOpDnn.UPDATE_NESTEROV_X) {
            throw new RuntimeException("getDim method should not be invoked for " + this.op.name());
        }
        try {
            this.parseInput();
        }
        catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
        Hop filter = null;
        Hop input = null;
        Hop dout = null;
        Hop dout1 = null;
        if (this.getOp() == Types.OpOpDnn.CONV2D) {
            input = this.getInput().get(0);
            filter = this.getInput().get(1);
        } else if (this.getOp() == Types.OpOpDnn.CONV2D_BACKWARD_DATA) {
            filter = this.getInput().get(0);
            dout = this.getInput().get(1);
        } else if (this.getOp() == Types.OpOpDnn.CONV2D_BACKWARD_FILTER) {
            input = this.getInput().get(0);
            dout = this.getInput().get(1);
        } else if (this.getOp() == Types.OpOpDnn.MAX_POOL || this.getOp() == Types.OpOpDnn.AVG_POOL) {
            input = this.getInput().get(0);
        } else if (this.getOp() == Types.OpOpDnn.MAX_POOL_BACKWARD || this.getOp() == Types.OpOpDnn.AVG_POOL_BACKWARD) {
            input = this.getInput().get(0);
            dout1 = this.getInput().get(1);
        }
        long ret = -1L;
        if (dimString.equals("K") && filter != null) {
            ret = DnnOp.getNonNegative(ret, DnnOp.getNonNegative(this._cachedParams.K, filter.getDim1()));
        } else if (dimString.equals("CRS") && filter != null) {
            ret = DnnOp.getNonNegative(ret, DnnOp.getNonNegative(DnnOp.nonNegativeMultiply(this._cachedParams.C, this._cachedParams.R, this._cachedParams.S), filter.getDim2()));
        } else if (dimString.equals("N") && input != null) {
            ret = DnnOp.getNonNegative(ret, DnnOp.getNonNegative(this._cachedParams.N, input.getDim1()));
        } else if (dimString.equals("CHW") && input != null) {
            ret = DnnOp.getNonNegative(ret, DnnOp.getNonNegative(DnnOp.nonNegativeMultiply(this._cachedParams.C, this._cachedParams.H, this._cachedParams.W), input.getDim2()));
        } else if (dimString.equals("N") && dout != null) {
            ret = DnnOp.getNonNegative(ret, DnnOp.getNonNegative(this._cachedParams.N, dout.getDim1()));
        } else if (dimString.equals("KPQ") && dout != null) {
            ret = DnnOp.getNonNegative(ret, DnnOp.getNonNegative(DnnOp.nonNegativeMultiply(this._cachedParams.K, this._cachedParams.P, this._cachedParams.Q), dout.getDim2()));
        } else if (dimString.equals("N") && dout1 != null) {
            ret = DnnOp.getNonNegative(ret, DnnOp.getNonNegative(this._cachedParams.N, dout1.getDim1()));
        } else if (dimString.equals("CPQ") && dout1 != null) {
            ret = DnnOp.getNonNegative(ret, DnnOp.getNonNegative(DnnOp.nonNegativeMultiply(this._cachedParams.C, this._cachedParams.P, this._cachedParams.Q), dout1.getDim2()));
        } else if (dimString.equals("K")) {
            ret = DnnOp.getNonNegative(ret, this._cachedParams.K >= 0 ? (long)this._cachedParams.K : -1L);
        } else if (dimString.equals("CRS")) {
            ret = DnnOp.getNonNegative(ret, DnnOp.nonNegativeMultiply(this._cachedParams.C, this._cachedParams.R, this._cachedParams.S));
        } else if (dimString.equals("N")) {
            ret = DnnOp.getNonNegative(ret, this._cachedParams.N >= 0 ? (long)this._cachedParams.N : -1L);
        } else if (dimString.equals("CHW")) {
            ret = DnnOp.getNonNegative(ret, DnnOp.nonNegativeMultiply(this._cachedParams.C, this._cachedParams.H, this._cachedParams.W));
        } else if (dimString.equals("KPQ")) {
            ret = DnnOp.getNonNegative(ret, DnnOp.nonNegativeMultiply(this._cachedParams.K, this._cachedParams.P, this._cachedParams.Q));
        } else if (dimString.equals("PQ")) {
            ret = DnnOp.getNonNegative(ret, DnnOp.nonNegativeMultiply(this._cachedParams.P, this._cachedParams.Q));
        } else if (dimString.equals("CPQ")) {
            ret = DnnOp.getNonNegative(ret, DnnOp.nonNegativeMultiply(this._cachedParams.C, this._cachedParams.P, this._cachedParams.Q));
        } else {
            throw new RuntimeException("Unsupported dimension:" + dimString + " for operator " + this.getOp().name());
        }
        if (LOG.isDebugEnabled() && ret < 0L) {
            LOG.debug((Object)("Unknown dimension " + dimString + " for DnnOp:" + this.op.name() + " img_dim=[" + this._cachedParams.N + " " + this._cachedParams.C + " " + this._cachedParams.H + " " + this._cachedParams.W + "] filter_dim=[" + this._cachedParams.K + " " + this._cachedParams.C + " " + this._cachedParams.R + " " + this._cachedParams.S + "] output_feature_map=[" + this._cachedParams.P + " " + this._cachedParams.Q + "] stride=[" + this._cachedParams.stride_h + " " + this._cachedParams.stride_w + "] pad=[" + this._cachedParams.pad_h + " " + this._cachedParams.pad_w + "]"));
        }
        return ret;
    }

    private static long nonNegativeMultiply(long val1, long val2, long val3) {
        if (val1 >= 0L && val2 >= 0L && val3 >= 0L) {
            return val1 * val2 * val3;
        }
        return -1L;
    }

    private static long nonNegativeMultiply(long val1, long val2) {
        if (val1 >= 0L && val2 >= 0L) {
            return val1 * val2;
        }
        return -1L;
    }

    private static long getNonNegative(long val1, long val2) {
        if (val1 >= 0L && val2 >= 0L) {
            if (val1 == val2) {
                return val1;
            }
            throw new RuntimeException("Incorrect dimensions in DnnOp: " + val1 + " != " + val2);
        }
        if (val1 >= 0L) {
            return val1;
        }
        if (val2 >= 0L) {
            return val2;
        }
        return -1L;
    }

    private static class IntermediateDimensions {
        int dim1;
        int dim2;
        double sp;

        public IntermediateDimensions(DnnOp h, String dim1Str, String dim2Str, double sp) {
            this.dim1 = (int)h.getDim(dim1Str);
            this.dim2 = (int)h.getDim(dim2Str);
            this.sp = sp;
        }

        public IntermediateDimensions(DnnOp h, String dim1Str, String dim2Str) {
            this.dim1 = (int)h.getDim(dim1Str);
            this.dim2 = (int)h.getDim(dim2Str);
            this.sp = 1.0;
        }

        public IntermediateDimensions(DnnOp h, int dim1, String dim2Str) {
            this.dim1 = dim1;
            this.dim2 = (int)h.getDim(dim2Str);
            this.sp = 1.0;
        }

        static double guardedAdd(double val1, double val2) {
            if (val1 < 0.0 || val2 < 0.0) {
                return OptimizerUtils.DEFAULT_SIZE;
            }
            double ret = val1 + val2;
            if (ret >= OptimizerUtils.DEFAULT_SIZE) {
                return OptimizerUtils.DEFAULT_SIZE;
            }
            return ret;
        }

        public static double addEstimateSizes(ArrayList<IntermediateDimensions> intermediates, int numWorkers) {
            double memBudget = 0.0;
            for (int i = 0; i < intermediates.size(); ++i) {
                memBudget = IntermediateDimensions.guardedAdd(memBudget, OptimizerUtils.estimateSizeExactSparsity((long)intermediates.get((int)i).dim1, (long)intermediates.get((int)i).dim2, intermediates.get((int)i).sp) * (long)numWorkers);
            }
            return memBudget;
        }

        public static double guardedMax(double val1, double val2) {
            if (val1 < 0.0 || val2 < 0.0) {
                return OptimizerUtils.DEFAULT_SIZE;
            }
            double ret = Math.max(val1, val2);
            if (ret >= OptimizerUtils.DEFAULT_SIZE) {
                return OptimizerUtils.DEFAULT_SIZE;
            }
            return ret;
        }
    }
}

