/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.rewrite.HopRewriteRule;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;

public class RewriteElementwiseMultChainOptimization
extends HopRewriteRule {
    private static final Comparator<Hop> compareByDataType = new Comparator<Hop>(){
        private final int[] orderDataType = new int[Types.DataType.values().length];
        {
            int valuesLength = Types.DataType.values().length;
            block7: for (int i = 0; i < valuesLength; ++i) {
                switch (Types.DataType.values()[i]) {
                    case SCALAR: {
                        this.orderDataType[i] = 0;
                        continue block7;
                    }
                    case MATRIX: {
                        this.orderDataType[i] = 1;
                        continue block7;
                    }
                    case TENSOR: {
                        this.orderDataType[i] = 2;
                        continue block7;
                    }
                    case FRAME: {
                        this.orderDataType[i] = 3;
                        continue block7;
                    }
                    case LIST: {
                        this.orderDataType[i] = 5;
                        continue block7;
                    }
                    default: {
                        this.orderDataType[i] = 4;
                    }
                }
            }
        }

        @Override
        public final int compare(Hop o1, Hop o2) {
            int c = Integer.compare(this.orderDataType[o1.getDataType().ordinal()], this.orderDataType[o2.getDataType().ordinal()]);
            if (c != 0) {
                return c;
            }
            switch (o1.getDataType()) {
                case MATRIX: {
                    if (o1.getDim2() == 1L) {
                        if (o2.getDim2() != 1L) {
                            return -1;
                        }
                        return this.compareBySparsityThenId(o1, o2);
                    }
                    if (o2.getDim2() == 1L) {
                        return 1;
                    }
                    if (o1.getDim1() == 1L) {
                        if (o2.getDim1() != 1L) {
                            return -1;
                        }
                        return this.compareBySparsityThenId(o1, o2);
                    }
                    if (o2.getDim1() == 1L) {
                        return 1;
                    }
                    return this.compareBySparsityThenId(o1, o2);
                }
            }
            return Long.compare(o1.getHopID(), o2.getHopID());
        }

        private int compareBySparsityThenId(Hop o1, Hop o2) {
            int c = Long.compare(o1.getNnz(), o2.getNnz());
            if (c != 0) {
                return -c;
            }
            return Long.compare(o1.getHopID(), o2.getHopID());
        }
    };

    @Override
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
        if (roots == null) {
            return null;
        }
        for (int i = 0; i < roots.size(); ++i) {
            Hop h = roots.get(i);
            roots.set(i, RewriteElementwiseMultChainOptimization.rule_RewriteEMult(h));
        }
        return roots;
    }

    @Override
    public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
        if (root == null) {
            return null;
        }
        return RewriteElementwiseMultChainOptimization.rule_RewriteEMult(root);
    }

    private static boolean isBinaryMult(Hop hop) {
        return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Types.OpOp2.MULT;
    }

    private static Hop rule_RewriteEMult(Hop root) {
        if (root.isVisited()) {
            return root;
        }
        root.setVisited();
        if (RewriteElementwiseMultChainOptimization.isBinaryMult(root) && root.dimsKnown()) {
            Hop left = root.getInput().get(0);
            Hop right = root.getInput().get(1);
            HashSet<BinaryOp> emults = new HashSet<BinaryOp>();
            HashMap<Hop, Integer> leaves = new HashMap<Hop, Integer>();
            RewriteElementwiseMultChainOptimization.findEMultsAndLeaves((BinaryOp)root, emults, leaves);
            if (emults.size() >= 2) {
                boolean okay;
                boolean bl = okay = !(RewriteElementwiseMultChainOptimization.isBinaryMult(left) && !RewriteElementwiseMultChainOptimization.checkForeignParent(emults, (BinaryOp)left) || RewriteElementwiseMultChainOptimization.isBinaryMult(right) && !RewriteElementwiseMultChainOptimization.checkForeignParent(emults, (BinaryOp)right));
                if (okay) {
                    Hop replacement = RewriteElementwiseMultChainOptimization.constructReplacement(emults, leaves);
                    if (LOG.isDebugEnabled()) {
                        LOG.debug((Object)String.format("Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d", emults.size(), root.getHopID(), replacement.getHopID()));
                    }
                    Hop newRoot = HopRewriteUtils.rewireAllParentChildReferences(root, replacement);
                    for (Hop leaf : leaves.keySet()) {
                        RewriteElementwiseMultChainOptimization.recurseInputs(leaf);
                    }
                    return newRoot;
                }
            }
        }
        RewriteElementwiseMultChainOptimization.recurseInputs(root);
        return root;
    }

    private static void recurseInputs(Hop parent) {
        ArrayList<Hop> inputs = parent.getInput();
        for (int i = 0; i < inputs.size(); ++i) {
            Hop input = inputs.get(i);
            Hop newInput = RewriteElementwiseMultChainOptimization.rule_RewriteEMult(input);
            inputs.set(i, newInput);
        }
    }

    private static Hop constructReplacement(Set<BinaryOp> emults, Map<Hop, Integer> leaves) {
        TreeSet<Hop> sorted = new TreeSet<Hop>(compareByDataType);
        for (Map.Entry<Hop, Integer> entry : leaves.entrySet()) {
            Hop h = entry.getKey();
            h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent));
            sorted.add(RewriteElementwiseMultChainOptimization.constructPower(h, entry.getValue()));
        }
        Iterator iterator = sorted.iterator();
        Hop next = iterator.hasNext() ? (Hop)iterator.next() : null;
        Hop colVectorsScalars = null;
        while (next != null && (next.getDataType() == Types.DataType.SCALAR || next.getDataType() == Types.DataType.MATRIX && next.getDim2() == 1L)) {
            if (colVectorsScalars == null) {
                colVectorsScalars = next;
            } else {
                colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Types.OpOp2.MULT);
                colVectorsScalars.setVisited();
            }
            next = iterator.hasNext() ? (Hop)iterator.next() : null;
        }
        Hop rowVectors = null;
        while (next != null && next.getDataType() == Types.DataType.MATRIX && next.getDim1() == 1L) {
            if (rowVectors == null) {
                rowVectors = next;
            } else {
                rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Types.OpOp2.MULT);
                rowVectors.setVisited();
            }
            next = iterator.hasNext() ? (Hop)iterator.next() : null;
        }
        Hop matrices = null;
        while (next != null && next.getDataType() == Types.DataType.MATRIX) {
            if (matrices == null) {
                matrices = next;
            } else {
                matrices = HopRewriteUtils.createBinary(matrices, next, Types.OpOp2.MULT);
                matrices.setVisited();
            }
            next = iterator.hasNext() ? (Hop)iterator.next() : null;
        }
        Hop other = null;
        while (next != null) {
            if (other == null) {
                other = next;
            } else {
                other = HopRewriteUtils.createBinary(other, next, Types.OpOp2.MULT);
                other.setVisited();
            }
            next = iterator.hasNext() ? (Hop)iterator.next() : null;
        }
        Hop top = null;
        if (other == null && matrices != null) {
            top = matrices;
        } else if (other != null && matrices == null) {
            top = other;
        } else if (other != null) {
            top = HopRewriteUtils.createBinary(other, matrices, Types.OpOp2.MULT);
            top.setVisited();
        }
        if (top == null && rowVectors != null) {
            top = rowVectors;
        } else if (rowVectors != null) {
            top = HopRewriteUtils.createBinary(top, rowVectors, Types.OpOp2.MULT);
            top.setVisited();
        }
        if (top == null && colVectorsScalars != null) {
            top = colVectorsScalars;
        } else if (colVectorsScalars != null) {
            top = HopRewriteUtils.createBinary(top, colVectorsScalars, Types.OpOp2.MULT);
            top.setVisited();
        }
        return top;
    }

    private static Hop constructPower(Hop hop, int cnt) {
        assert (cnt >= 1);
        hop.setVisited();
        if (cnt == 1) {
            return hop;
        }
        BinaryOp pow = HopRewriteUtils.createBinary(hop, (Hop)new LiteralOp(cnt), Types.OpOp2.POW);
        pow.setVisited();
        return pow;
    }

    private static boolean checkForeignParent(Set<BinaryOp> emults, BinaryOp child) {
        ArrayList<Hop> parents = child.getParent();
        if (parents.size() > 1) {
            for (Hop parent : parents) {
                if (parent instanceof BinaryOp && emults.contains(parent)) continue;
                return false;
            }
        }
        ArrayList<Hop> inputs = child.getInput();
        Hop left = inputs.get(0);
        Hop right = inputs.get(1);
        return !(RewriteElementwiseMultChainOptimization.isBinaryMult(left) && !RewriteElementwiseMultChainOptimization.checkForeignParent(emults, (BinaryOp)left) || RewriteElementwiseMultChainOptimization.isBinaryMult(right) && !RewriteElementwiseMultChainOptimization.checkForeignParent(emults, (BinaryOp)right));
    }

    private static void findEMultsAndLeaves(BinaryOp root, Set<BinaryOp> emults, Map<Hop, Integer> leaves) {
        emults.add(root);
        ArrayList<Hop> inputs = root.getInput();
        Hop left = inputs.get(0);
        Hop right = inputs.get(1);
        if (RewriteElementwiseMultChainOptimization.isBinaryMult(left)) {
            RewriteElementwiseMultChainOptimization.findEMultsAndLeaves((BinaryOp)left, emults, leaves);
        } else {
            RewriteElementwiseMultChainOptimization.addMultiset(leaves, left);
        }
        if (RewriteElementwiseMultChainOptimization.isBinaryMult(right)) {
            RewriteElementwiseMultChainOptimization.findEMultsAndLeaves((BinaryOp)right, emults, leaves);
        } else {
            RewriteElementwiseMultChainOptimization.addMultiset(leaves, right);
        }
    }

    private static <K> void addMultiset(Map<K, Integer> map, K k) {
        map.put(k, map.getOrDefault(k, 0) + 1);
    }
}

