/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.codegen.cplan;

import java.util.Arrays;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.runtime.util.UtilFunctions;

public class CNodeTernary
extends CNode {
    private final TernaryType _type;

    public CNodeTernary(CNode in1, CNode in2, CNode in3, TernaryType type) {
        this._inputs.add(in1);
        this._inputs.add(in2);
        this._inputs.add(in3);
        this._type = type;
        this.setOutputDims();
    }

    public TernaryType getType() {
        return this._type;
    }

    @Override
    public String codegen(boolean sparse, SpoofCompiler.GeneratorAPI api) {
        if (this.isGenerated()) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        sb.append(((CNode)this._inputs.get(0)).codegen(sparse, api));
        sb.append(((CNode)this._inputs.get(1)).codegen(sparse, api));
        sb.append(((CNode)this._inputs.get(2)).codegen(sparse, api));
        boolean lsparse = sparse && this._inputs.get(0) instanceof CNodeData && ((CNode)this._inputs.get(0)).getVarname().startsWith("a") && !((CNode)this._inputs.get(0)).isLiteral();
        String var = this.createVarname();
        String tmp = this.getLanguageTemplateClass(this, api).getTemplate(this._type, lsparse);
        tmp = tmp.replace("%TMP%", var);
        for (int j = 1; j <= 3; ++j) {
            String varj = ((CNode)this._inputs.get(j - 1)).getVarname();
            tmp = tmp.replace("%IN" + j + "v%", varj + (varj.startsWith("a") ? "vals" : ""));
            tmp = tmp.replace("%IN" + j + "i%", varj + (varj.startsWith("a") ? "ix" : ""));
            tmp = tmp.replace("%IN" + j + "%", varj);
        }
        sb.append(tmp);
        this._generated = true;
        return sb.toString();
    }

    public String toString() {
        switch (this._type) {
            case PLUS_MULT: {
                return "t(+*)";
            }
            case MINUS_MULT: {
                return "t(-*)";
            }
            case BIASADD: {
                return "t(bias+)";
            }
            case BIASMULT: {
                return "t(bias*)";
            }
            case REPLACE: 
            case REPLACE_NAN: {
                return "t(rplc)";
            }
            case IFELSE: {
                return "t(ifelse)";
            }
            case LOOKUP_RC1: {
                return "u(ixrc1)";
            }
            case LOOKUP_RVECT1: {
                return "u(ixrv1)";
            }
        }
        return super.toString();
    }

    @Override
    public void setOutputDims() {
        switch (this._type) {
            case PLUS_MULT: 
            case MINUS_MULT: 
            case BIASADD: 
            case BIASMULT: 
            case REPLACE: 
            case REPLACE_NAN: 
            case IFELSE: 
            case LOOKUP_RC1: {
                this._rows = 0L;
                this._cols = 0L;
                this._dataType = Types.DataType.SCALAR;
                break;
            }
            case LOOKUP_RVECT1: {
                this._rows = 1L;
                this._cols = ((CNode)this._inputs.get((int)0))._cols;
                this._dataType = Types.DataType.MATRIX;
            }
        }
    }

    @Override
    public int hashCode() {
        if (this._hash == 0) {
            this._hash = UtilFunctions.intHashCode(super.hashCode(), this._type.hashCode());
        }
        return this._hash;
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof CNodeTernary)) {
            return false;
        }
        CNodeTernary that = (CNodeTernary)o;
        return super.equals(that) && this._type == that._type;
    }

    @Override
    public boolean isSupported(SpoofCompiler.GeneratorAPI api) {
        boolean is_supported = api == SpoofCompiler.GeneratorAPI.CUDA || api == SpoofCompiler.GeneratorAPI.JAVA;
        int i = 0;
        while (is_supported && i < this._inputs.size()) {
            CNode in = (CNode)this._inputs.get(i++);
            is_supported = in.isSupported(api);
        }
        return is_supported;
    }

    public static enum TernaryType {
        PLUS_MULT,
        MINUS_MULT,
        BIASADD,
        BIASMULT,
        REPLACE,
        REPLACE_NAN,
        IFELSE,
        LOOKUP_RC1,
        LOOKUP_RVECT1;


        public static boolean contains(String value) {
            return Arrays.stream(TernaryType.values()).anyMatch(tt -> tt.name().equals(value));
        }

        public boolean isVectorPrimitive() {
            return this == LOOKUP_RVECT1;
        }
    }
}

