/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.rules.ImmutableJoinToMultiJoinRule;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Pair;
import org.apache.kylin.guava30.shaded.common.collect.ImmutableMap;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;

@Value.Enclosing
public class JoinToMultiJoinRule
extends RelRule<Config>
implements TransformationRule {
    protected JoinToMultiJoinRule(Config config) {
        super(config);
    }

    @Deprecated
    public JoinToMultiJoinRule(Class<? extends Join> clazz) {
        this(Config.DEFAULT.withOperandFor(clazz));
    }

    @Deprecated
    public JoinToMultiJoinRule(Class<? extends Join> joinClass, RelBuilderFactory relBuilderFactory) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class).withOperandFor(joinClass));
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Join origJoin = (Join)call.rel(0);
        return origJoin.getJoinType().projectsRight();
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Join origJoin = (Join)call.rel(0);
        Object left = call.rel(1);
        Object right = call.rel(2);
        ArrayList<@Nullable ImmutableBitSet> projFieldsList = new ArrayList<ImmutableBitSet>();
        ArrayList<int[]> joinFieldRefCountsList = new ArrayList<int[]>();
        List<RelNode> newInputs = JoinToMultiJoinRule.combineInputs(origJoin, left, right, projFieldsList, joinFieldRefCountsList);
        ArrayList<Pair<JoinRelType, @Nullable RexNode>> joinSpecs = new ArrayList<Pair<JoinRelType, RexNode>>();
        JoinToMultiJoinRule.combineOuterJoins(origJoin, newInputs, left, right, joinSpecs);
        List<@Nullable RexNode> newJoinFilters = JoinToMultiJoinRule.combineJoinFilters(origJoin, left, right);
        ImmutableMap<Integer, ImmutableIntList> newJoinFieldRefCountsMap = JoinToMultiJoinRule.addOnJoinFieldRefCounts(newInputs, origJoin.getRowType().getFieldCount(), origJoin.getCondition(), joinFieldRefCountsList);
        List<@Nullable RexNode> newPostJoinFilters = JoinToMultiJoinRule.combinePostJoinFilters(origJoin, left, right);
        RexBuilder rexBuilder = origJoin.getCluster().getRexBuilder();
        MultiJoin multiJoin = new MultiJoin(origJoin.getCluster(), newInputs, RexUtil.composeConjunction(rexBuilder, newJoinFilters), origJoin.getRowType(), origJoin.getJoinType() == JoinRelType.FULL, Pair.right(joinSpecs), Pair.left(joinSpecs), projFieldsList, newJoinFieldRefCountsMap, RexUtil.composeConjunction(rexBuilder, newPostJoinFilters, true));
        call.transformTo(multiJoin);
    }

    private static List<RelNode> combineInputs(Join join, RelNode left, RelNode right, List<@Nullable ImmutableBitSet> projFieldsList, List<int[]> joinFieldRefCountsList) {
        int i;
        ArrayList<RelNode> newInputs = new ArrayList<RelNode>();
        if (JoinToMultiJoinRule.canCombine(left, join.getJoinType().generatesNullsOnLeft())) {
            MultiJoin leftMultiJoin = (MultiJoin)left;
            for (i = 0; i < left.getInputs().size(); ++i) {
                newInputs.add(leftMultiJoin.getInput(i));
                projFieldsList.add(leftMultiJoin.getProjFields().get(i));
                joinFieldRefCountsList.add(((ImmutableIntList)leftMultiJoin.getJoinFieldRefCountsMap().get((Object)i)).toIntArray());
            }
        } else {
            newInputs.add(left);
            projFieldsList.add(null);
            joinFieldRefCountsList.add(new int[left.getRowType().getFieldCount()]);
        }
        if (JoinToMultiJoinRule.canCombine(right, join.getJoinType().generatesNullsOnRight())) {
            MultiJoin rightMultiJoin = (MultiJoin)right;
            for (i = 0; i < right.getInputs().size(); ++i) {
                newInputs.add(rightMultiJoin.getInput(i));
                projFieldsList.add(rightMultiJoin.getProjFields().get(i));
                joinFieldRefCountsList.add(((ImmutableIntList)rightMultiJoin.getJoinFieldRefCountsMap().get((Object)i)).toIntArray());
            }
        } else {
            newInputs.add(right);
            projFieldsList.add(null);
            joinFieldRefCountsList.add(new int[right.getRowType().getFieldCount()]);
        }
        return newInputs;
    }

    private static void combineOuterJoins(Join joinRel, List<RelNode> combinedInputs, RelNode left, RelNode right, List<Pair<JoinRelType, @Nullable RexNode>> joinSpecs) {
        JoinRelType joinType = joinRel.getJoinType();
        boolean leftCombined = JoinToMultiJoinRule.canCombine(left, joinType.generatesNullsOnLeft());
        boolean rightCombined = JoinToMultiJoinRule.canCombine(right, joinType.generatesNullsOnRight());
        switch (joinType) {
            case LEFT: {
                if (leftCombined) {
                    JoinToMultiJoinRule.copyOuterJoinInfo((MultiJoin)left, joinSpecs, 0, null, null);
                } else {
                    joinSpecs.add(Pair.of(JoinRelType.INNER, null));
                }
                joinSpecs.add(Pair.of(joinType, joinRel.getCondition()));
                break;
            }
            case RIGHT: {
                joinSpecs.add(Pair.of(joinType, joinRel.getCondition()));
                if (rightCombined) {
                    JoinToMultiJoinRule.copyOuterJoinInfo((MultiJoin)right, joinSpecs, left.getRowType().getFieldCount(), right.getRowType().getFieldList(), joinRel.getRowType().getFieldList());
                    break;
                }
                joinSpecs.add(Pair.of(JoinRelType.INNER, null));
                break;
            }
            default: {
                if (leftCombined) {
                    JoinToMultiJoinRule.copyOuterJoinInfo((MultiJoin)left, joinSpecs, 0, null, null);
                } else {
                    joinSpecs.add(Pair.of(JoinRelType.INNER, null));
                }
                if (rightCombined) {
                    JoinToMultiJoinRule.copyOuterJoinInfo((MultiJoin)right, joinSpecs, left.getRowType().getFieldCount(), right.getRowType().getFieldList(), joinRel.getRowType().getFieldList());
                    break;
                }
                joinSpecs.add(Pair.of(JoinRelType.INNER, null));
            }
        }
    }

    private static void copyOuterJoinInfo(MultiJoin multiJoin, List<Pair<JoinRelType, @Nullable RexNode>> destJoinSpecs, int adjustmentAmount, @Nullable List<RelDataTypeField> srcFields, @Nullable List<RelDataTypeField> destFields) {
        List<Pair<JoinRelType, @Nullable RexNode>> srcJoinSpecs = Pair.zip(multiJoin.getJoinTypes(), multiJoin.getOuterJoinConditions());
        if (adjustmentAmount == 0) {
            destJoinSpecs.addAll(srcJoinSpecs);
        } else {
            assert (srcFields != null);
            assert (destFields != null);
            int nFields = srcFields.size();
            int[] adjustments = new int[nFields];
            for (int idx = 0; idx < nFields; ++idx) {
                adjustments[idx] = adjustmentAmount;
            }
            for (Pair<JoinRelType, RexNode> src : srcJoinSpecs) {
                destJoinSpecs.add(Pair.of(src.left, src.right == null ? null : ((RexNode)src.right).accept(new RelOptUtil.RexInputConverter(multiJoin.getCluster().getRexBuilder(), srcFields, destFields, adjustments))));
            }
        }
    }

    private static List<@Nullable RexNode> combineJoinFilters(Join join, RelNode left, RelNode right) {
        JoinRelType joinType = join.getJoinType();
        ArrayList<@Nullable RexNode> filters = new ArrayList<RexNode>();
        if (joinType != JoinRelType.LEFT && joinType != JoinRelType.RIGHT) {
            filters.add(join.getCondition());
        }
        if (JoinToMultiJoinRule.canCombine(left, joinType.generatesNullsOnLeft())) {
            filters.add(((MultiJoin)left).getJoinFilter());
        }
        if (JoinToMultiJoinRule.canCombine(right, joinType.generatesNullsOnRight())) {
            MultiJoin multiJoin = (MultiJoin)right;
            filters.add(JoinToMultiJoinRule.shiftRightFilter(join, left, multiJoin, multiJoin.getJoinFilter()));
        }
        return filters;
    }

    private static boolean canCombine(RelNode input, boolean nullGenerating) {
        return input instanceof MultiJoin && !((MultiJoin)input).isFullOuterJoin() && !((MultiJoin)input).containsOuter() && !nullGenerating;
    }

    private static @Nullable RexNode shiftRightFilter(Join joinRel, RelNode left, MultiJoin right, @Nullable RexNode rightFilter) {
        if (rightFilter == null) {
            return null;
        }
        int nFieldsOnLeft = left.getRowType().getFieldList().size();
        int nFieldsOnRight = right.getRowType().getFieldList().size();
        int[] adjustments = new int[nFieldsOnRight];
        for (int i = 0; i < nFieldsOnRight; ++i) {
            adjustments[i] = nFieldsOnLeft;
        }
        rightFilter = rightFilter.accept(new RelOptUtil.RexInputConverter(joinRel.getCluster().getRexBuilder(), right.getRowType().getFieldList(), joinRel.getRowType().getFieldList(), adjustments));
        return rightFilter;
    }

    private static ImmutableMap<Integer, ImmutableIntList> addOnJoinFieldRefCounts(List<RelNode> multiJoinInputs, int nTotalFields, RexNode joinCondition, List<int[]> origJoinFieldRefCounts) {
        int[] joinCondRefCounts = new int[nTotalFields];
        joinCondition.accept(new InputReferenceCounter(joinCondRefCounts));
        HashMap<Integer, Object> refCountsMap = new HashMap<Integer, Object>();
        int nInputs = multiJoinInputs.size();
        int currInput = 0;
        for (int[] origRefCounts : origJoinFieldRefCounts) {
            refCountsMap.put(currInput, origRefCounts.clone());
            ++currInput;
        }
        currInput = -1;
        int startField = 0;
        int nFields = 0;
        for (int i = 0; i < nTotalFields; ++i) {
            if (joinCondRefCounts[i] == 0) continue;
            while (i >= startField + nFields) {
                startField += nFields;
                assert (++currInput < nInputs);
                nFields = multiJoinInputs.get(currInput).getRowType().getFieldCount();
            }
            int key = currInput;
            int[] refCounts = (int[])Objects.requireNonNull(refCountsMap.get(key), () -> "refCountsMap.get(currInput) for " + key);
            int n = i - startField;
            refCounts[n] = refCounts[n] + joinCondRefCounts[i];
        }
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry entry : refCountsMap.entrySet()) {
            builder.put(entry.getKey(), (Object)ImmutableIntList.of((int[])entry.getValue()));
        }
        return builder.build();
    }

    private static List<@Nullable RexNode> combinePostJoinFilters(Join joinRel, RelNode left, RelNode right) {
        ArrayList<@Nullable RexNode> filters = new ArrayList<RexNode>();
        if (right instanceof MultiJoin) {
            MultiJoin multiRight = (MultiJoin)right;
            filters.add(JoinToMultiJoinRule.shiftRightFilter(joinRel, left, multiRight, multiRight.getPostJoinFilter()));
        }
        if (left instanceof MultiJoin) {
            filters.add(((MultiJoin)left).getPostJoinFilter());
        }
        return filters;
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableJoinToMultiJoinRule.Config.of().withOperandFor(LogicalJoin.class);

        @Override
        default public JoinToMultiJoinRule toRule() {
            return new JoinToMultiJoinRule(this);
        }

        default public Config withOperandFor(Class<? extends Join> joinClass) {
            return this.withOperandSupplier(b0 -> b0.operand(joinClass).inputs(b1 -> b1.operand(RelNode.class).anyInputs(), b2 -> b2.operand(RelNode.class).anyInputs())).as(Config.class);
        }
    }

    private static class InputReferenceCounter
    extends RexVisitorImpl<Void> {
        private final int[] refCounts;

        InputReferenceCounter(int[] refCounts) {
            super(true);
            this.refCounts = refCounts;
        }

        @Override
        public Void visitInputRef(RexInputRef inputRef) {
            int n = inputRef.getIndex();
            this.refCounts[n] = this.refCounts[n] + 1;
            return null;
        }
    }
}

