/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanVisitor;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.ColumnSchema;
import org.apache.iotdb.db.queryengine.plan.relational.planner.DataOrganizationSpecification;
import org.apache.iotdb.db.queryengine.plan.relational.planner.OrderingScheme;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationTableScanNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.DeviceTableScanNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FillNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.MergeSortNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.SortNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TableFunctionProcessorNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ValueFillNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanOptimizer;

public class TransformAggregationToStreamable
implements PlanOptimizer {
    @Override
    public PlanNode optimize(PlanNode plan, PlanOptimizer.Context context) {
        if (!context.getAnalysis().isQuery() || !context.getAnalysis().containsAggregationQuery()) {
            return plan;
        }
        return plan.accept(new Rewriter(), null);
    }

    private static class Rewriter
    extends PlanVisitor<PlanNode, Void> {
        private Rewriter() {
        }

        @Override
        public PlanNode visitPlan(PlanNode node, Void context) {
            for (PlanNode child : node.getChildren()) {
                child.accept(this, context);
            }
            return node;
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, Void context) {
            node.getChild().accept(this, context);
            ImmutableSet expectedGroupingKeys = ImmutableSet.copyOf(node.getGroupingKeys());
            node.setPreGroupedSymbols(node.getChild().accept(new DeriveGroupProperties(), new GroupContext((Set)expectedGroupingKeys)));
            return node;
        }

        @Override
        public PlanNode visitAggregationTableScan(AggregationTableScanNode node, Void context) {
            throw new RuntimeException("This optimizer should be used before optimizer of PushAggregationIntoTableScan");
        }
    }

    private static class GroupContext {
        private final Set<Symbol> groupingKeys;

        private GroupContext(Set<Symbol> groupingKeys) {
            this.groupingKeys = groupingKeys;
        }
    }

    private static class DeriveGroupProperties
    extends PlanVisitor<List<Symbol>, GroupContext> {
        private DeriveGroupProperties() {
        }

        @Override
        public List<Symbol> visitPlan(PlanNode node, GroupContext context) {
            List result = node.getChildren().stream().map(child -> child.accept(new DeriveGroupProperties(), context)).distinct().collect(Collectors.toList());
            return result.size() == 1 ? (List)result.get(0) : ImmutableList.of();
        }

        @Override
        public List<Symbol> visitMergeSort(MergeSortNode node, GroupContext context) {
            return this.getMatchedPrefixSymbols(context, node.getOrderingScheme());
        }

        private List<Symbol> getMatchedPrefixSymbols(GroupContext context, OrderingScheme orderingScheme) {
            Set expectedGroupingKeys = context.groupingKeys;
            List<Symbol> orderKeys = orderingScheme.getOrderBy();
            for (int i = 0; i < orderKeys.size(); ++i) {
                if (expectedGroupingKeys.contains(orderKeys.get(i))) continue;
                return orderKeys.subList(0, i);
            }
            return ImmutableList.of();
        }

        @Override
        public List<Symbol> visitProject(ProjectNode node, GroupContext context) {
            if (ImmutableSet.copyOf(node.getOutputSymbols()).containsAll((Collection)context.groupingKeys)) {
                return node.getChild().accept(this, context);
            }
            return ImmutableList.of();
        }

        @Override
        public List<Symbol> visitFill(FillNode node, GroupContext context) {
            if (node instanceof ValueFillNode) {
                return ImmutableList.of();
            }
            return node.getChild().accept(this, context);
        }

        @Override
        public List<Symbol> visitSort(SortNode node, GroupContext context) {
            return this.getMatchedPrefixSymbols(context, node.getOrderingScheme());
        }

        @Override
        public List<Symbol> visitTableFunctionProcessor(TableFunctionProcessorNode node, GroupContext context) {
            if (node.getChildren().isEmpty()) {
                return ImmutableList.of();
            }
            if (node.isRowSemantic()) {
                return this.visitPlan((PlanNode)node, context);
            }
            Optional<DataOrganizationSpecification> dataOrganizationSpecification = node.getDataOrganizationSpecification();
            return dataOrganizationSpecification.map(organizationSpecification -> organizationSpecification.getPartitionBy().stream().filter(context.groupingKeys::contains).collect(Collectors.toList())).orElseGet(ImmutableList::of);
        }

        @Override
        public List<Symbol> visitDeviceTableScan(DeviceTableScanNode node, GroupContext context) {
            Set expectedGroupingKeys = context.groupingKeys;
            Map<Symbol, ColumnSchema> assignments = node.getAssignments();
            return expectedGroupingKeys.stream().filter(k -> {
                ColumnSchema columnSchema = (ColumnSchema)assignments.get(k);
                if (columnSchema != null) {
                    return columnSchema.getColumnCategory() == TsTableColumnCategory.TAG || columnSchema.getColumnCategory() == TsTableColumnCategory.ATTRIBUTE;
                }
                return false;
            }).collect(Collectors.toList());
        }

        @Override
        public List<Symbol> visitAggregation(AggregationNode node, GroupContext context) {
            return ImmutableSet.copyOf(node.getGroupingKeys()).equals((Object)context.groupingKeys) ? node.getGroupingKeys() : ImmutableList.of();
        }

        @Override
        public List<Symbol> visitAggregationTableScan(AggregationTableScanNode node, GroupContext context) {
            throw new RuntimeException("This optimizer should be used before optimizer of PushAggregationIntoTableScan");
        }
    }
}

