/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.sql.calcite.rule;

import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.DruidCorrelateUnnestRel;
import org.apache.druid.sql.calcite.rel.DruidRel;
import org.apache.druid.sql.calcite.rel.DruidUnnestRel;
import org.apache.druid.sql.calcite.rel.PartialDruidQuery;

public class DruidCorrelateUnnestRule
extends RelOptRule {
    private final PlannerContext plannerContext;

    public DruidCorrelateUnnestRule(PlannerContext plannerContext) {
        super(DruidCorrelateUnnestRule.operand(Correlate.class, (RelOptRuleOperand)DruidCorrelateUnnestRule.operand(DruidRel.class, (RelOptRuleOperandChildren)DruidCorrelateUnnestRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{DruidCorrelateUnnestRule.operand(DruidUnnestRel.class, (RelOptRuleOperandChildren)DruidCorrelateUnnestRule.any())}));
        this.plannerContext = plannerContext;
    }

    public boolean matches(RelOptRuleCall call) {
        DruidRel left = (DruidRel)call.rel(1);
        return left.getPartialDruidQuery() != null;
    }

    public void onMatch(RelOptRuleCall call) {
        ImmutableBitSet requiredCols;
        RexNode newUnnestRexNode;
        CorrelationId newCorrelationId;
        DruidRel newLeft;
        Correlate correlate = (Correlate)call.rel(0);
        DruidRel left = (DruidRel)call.rel(1);
        DruidUnnestRel right = (DruidUnnestRel)call.rel(2);
        RexBuilder rexBuilder = correlate.getCluster().getRexBuilder();
        ArrayList<RexInputRef> pulledUpProjects = new ArrayList<RexInputRef>();
        if (left.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT) {
            RelNode leftScan = left.getPartialDruidQuery().getScan();
            Project leftProject = left.getPartialDruidQuery().getSelectProject();
            pulledUpProjects.addAll(leftProject.getProjects());
            Filter leftFilter = left.getPartialDruidQuery().getWhereFilter();
            newLeft = left.withPartialQuery(PartialDruidQuery.create(leftScan).withWhereFilter(leftFilter));
            newCorrelationId = correlate.getCluster().createCorrel();
            PushCorrelatedFieldAccessPastProject correlatedFieldRewriteShuttle = new PushCorrelatedFieldAccessPastProject(correlate.getCorrelationId(), newCorrelationId, leftProject);
            newUnnestRexNode = correlatedFieldRewriteShuttle.apply(right.getInputRexNode());
            requiredCols = ImmutableBitSet.of((Iterable)correlatedFieldRewriteShuttle.getRequiredColumns());
        } else {
            for (int i = 0; i < left.getRowType().getFieldCount(); ++i) {
                pulledUpProjects.add(rexBuilder.makeInputRef(((RelDataTypeField)correlate.getRowType().getFieldList().get(i)).getType(), i));
            }
            newLeft = left;
            newUnnestRexNode = right.getInputRexNode();
            requiredCols = correlate.getRequiredColumns();
            newCorrelationId = correlate.getCorrelationId();
        }
        for (int i = 0; i < right.getRowType().getFieldCount(); ++i) {
            pulledUpProjects.add(rexBuilder.makeInputRef(((RelDataTypeField)correlate.getRowType().getFieldList().get(left.getRowType().getFieldCount() + i)).getType(), newLeft.getRowType().getFieldCount() + i));
        }
        DruidCorrelateUnnestRel druidCorrelateUnnest = DruidCorrelateUnnestRel.create(correlate.copy(correlate.getTraitSet(), (RelNode)newLeft, (RelNode)right.withUnnestRexNode(newUnnestRexNode), newCorrelationId, requiredCols, correlate.getJoinType()), this.plannerContext);
        RelBuilder relBuilder = call.builder().push((RelNode)druidCorrelateUnnest).project((Iterable)RexUtil.fixUp((RexBuilder)rexBuilder, pulledUpProjects, (List)RelOptUtil.getFieldTypeList((RelDataType)druidCorrelateUnnest.getRowType())));
        relBuilder.convert(correlate.getRowType(), false);
        RelNode build = relBuilder.build();
        call.transformTo(build);
    }

    private static class PushCorrelatedFieldAccessPastProject
    extends RexShuttle {
        private final CorrelationId correlationId;
        private final CorrelationId newCorrelationId;
        private final Project project;
        private final IntSet requiredColumns = new IntAVLTreeSet();

        public PushCorrelatedFieldAccessPastProject(CorrelationId correlationId, CorrelationId newCorrelationId, Project project) {
            this.correlationId = correlationId;
            this.newCorrelationId = newCorrelationId;
            this.project = project;
        }

        public IntSet getRequiredColumns() {
            return this.requiredColumns;
        }

        public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
            if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
                RexCorrelVariable encounteredCorrelVariable = (RexCorrelVariable)fieldAccess.getReferenceExpr();
                if (encounteredCorrelVariable.id.equals((Object)this.correlationId)) {
                    RexNode projectExpr = (RexNode)this.project.getProjects().get(fieldAccess.getField().getIndex());
                    RexBuilder rexBuilder = this.project.getCluster().getRexBuilder();
                    final RexNode newCorrel = rexBuilder.makeCorrel(this.project.getInput().getRowType(), this.newCorrelationId);
                    return new RexShuttle(){

                        public RexNode visitInputRef(RexInputRef inputRef) {
                            requiredColumns.add(inputRef.getIndex());
                            return project.getCluster().getRexBuilder().makeFieldAccess(newCorrel, inputRef.getIndex());
                        }
                    }.apply(projectExpr);
                }
            }
            return super.visitFieldAccess(fieldAccess);
        }
    }
}

