package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlAvgAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.CompositeList;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;

/* loaded from: input_file:lib/calcite-core-1.13.0.jar:org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.class */
public class AggregateReduceFunctionsRule extends RelOptRule {
    public static final AggregateReduceFunctionsRule INSTANCE;
    static final /* synthetic */ boolean $assertionsDisabled;

    protected AggregateReduceFunctionsRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory) {
        super(relOptRuleOperand, relBuilderFactory, null);
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        if (super.matches(relOptRuleCall)) {
            return containsAvgStddevVarCall(((Aggregate) relOptRuleCall.rels[0]).getAggCallList());
        }
        return false;
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        reduceAggs(relOptRuleCall, (Aggregate) relOptRuleCall.rels[0]);
    }

    private boolean containsAvgStddevVarCall(List<AggregateCall> list) {
        for (AggregateCall aggregateCall : list) {
            if ((aggregateCall.getAggregation() instanceof SqlAvgAggFunction) || (aggregateCall.getAggregation() instanceof SqlSumAggFunction)) {
                return true;
            }
        }
        return false;
    }

    private void reduceAggs(RelOptRuleCall relOptRuleCall, Aggregate aggregate) {
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        List<AggregateCall> aggCallList = aggregate.getAggCallList();
        int groupCount = aggregate.getGroupCount();
        int indicatorCount = aggregate.getIndicatorCount();
        List<AggregateCall> newArrayList = Lists.newArrayList();
        Map<AggregateCall, RexNode> newHashMap = Maps.newHashMap();
        ArrayList newArrayList2 = Lists.newArrayList();
        for (int i = 0; i < groupCount + indicatorCount; i++) {
            newArrayList2.add(rexBuilder.makeInputRef(getFieldType(aggregate, i), i));
        }
        RelBuilder builder = relOptRuleCall.builder();
        builder.push(aggregate.getInput());
        List<RexNode> arrayList = new ArrayList<>(builder.fields());
        Iterator<AggregateCall> it2 = aggCallList.iterator();
        while (it2.hasNext()) {
            newArrayList2.add(reduceAgg(aggregate, it2.next(), newArrayList, newHashMap, arrayList));
        }
        int size = arrayList.size() - builder.peek().getRowType().getFieldCount();
        if (size > 0) {
            builder.project(arrayList, CompositeList.of((List) builder.peek().getRowType().getFieldNames(), Collections.nCopies(size, null)));
        }
        newAggregateRel(builder, aggregate, newArrayList);
        builder.project(newArrayList2, aggregate.getRowType().getFieldNames());
        relOptRuleCall.transformTo(builder.build());
    }

    private RexNode reduceAgg(Aggregate aggregate, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map, List<RexNode> list2) {
        if (aggregateCall.getAggregation() instanceof SqlSumAggFunction) {
            return reduceSum(aggregate, aggregateCall, list, map);
        }
        if (!(aggregateCall.getAggregation() instanceof SqlAvgAggFunction)) {
            return aggregate.getCluster().getRexBuilder().addAggCall(aggregateCall, aggregate.getGroupCount(), aggregate.indicator, list, map, SqlTypeUtil.projectTypes(aggregate.getInput().getRowType(), aggregateCall.getArgList()));
        }
        SqlKind kind = aggregateCall.getAggregation().getKind();
        switch (kind) {
            case AVG:
                return reduceAvg(aggregate, aggregateCall, list, map);
            case STDDEV_POP:
                return reduceStddev(aggregate, aggregateCall, true, true, list, map, list2);
            case STDDEV_SAMP:
                return reduceStddev(aggregate, aggregateCall, false, true, list, map, list2);
            case VAR_POP:
                return reduceStddev(aggregate, aggregateCall, true, false, list, map, list2);
            case VAR_SAMP:
                return reduceStddev(aggregate, aggregateCall, false, false, list, map, list2);
            default:
                throw Util.unexpected(kind);
        }
    }

    private RexNode reduceAvg(Aggregate aggregate, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map) {
        int groupCount = aggregate.getGroupCount();
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        RelDataType fieldType = getFieldType(aggregate.getInput(), aggregateCall.getArgList().get(0).intValue());
        AggregateCall create = AggregateCall.create(SqlStdOperatorTable.SUM, aggregateCall.isDistinct(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), null, null);
        AggregateCall create2 = AggregateCall.create(SqlStdOperatorTable.COUNT, aggregateCall.isDistinct(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), null, null);
        return rexBuilder.makeCast(aggregateCall.getType(), rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, rexBuilder.addAggCall(create, groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType)), rexBuilder.addAggCall(create2, groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType))));
    }

    private RexNode reduceSum(Aggregate aggregate, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map) {
        int groupCount = aggregate.getGroupCount();
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        RelDataType fieldType = getFieldType(aggregate.getInput(), aggregateCall.getArgList().get(0).intValue());
        AggregateCall create = AggregateCall.create(SqlStdOperatorTable.SUM0, aggregateCall.isDistinct(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), null, aggregateCall.name);
        AggregateCall create2 = AggregateCall.create(SqlStdOperatorTable.COUNT, aggregateCall.isDistinct(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate, null, null);
        RexNode addAggCall = rexBuilder.addAggCall(create, groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType));
        if (!aggregateCall.getType().isNullable()) {
            return addAggCall;
        }
        return rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, rexBuilder.addAggCall(create2, groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType)), rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.makeCast(addAggCall.getType(), rexBuilder.constantNull()), addAggCall);
    }

    private RexNode reduceStddev(Aggregate aggregate, AggregateCall aggregateCall, boolean z, boolean z2, List<AggregateCall> list, Map<AggregateCall, RexNode> map, List<RexNode> list2) {
        RexNode makeCall;
        int groupCount = aggregate.getGroupCount();
        RelOptCluster cluster = aggregate.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RelDataTypeFactory typeFactory = cluster.getTypeFactory();
        if (!$assertionsDisabled && aggregateCall.getArgList().size() != 1) {
            throw new AssertionError(aggregateCall.getArgList());
        }
        int intValue = aggregateCall.getArgList().get(0).intValue();
        RelDataType fieldType = getFieldType(aggregate.getInput(), intValue);
        RexNode rexNode = list2.get(intValue);
        RexNode addAggCall = rexBuilder.addAggCall(AggregateCall.create(SqlStdOperatorTable.SUM, aggregateCall.isDistinct(), ImmutableIntList.of(lookupOrAdd(list2, rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, rexNode, rexNode))), aggregateCall.filterArg, SqlStdOperatorTable.SUM.inferReturnType(new Aggregate.AggCallBinding(typeFactory, SqlStdOperatorTable.SUM, ImmutableList.of(rexNode.getType()), aggregate.getGroupCount(), aggregateCall.filterArg >= 0)), null), groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType));
        RexNode addAggCall2 = rexBuilder.addAggCall(AggregateCall.create(SqlStdOperatorTable.SUM, aggregateCall.isDistinct(), ImmutableIntList.of(intValue), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), null, null), groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType));
        RexNode makeCall2 = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, addAggCall2, addAggCall2);
        RexNode addAggCall3 = rexBuilder.addAggCall(AggregateCall.create(SqlStdOperatorTable.COUNT, aggregateCall.isDistinct(), aggregateCall.getArgList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), null, null), groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType));
        RexNode makeCall3 = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, addAggCall, rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, makeCall2, addAggCall3));
        if (z) {
            makeCall = addAggCall3;
        } else {
            RexLiteral makeExactLiteral = rexBuilder.makeExactLiteral(BigDecimal.ONE);
            makeCall = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, addAggCall3, makeExactLiteral), rexBuilder.makeCast(addAggCall3.getType(), rexBuilder.constantNull()), rexBuilder.makeCall(SqlStdOperatorTable.MINUS, addAggCall3, makeExactLiteral));
        }
        RexNode makeCall4 = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, makeCall3, makeCall);
        RexNode rexNode2 = makeCall4;
        if (z2) {
            rexNode2 = rexBuilder.makeCall(SqlStdOperatorTable.POWER, makeCall4, rexBuilder.makeExactLiteral(new BigDecimal("0.5")));
        }
        return rexBuilder.makeCast(aggregateCall.getType(), rexNode2);
    }

    private static <T> int lookupOrAdd(List<T> list, T t) {
        int indexOf = list.indexOf(t);
        if (indexOf == -1) {
            indexOf = list.size();
            list.add(t);
        }
        return indexOf;
    }

    protected void newAggregateRel(RelBuilder relBuilder, Aggregate aggregate, List<AggregateCall> list) {
        relBuilder.aggregate(relBuilder.groupKey(aggregate.getGroupSet(), aggregate.indicator, aggregate.getGroupSets()), list);
    }

    private RelDataType getFieldType(RelNode relNode, int i) {
        return relNode.getRowType().getFieldList().get(i).getType();
    }

    static {
        $assertionsDisabled = !AggregateReduceFunctionsRule.class.desiredAssertionStatus();
        INSTANCE = new AggregateReduceFunctionsRule(operand(LogicalAggregate.class, any()), RelFactories.LOGICAL_BUILDER);
    }
}
