/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.BinaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class AggregateBinaryFEDInstruction
extends BinaryFEDInstruction {
    private static final Log LOG = LogFactory.getLog((String)AggregateBinaryFEDInstruction.class.getName());

    public AggregateBinaryFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(FEDInstruction.FEDType.AggregateBinary, op, in1, in2, out, opcode, istr);
    }

    public AggregateBinaryFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, FEDInstruction.FederatedOutput fedOut) {
        super(FEDInstruction.FEDType.AggregateBinary, op, in1, in2, out, opcode, istr, fedOut);
    }

    public static AggregateBinaryFEDInstruction parseInstruction(AggregateBinaryCPInstruction inst, ExecutionContext ec) {
        if (inst.input1.isMatrix() && inst.input2.isMatrix()) {
            MatrixObject mo1 = ec.getMatrixObject(inst.input1);
            MatrixObject mo2 = ec.getMatrixObject(inst.input2);
            if (mo1.isFederated(FTypes.FType.ROW) && mo1.isFederatedExcept(FTypes.FType.BROADCAST) || mo2.isFederated(FTypes.FType.ROW) && mo2.isFederatedExcept(FTypes.FType.BROADCAST) || mo1.isFederated(FTypes.FType.COL) && mo1.isFederatedExcept(FTypes.FType.BROADCAST)) {
                return AggregateBinaryFEDInstruction.parseInstruction(inst);
            }
        }
        return null;
    }

    private static AggregateBinaryFEDInstruction parseInstruction(AggregateBinaryCPInstruction instr) {
        return new AggregateBinaryFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.output, instr.getOpcode(), instr.getInstructionString(), FEDInstruction.FederatedOutput.NONE);
    }

    public static AggregateBinaryFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("ba+*")) {
            throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
        }
        InstructionUtils.checkNumFields(parts, 5);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand out = new CPOperand(parts[3]);
        int k = Integer.parseInt(parts[4]);
        FEDInstruction.FederatedOutput fedOut = FEDInstruction.FederatedOutput.valueOf(parts[5]);
        return new AggregateBinaryFEDInstruction(InstructionUtils.getMatMultOperator(k), in1, in2, out, opcode, str, fedOut);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        MatrixLineagePair mo1 = ec.getMatrixLineagePair(this.input1);
        MatrixLineagePair mo2 = ec.getMatrixLineagePair(this.input2);
        if (mo1.isFederated(FTypes.FType.COL) && mo2.isFederated(FTypes.FType.ROW) && mo1.getFedMapping().isAligned(mo2.getFedMapping(), FTypes.AlignType.COL_T)) {
            FederatedRequest fr1 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}, true);
            if (this._fedOut.isForcedFederated()) {
                this.writeInfoLog(mo1, mo2);
            }
            this.aggregateLocally(mo1.getFedMapping(), true, ec, fr1);
        } else if (mo1.isFederated(FTypes.FType.ROW)) {
            boolean isPartOut;
            FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
            FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()}, true);
            boolean isVector = mo2.getNumColumns() == 1L;
            boolean bl = isPartOut = mo1.isFederated(FTypes.FType.PART) || !isVector && mo2.isFederated(FTypes.FType.PART);
            if (isPartOut && this._fedOut.isForcedFederated()) {
                this.writeInfoLog(mo1, mo2);
            }
            if ((this._fedOut.isForcedFederated() || !isVector && !this._fedOut.isForcedLocal()) && !isPartOut) {
                Future<FederatedResponse>[] ffr = mo1.getFedMapping().execute(this.getTID(), true, fr1, fr2);
                this.setOutputFedMapping(mo1.getFedMapping(), mo1, mo2, FederationUtils.sumNonZeros(ffr), fr2.getID(), ec);
            } else {
                boolean isDoubleBroadcast;
                boolean bl2 = isDoubleBroadcast = mo1.isFederated(FTypes.FType.BROADCAST) && mo2.isFederated(FTypes.FType.BROADCAST);
                if (isDoubleBroadcast) {
                    this.aggregateLocallySingleWorker(mo1.getFedMapping(), ec, fr1, fr2);
                } else {
                    this.aggregateLocally(mo1.getFedMapping(), false, ec, fr1, fr2);
                }
            }
        } else if (mo2.isFederated(FTypes.FType.ROW)) {
            FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, true);
            FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{fr1[0].getID(), mo2.getFedMapping().getID()}, true);
            if (this._fedOut.isForcedFederated()) {
                this.writeInfoLog(mo1, mo2);
            }
            this.aggregateLocally(mo2.getFedMapping(), true, ec, fr1, new FederatedRequest[]{fr2});
        } else if (mo1.isFederated(FTypes.FType.COL)) {
            FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, true);
            FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), fr1[0].getID()}, true);
            if (this._fedOut.isForcedFederated()) {
                this.writeInfoLog(mo1, mo2);
            }
            this.aggregateLocally(mo1.getFedMapping(), true, ec, fr1, new FederatedRequest[]{fr2});
        } else {
            throw new DMLRuntimeException("Federated AggregateBinary not supported with the following federated objects: " + mo1.isFederated() + ":" + mo1.getFedMapping() + " " + mo2.isFederated() + ":" + mo2.getFedMapping());
        }
    }

    private void writeInfoLog(MatrixLineagePair mo1, MatrixLineagePair mo2) {
        FTypes.FType mo1FType = mo1.getFedMapping() == null ? null : mo1.getFedMapping().getType();
        FTypes.FType mo2FType = mo2.getFedMapping() == null ? null : mo2.getFedMapping().getType();
        LOG.info((Object)("Federated output flag would result in PART federated map and has been ignored in " + this.instString));
        LOG.info((Object)("Input 1 FType is " + mo1FType + " and input 2 FType " + mo2FType));
    }

    private void setPartialOutput(FederationMap federationMap, MatrixLineagePair mo1, MatrixLineagePair mo2, long outputID, ExecutionContext ec) {
        MatrixObject out = ec.getMatrixObject(this.output);
        out.getDataCharacteristics().setDimension(mo1.getNumRows(), mo2.getNumColumns()).setBlocksize(mo1.getBlocksize());
        FederationMap outputFedMap = federationMap.copyWithNewIDAndRange(mo1.getNumRows(), mo2.getNumColumns(), outputID);
        out.setFedMapping(outputFedMap);
    }

    private void setOutputFedMapping(FederationMap federationMap, MatrixLineagePair mo1, MatrixLineagePair mo2, long nnz, long outputID, ExecutionContext ec) {
        MatrixObject out = ec.getMatrixObject(this.output);
        out.getDataCharacteristics().setDimension(mo1.getNumRows(), mo2.getNumColumns()).setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
        out.setFedMapping(federationMap.copyWithNewID(outputID, mo2.getNumColumns()));
    }

    private void aggregateLocally(FederationMap fedMap, boolean aggAdd, ExecutionContext ec, FederatedRequest ... fr) {
        this.aggregateLocally(fedMap, aggAdd, ec, (FederatedRequest[])null, fr);
    }

    private void aggregateLocally(FederationMap fedMap, boolean aggAdd, ExecutionContext ec, FederatedRequest[] frSliced, FederatedRequest ... fr) {
        long callInstID = fr[fr.length - 1].getID();
        FederatedRequest frG = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstID);
        FederatedRequest frC = fedMap.cleanup(this.getTID(), callInstID);
        Future<FederatedResponse>[] ffr = frSliced != null ? fedMap.execute(this.getTID(), frSliced, (FederatedRequest[])ArrayUtils.addAll((Object[])fr, (Object[])new FederatedRequest[]{frG, frC})) : fedMap.execute(this.getTID(), (FederatedRequest[])ArrayUtils.addAll((Object[])fr, (Object[])new FederatedRequest[]{frG, frC}));
        MatrixBlock ret = aggAdd ? FederationUtils.aggAdd(ffr) : FederationUtils.bind(ffr, false);
        ec.setMatrixOutput(this.output.getName(), ret);
    }

    private void aggregateLocallySingleWorker(FederationMap fedMap, ExecutionContext ec, FederatedRequest ... fr) {
        long callInstID = fr[fr.length - 1].getID();
        FederatedRequest frG = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstID);
        FederatedRequest frC = fedMap.cleanup(this.getTID(), callInstID);
        Future<FederatedResponse>[] ffr = fedMap.execute(this.getTID(), (FederatedRequest[])ArrayUtils.addAll((Object[])fr, (Object[])new FederatedRequest[]{frG, frC}));
        try {
            MatrixBlock ret = (MatrixBlock)ffr[0].get().getData()[0];
            ec.setMatrixOutput(this.output.getName(), ret);
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
    }
}

