/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.IContainDefaultTuple;
import org.apache.sysds.runtime.compress.colgroup.IFrameOfReferenceGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.encoding.ConstEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.DenseEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.EmptyEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.estim.encoding.SparseEncoding;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public final class CLALibCombineGroups {
    protected static final Log LOG = LogFactory.getLog((String)CLALibCombineGroups.class.getName());

    private CLALibCombineGroups() {
    }

    public static List<AColGroup> combine(CompressedMatrixBlock cmb, int k) {
        ExecutorService pool = null;
        try {
            pool = k > 1 ? CommonThreadPool.get(k) : null;
            List<AColGroup> list = CLALibCombineGroups.combine(cmb, null, pool);
            return list;
        }
        catch (Exception e) {
            throw new DMLCompressionException("Compression Failed", e);
        }
        finally {
            if (pool != null) {
                pool.shutdown();
            }
        }
    }

    public static List<AColGroup> combine(CompressedMatrixBlock cmb, CompressedSizeInfo csi, ExecutorService pool) {
        double[] c;
        List<AColGroup> input = cmb.getColGroups();
        boolean filterFor = CLALibUtils.shouldFilterFOR(input);
        double[] dArray = c = filterFor ? new double[cmb.getNumColumns()] : null;
        if (filterFor) {
            input = CLALibUtils.filterFOR(input, c);
        }
        ArrayList<List<AColGroup>> combinations = new ArrayList<List<AColGroup>>();
        for (CompressedSizeInfoColGroup gi : csi.getInfo()) {
            combinations.add(CLALibCombineGroups.findGroupsInIndex(gi.getColumns(), input));
        }
        ArrayList<AColGroup> ret = new ArrayList<AColGroup>();
        if (filterFor) {
            for (List list : combinations) {
                ret.add(CLALibCombineGroups.combineN(list).addVector(c));
            }
        } else {
            for (List list : combinations) {
                ret.add(CLALibCombineGroups.combineN(list));
            }
        }
        return ret;
    }

    public static List<AColGroup> findGroupsInIndex(IColIndex idx, List<AColGroup> groups) {
        ArrayList<AColGroup> ret = new ArrayList<AColGroup>();
        for (AColGroup g : groups) {
            if (!g.getColIndices().containsAny(idx)) continue;
            ret.add(g);
        }
        return ret;
    }

    public static AColGroup combineN(List<AColGroup> groups) {
        AColGroup base = groups.get(0);
        for (int i = 1; i < groups.size(); ++i) {
            base = CLALibCombineGroups.combine(base, groups.get(i));
        }
        return base;
    }

    public static AColGroup combine(AColGroup a, AColGroup b) {
        try {
            if (a instanceof IFrameOfReferenceGroup || b instanceof IFrameOfReferenceGroup) {
                throw new DMLCompressionException("Invalid call with frame of reference group to combine");
            }
            IColIndex combinedColumns = ColIndexFactory.combine(a, b);
            if (a instanceof ColGroupUncompressed) {
                a = a.recompress();
            }
            if (b instanceof ColGroupUncompressed) {
                b = b.recompress();
            }
            if (a instanceof AColGroupCompressed && b instanceof AColGroupCompressed) {
                return CLALibCombineGroups.combineCompressed(combinedColumns, (AColGroupCompressed)a, (AColGroupCompressed)b);
            }
            if (a instanceof ColGroupUncompressed || b instanceof ColGroupUncompressed) {
                return CLALibCombineGroups.combineUC(combinedColumns, a, b);
            }
            throw new NotImplementedException("Not implemented combine for " + a.getClass().getSimpleName() + " - " + b.getClass().getSimpleName());
        }
        catch (Exception e) {
            StringBuilder sb = new StringBuilder();
            sb.append("Failed to combine:\n\n");
            sb.append(a);
            sb.append("\n\n");
            sb.append(b);
            throw new DMLCompressionException(sb.toString(), e);
        }
    }

    private static AColGroup combineCompressed(IColIndex combinedColumns, AColGroupCompressed ac, AColGroupCompressed bc) {
        IEncode ae = ac.getEncoding();
        IEncode be = bc.getEncoding();
        if (ae instanceof SparseEncoding && !(be instanceof SparseEncoding)) {
            return CLALibCombineGroups.combineCompressed(combinedColumns, bc, ac);
        }
        Pair<IEncode, Map<Integer, Integer>> cec = ae.combineWithMap(be);
        IEncode ce = (IEncode)cec.getLeft();
        Map filter = (Map)cec.getRight();
        if (ce instanceof DenseEncoding) {
            DenseEncoding ced = (DenseEncoding)ce;
            IDictionary cd = DictionaryFactory.combineDictionaries(ac, bc, filter);
            return ColGroupDDC.create(combinedColumns, cd, ced.getMap(), null);
        }
        if (ce instanceof EmptyEncoding) {
            return new ColGroupEmpty(combinedColumns);
        }
        if (ce instanceof ConstEncoding) {
            IDictionary cd = DictionaryFactory.combineDictionaries(ac, bc, filter);
            return ColGroupConst.create(combinedColumns, cd);
        }
        if (ce instanceof SparseEncoding) {
            SparseEncoding sed = (SparseEncoding)ce;
            IDictionary cd = DictionaryFactory.combineDictionariesSparse(ac, bc);
            double[] defaultTuple = CLALibCombineGroups.constructDefaultTuple(ac, bc);
            return ColGroupSDC.create(combinedColumns, sed.getNumRows(), cd, defaultTuple, sed.getOffsets(), sed.getMap(), null);
        }
        throw new NotImplementedException("Not implemented combine for " + ac.getClass().getSimpleName() + " - " + bc.getClass().getSimpleName());
    }

    private static AColGroup combineUC(IColIndex combinedColumns, AColGroup a, AColGroup b) {
        int nRow = a instanceof ColGroupUncompressed ? ((ColGroupUncompressed)a).getData().getNumRows() : ((ColGroupUncompressed)b).getData().getNumRows();
        MatrixBlock target = new MatrixBlock(nRow, combinedColumns.size(), false);
        target.allocateBlock();
        DenseBlock db = target.getDenseBlock();
        IColIndex aTempCols = ColIndexFactory.getColumnMapping(combinedColumns, a.getColIndices());
        a.copyAndSet(aTempCols).decompressToDenseBlock(db, 0, nRow, 0, 0);
        IColIndex bTempCols = ColIndexFactory.getColumnMapping(combinedColumns, b.getColIndices());
        b.copyAndSet(bTempCols).decompressToDenseBlock(db, 0, nRow, 0, 0);
        target.recomputeNonZeros();
        return ColGroupUncompressed.create(combinedColumns, target, false);
    }

    public static double[] constructDefaultTuple(AColGroupCompressed ac, AColGroupCompressed bc) {
        double[] ret = new double[ac.getNumCols() + bc.getNumCols()];
        if (ac instanceof IContainDefaultTuple) {
            double[] defa = ((IContainDefaultTuple)((Object)ac)).getDefaultTuple();
            System.arraycopy(defa, 0, ret, 0, defa.length);
        }
        if (bc instanceof IContainDefaultTuple) {
            double[] defb = ((IContainDefaultTuple)((Object)bc)).getDefaultTuple();
            System.arraycopy(defb, 0, ret, ac.getNumCols(), defb.length);
        }
        return ret;
    }
}

