/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.rules.ImmutableJoinExpandOrToUnionRule;
import org.apache.calcite.rel.rules.JoinCommuteRule;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.immutables.value.Value;

@Value.Enclosing
public class JoinExpandOrToUnionRule
extends RelRule<Config>
implements TransformationRule {
    protected JoinExpandOrToUnionRule(Config config) {
        super(config);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Join join = (Join)call.rel(0);
        List<RexNode> orConds = RelOptUtil.disjunctions(join.getCondition());
        return orConds.size() > 1;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        RelNode expanded;
        Join join = (Join)call.rel(0);
        RelBuilder relBuilder = call.builder();
        switch (join.getJoinType()) {
            case INNER: {
                expanded = this.expandInnerJoin(join, relBuilder);
                break;
            }
            case ANTI: {
                expanded = this.expandAntiJoin(join, relBuilder);
                break;
            }
            case LEFT: {
                expanded = this.expandLeftOrRightJoin(join, true, relBuilder);
                break;
            }
            case RIGHT: {
                expanded = this.expandLeftOrRightJoin(join, false, relBuilder);
                break;
            }
            case FULL: {
                expanded = this.expandFullJoin(join, relBuilder);
                break;
            }
            default: {
                return;
            }
        }
        if (expanded instanceof Join && ((Join)expanded).getCondition().equals(join.getCondition())) {
            return;
        }
        call.transformTo(expanded);
    }

    private List<RexNode> splitCond(Join join) {
        RexBuilder builder = join.getCluster().getRexBuilder();
        List<RexNode> orConds = RelOptUtil.disjunctions(join.getCondition());
        int leftFieldCount = join.getLeft().getRowType().getFieldCount();
        ArrayList<RexNode> result = new ArrayList<RexNode>();
        ArrayList<RexNode> otherBuffer = new ArrayList<RexNode>();
        for (RexNode cond : orConds) {
            if (this.isValidCond(cond, leftFieldCount)) {
                if (!otherBuffer.isEmpty()) {
                    result.add(RexUtil.composeDisjunction(builder, otherBuffer));
                    otherBuffer.clear();
                }
                result.add(cond);
                continue;
            }
            otherBuffer.add(cond);
        }
        if (!otherBuffer.isEmpty()) {
            result.add(RexUtil.composeDisjunction(builder, otherBuffer));
        }
        return result;
    }

    private boolean isValidCond(RexNode node, int leftFieldCount) {
        boolean hasJoinKeyCond = false;
        List<RexNode> conds = RelOptUtil.conjunctions(node);
        for (RexNode cond : conds) {
            RexCall call;
            if (!this.doesNotReferToBothInputs(cond, leftFieldCount)) {
                return false;
            }
            if (RexUtil.SubQueryFinder.find(cond) != null || RexUtil.containsCorrelation(cond)) {
                return false;
            }
            if (!(cond instanceof RexCall) || !this.isEquiJoinCond(call = (RexCall)cond, leftFieldCount)) continue;
            hasJoinKeyCond = true;
        }
        return hasJoinKeyCond;
    }

    private boolean isEquiJoinCond(RexCall call, int leftFieldCount) {
        if (call.getKind() != SqlKind.EQUALS && call.getKind() != SqlKind.IS_NOT_DISTINCT_FROM) {
            return false;
        }
        RexNode left = call.getOperands().get(0);
        RexNode right = call.getOperands().get(1);
        if (left instanceof RexInputRef && right instanceof RexInputRef) {
            int leftIndex = ((RexInputRef)left).getIndex();
            int rightIndex = ((RexInputRef)right).getIndex();
            return leftIndex < leftFieldCount && rightIndex >= leftFieldCount || rightIndex < leftFieldCount && leftIndex >= leftFieldCount;
        }
        return false;
    }

    private boolean doesNotReferToBothInputs(RexNode rex, int leftFieldCount) {
        RexInputRefCounter counter = new RexInputRefCounter(leftFieldCount);
        rex.accept(counter);
        return counter.doesNotReferToBothInputs();
    }

    private RelNode expandLeftOrRightJoin(Join join, boolean isLeftJoin, RelBuilder relBuilder) {
        List<RexNode> orConds = this.splitCond(join);
        List<RelNode> joins = this.expandLeftOrRightJoinToRelNodes(join, orConds, isLeftJoin, relBuilder);
        return relBuilder.pushAll(joins).union(true, joins.size()).build();
    }

    private List<RelNode> expandLeftOrRightJoinToRelNodes(Join join, List<RexNode> orConds, boolean isLeftJoin, RelBuilder relBuilder) {
        ArrayList<RelNode> joins = new ArrayList<RelNode>();
        joins.addAll(this.expandInnerJoinToRelNodes(join, orConds, relBuilder));
        joins.add(this.expandAntiJoinToRelNode(join, orConds, isLeftJoin, true, relBuilder));
        return joins;
    }

    private RelNode expandFullJoin(Join join, RelBuilder relBuilder) {
        List<RexNode> orConds = this.splitCond(join);
        ArrayList<RelNode> joins = new ArrayList<RelNode>();
        joins.addAll(this.expandInnerJoinToRelNodes(join, orConds, relBuilder));
        joins.add(this.expandAntiJoinToRelNode(join, orConds, false, true, relBuilder));
        joins.add(this.expandAntiJoinToRelNode(join, orConds, true, true, relBuilder));
        relBuilder.pushAll(joins).union(true, joins.size());
        List projects = join.getRowType().getFieldList().stream().map(field -> {
            RexInputRef rexNode = relBuilder.field(field.getIndex());
            return field.getType().equals(((RexNode)rexNode).getType()) ? rexNode : relBuilder.getRexBuilder().makeCast(field.getType(), rexNode, true, false);
        }).collect(Collectors.toList());
        return relBuilder.project(projects).build();
    }

    private RelNode expandInnerJoin(Join join, RelBuilder relBuilder) {
        List<RexNode> orConds = this.splitCond(join);
        List<RelNode> joins = this.expandInnerJoinToRelNodes(join, orConds, relBuilder);
        return relBuilder.pushAll(joins).union(true, joins.size()).build();
    }

    private List<RelNode> expandInnerJoinToRelNodes(Join join, List<RexNode> orConds, RelBuilder relBuilder) {
        ArrayList<RelNode> joins = new ArrayList<RelNode>();
        for (int i = 0; i < orConds.size(); ++i) {
            RexNode orCond = orConds.get(i);
            for (int j = 0; j < i; ++j) {
                orCond = relBuilder.and(orCond, relBuilder.not(orConds.get(j)));
            }
            relBuilder.push(join.getLeft()).push(join.getRight()).join(JoinRelType.INNER, orCond);
            joins.add(relBuilder.build());
        }
        return joins;
    }

    private RelNode expandAntiJoin(Join join, RelBuilder relBuilder) {
        List<RexNode> orConds = this.splitCond(join);
        return this.expandAntiJoinToRelNode(join, orConds, true, false, relBuilder);
    }

    private RelNode expandAntiJoinToRelNode(Join join, List<RexNode> orConds, boolean isLeftAnti, boolean isAppendNulls, RelBuilder relBuilder) {
        RelNode left = isLeftAnti ? join.getLeft() : join.getRight();
        RelNode right = isLeftAnti ? join.getRight() : join.getLeft();
        RelNode top = left;
        for (int i = 0; i < orConds.size(); ++i) {
            RexNode orCond = orConds.get(i);
            relBuilder.push(top).push(right).join(JoinRelType.ANTI, isLeftAnti ? orCond : JoinCommuteRule.swapJoinCond(orCond, join, relBuilder.getRexBuilder()));
            top = relBuilder.build();
        }
        if (!isAppendNulls) {
            return top;
        }
        relBuilder.push(top);
        ArrayList<RexNode> fields = new ArrayList<RexNode>((Collection<RexNode>)relBuilder.fields());
        ArrayList<RexLiteral> nulls = new ArrayList<RexLiteral>();
        for (int i = 0; i < right.getRowType().getFieldCount(); ++i) {
            nulls.add(relBuilder.getRexBuilder().makeNullLiteral(right.getRowType().getFieldList().get(i).getType()));
        }
        List projects = isLeftAnti ? Stream.concat(fields.stream(), nulls.stream()).collect(Collectors.toList()) : Stream.concat(nulls.stream(), fields.stream()).collect(Collectors.toList());
        return relBuilder.project(projects).build();
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableJoinExpandOrToUnionRule.Config.of().withOperandFor(Join.class);

        @Override
        default public JoinExpandOrToUnionRule toRule() {
            return new JoinExpandOrToUnionRule(this);
        }

        default public Config withOperandFor(Class<? extends Join> joinClass) {
            return this.withOperandSupplier(b -> b.operand(joinClass).anyInputs()).as(Config.class);
        }
    }

    private static class RexInputRefCounter
    extends RexVisitorImpl<Void> {
        private int leftFieldCount;
        public int leftInputRefCount = 0;
        public int rightInputRefCount = 0;

        RexInputRefCounter(int leftFieldCount) {
            super(true);
            this.leftFieldCount = leftFieldCount;
        }

        @Override
        public Void visitInputRef(RexInputRef inputRef) {
            if (inputRef.getIndex() < this.leftFieldCount) {
                ++this.leftFieldCount;
            } else {
                ++this.rightInputRefCount;
            }
            return null;
        }

        public boolean doesNotReferToBothInputs() {
            return this.leftInputRefCount == 0 || this.rightInputRefCount == 0;
        }
    }
}

