Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
Expand Down Expand Up @@ -71,7 +71,7 @@ public Rule build() {
}

@VisibleForTesting
protected static boolean isBinaryArithmeticSlot(TreeNode<Expression> expr) {
protected static boolean isBinaryArithmeticSlot(Expression expr) {
if (expr instanceof Slot) {
return true;
}
Expand All @@ -81,7 +81,56 @@ protected static boolean isBinaryArithmeticSlot(TreeNode<Expression> expr) {
if (!supportedFunctions.contains(expr.getClass())) {
return false;
}
return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent() && expr.child(1) instanceof Literal
|| ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0) instanceof Literal;

// Float/double arithmetic: precision loss for all operations
if (expr.child(0).getDataType().isFloatLikeType()
|| expr.child(1).getDataType().isFloatLikeType()) {
return false;
}

Expression slotExpr;
Literal literal;
if (expr.child(0) instanceof Literal) {
literal = (Literal) expr.child(0);
slotExpr = expr.child(1);
} else if (expr.child(1) instanceof Literal) {
literal = (Literal) expr.child(1);
slotExpr = expr.child(0);
} else {
return false;
}

if (!canExtractSlot(slotExpr)) {
return false;
}

return checkLiteral((BinaryArithmetic) expr, literal);
}

@VisibleForTesting
protected static boolean checkLiteral(BinaryArithmetic expr, Literal literal) {
if (literal.isNullLiteral()) {
return false;
}
if (expr instanceof Multiply || expr instanceof Divide) {
if (literal.isZero()) {
return false;
}
}
return true;
}

@VisibleForTesting
protected static boolean canExtractSlot(Expression expr) {
while (expr instanceof Cast) {
Cast cast = (Cast) expr;
Expression inner = cast.child();
if (!inner.getDataType().isInjectiveCastTo(cast.getDataType())) {
return false;
}
expr = inner;
}
return expr instanceof Slot;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
Expand Down Expand Up @@ -355,41 +354,6 @@ public static <S extends NamedExpression> S selectMinimumColumn(Collection<S> sl
return minSlot;
}

/**
* Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot}
* or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot}
* <p>
* for example:
* - SlotReference to a column:
* col
* - Cast on SlotReference:
* cast(int_col as string)
* cast(cast(int_col as long) as string)
*
* @param expr input expression
* @return Return Optional[ExprId] of underlying slot reference if input expression is a slot or cast on slot.
* Otherwise, return empty optional result.
*/
public static Optional<ExprId> isSlotOrCastOnSlot(Expression expr) {
return extractSlotOrCastOnSlot(expr).map(Slot::getExprId);
}

/**
* Check whether the input expression is a {@link org.apache.doris.nereids.trees.expressions.Slot}
* or at least one {@link Cast} on a {@link org.apache.doris.nereids.trees.expressions.Slot}
*/
public static Optional<Slot> extractSlotOrCastOnSlot(Expression expr) {
while (expr instanceof Cast) {
expr = expr.child(0);
}

if (expr instanceof SlotReference) {
return Optional.of((Slot) expr);
} else {
return Optional.empty();
}
}

/**
* Generate replaceMap Slot -> Expression from NamedExpression[Expression as name]
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,30 @@
package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Mod;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
Expand All @@ -41,6 +53,7 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.util.List;

class SimplifyAggGroupByTest implements MemoPatternMatchSupported {
Expand Down Expand Up @@ -156,4 +169,231 @@ void testisBinaryArithmeticSlot() {
Divide divide = new Divide(id, Literal.of(2));
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(divide));
}

// ========== new tests for injectivity checks ==========

@Test
void testMultiplyByZero() {
Slot id = scan1.getOutput().get(0);
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Multiply(id, Literal.of(0))));
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Multiply(Literal.of(0), id)));
}

@Test
void testDivideZeroNumerator() {
Slot id = scan1.getOutput().get(0);
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Divide(Literal.of(0), id)));
}

@Test
void testDivideByZero() {
Slot id = scan1.getOutput().get(0);
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Divide(id, Literal.of(0))));
}

@Test
void testNullLiteral() {
Slot id = scan1.getOutput().get(0);
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Add(id, NullLiteral.INSTANCE)));
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Multiply(id, NullLiteral.INSTANCE)));
}

@Test
void testMultiplyWithDoubleLiteral() {
Slot id = scan1.getOutput().get(0);
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Multiply(id, new DoubleLiteral(0.1))));
}

@Test
void testDivideWithDoubleLiteral() {
Slot id = scan1.getOutput().get(0);
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Divide(id, new DoubleLiteral(2.0))));
}

@Test
void testMultiplyWithFloatSlot() {
Slot floatSlot = new SlotReference("f", FloatType.INSTANCE);
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Multiply(floatSlot, Literal.of(2))));
}

@Test
void testMultiplyDoubleSlotWithIntLiteral() {
Slot doubleSlot = new SlotReference("d", DoubleType.INSTANCE);
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Multiply(doubleSlot, Literal.of(2))));
}

@Test
void testAddWithDoubleLiteral() {
// Float/double arithmetic may be imprecise, reject for all ops
Slot id = scan1.getOutput().get(0);
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Add(id, new DoubleLiteral(1.0))));
}

@Test
void testAddWithFloatLiteral() {
Slot id = scan1.getOutput().get(0);
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Add(id, new FloatLiteral(1.0f))));
}

@Test
void testSubtractWithDoubleLiteral() {
Slot id = scan1.getOutput().get(0);
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Subtract(id, new DoubleLiteral(1.0))));
}

@Test
void testMultiplyWithDecimalLiteral() {
// Small decimal multiply should pass (precision fits)
Slot id = scan1.getOutput().get(0);
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Multiply(id, new DecimalLiteral(new BigDecimal("2.0")))));
}

@Test
void testDivideWithDecimalLiteral() {
// Divide with decimal: precision overflow too extreme to worry about
Slot id = scan1.getOutput().get(0);
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Divide(id, new DecimalLiteral(new BigDecimal("2.0")))));
}

@Test
void testAddWithDecimalLiteral() {
// Add/Subtract with decimal are exact, should pass
Slot id = scan1.getOutput().get(0);
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(
new Add(id, new DecimalLiteral(new BigDecimal("1.0")))));
}

// ========== tests for isInjectiveCastTo ==========

@Test
void testIntegerWidening() {
Assertions.assertTrue(TinyIntType.INSTANCE.isInjectiveCastTo(IntegerType.INSTANCE));
Assertions.assertTrue(IntegerType.INSTANCE.isInjectiveCastTo(BigIntType.INSTANCE));
Assertions.assertFalse(IntegerType.INSTANCE.isInjectiveCastTo(TinyIntType.INSTANCE));
Assertions.assertFalse(BigIntType.INSTANCE.isInjectiveCastTo(IntegerType.INSTANCE));
}

@Test
void testDecimalWidening() {
Assertions.assertTrue(DecimalV3Type.createDecimalV3Type(5, 2)
.isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(10, 4)));
Assertions.assertFalse(DecimalV3Type.createDecimalV3Type(10, 4)
.isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(5, 2)));
}

@Test
void testIntegralToDecimalWidening() {
Assertions.assertTrue(TinyIntType.INSTANCE
.isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(10, 0)));
// BigInt has 19 digits, DECIMAL(5,0) only has 5 integer digits
Assertions.assertFalse(BigIntType.INSTANCE
.isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(5, 0)));
}

@Test
void testCrossFamilyRejected() {
Assertions.assertFalse(IntegerType.INSTANCE.isInjectiveCastTo(FloatType.INSTANCE));
Assertions.assertFalse(FloatType.INSTANCE.isInjectiveCastTo(IntegerType.INSTANCE));
Assertions.assertFalse(IntegerType.INSTANCE.isInjectiveCastTo(DoubleType.INSTANCE));
}

// ========== tests for canExtractSlot ==========

@Test
void testCanExtractSlotBare() {
Slot id = scan1.getOutput().get(0);
Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(id));
}

@Test
void testCanExtractSlotWidening() {
Slot id = scan1.getOutput().get(0);
// INT->BIGINT is lossless widening
Expression cast = new Cast(id, BigIntType.INSTANCE);
Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(cast));
}

@Test
void testCanExtractSlotExplicitCast() {
Slot id = scan1.getOutput().get(0);
// explicit cast should also be acceptable if lossless
Expression cast = new Cast(id, BigIntType.INSTANCE, true);
Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(cast));
}

@Test
void testCanExtractSlotNarrowing() {
Slot id = scan1.getOutput().get(0);
// INT -> TINYINT is narrowing, should be rejected
Expression cast = new Cast(id, TinyIntType.INSTANCE);
Assertions.assertFalse(SimplifyAggGroupBy.canExtractSlot(cast));
}

// ========== integration tests via PlanChecker ==========

@Test
void testMultiplyByZeroNotSimplified() {
Slot id = scan1.getOutput().get(0);
List<NamedExpression> output = ImmutableList.of(id, new Count().alias("cnt"));
List<Expression> groupBy = ImmutableList.of(id, new Multiply(id, Literal.of(0)));
LogicalPlan agg = new LogicalPlanBuilder(scan1)
.agg(groupBy, output)
.build();
ConnectContext connectContext = MemoTestUtils.createConnectContext();
connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false);
PlanChecker.from(connectContext, agg)
.applyTopDown(new SimplifyAggGroupBy())
.matchesFromRoot(
logicalAggregate().when(a -> a.equals(agg))
);
}

@Test
void testNullLiteralNotSimplified() {
Slot id = scan1.getOutput().get(0);
List<NamedExpression> output = ImmutableList.of(id, new Count().alias("cnt"));
List<Expression> groupBy = ImmutableList.of(id, new Add(id, NullLiteral.INSTANCE));
LogicalPlan agg = new LogicalPlanBuilder(scan1)
.agg(groupBy, output)
.build();
ConnectContext connectContext = MemoTestUtils.createConnectContext();
connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false);
PlanChecker.from(connectContext, agg)
.applyTopDown(new SimplifyAggGroupBy())
.matchesFromRoot(
logicalAggregate().when(a -> a.equals(agg))
);
}

@Test
void testMultiplyDoubleLiteralNotSimplified() {
Slot id = scan1.getOutput().get(0);
List<NamedExpression> output = ImmutableList.of(id, new Count().alias("cnt"));
List<Expression> groupBy = ImmutableList.of(id, new Multiply(id, new DoubleLiteral(0.1)));
LogicalPlan agg = new LogicalPlanBuilder(scan1)
.agg(groupBy, output)
.build();
ConnectContext connectContext = MemoTestUtils.createConnectContext();
connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false);
PlanChecker.from(connectContext, agg)
.applyTopDown(new SimplifyAggGroupBy())
.matchesFromRoot(
logicalAggregate().when(a -> a.equals(agg))
);
}
}
Loading