diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java index 8791bb693f9b06..d737d37f29c46c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java @@ -19,8 +19,11 @@ import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.NullSafeEqual; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.plans.JoinType; @@ -35,6 +38,8 @@ import com.google.common.collect.ImmutableList; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -88,12 +93,15 @@ private Plan toAntiJoin(LogicalFilter> filter, Map TypeUtils.isNull(p).isPresent()) .flatMap(p -> p.getInputSlots().stream()) .collect(Collectors.toSet()); - Set leftAlwaysNullSlots = join.left().getOutputSet().stream() - .filter(s -> alwaysNullSlots.contains(s) && !s.nullable()) - .collect(Collectors.toSet()); - Set rightAlwaysNullSlots = join.right().getOutputSet().stream() - .filter(s -> alwaysNullSlots.contains(s) && !s.nullable()) - .collect(Collectors.toSet()); + Set nullRejectedJoinSlots = getNullRejectedJoinSlots(join); + List leftOutput = join.left().getOutput(); + List rightOutput = join.right().getOutput(); + List joinOutput = join.getOutput(); + Set leftAlwaysNullSlots = getAlwaysNullSlots(leftOutput, + joinOutput.subList(0, leftOutput.size()), alwaysNullSlots, nullRejectedJoinSlots); + Set rightAlwaysNullSlots = getAlwaysNullSlots(rightOutput, + joinOutput.subList(leftOutput.size(), leftOutput.size() + rightOutput.size()), + alwaysNullSlots, nullRejectedJoinSlots); Plan newChild = null; if (join.getJoinType().isLeftOuterJoin() && !rightAlwaysNullSlots.isEmpty()) { @@ -128,4 +136,53 @@ private Plan toAntiJoin(LogicalFilter> filter, Map getNullRejectedJoinSlots(LogicalJoin join) { + Set leftOutputSet = join.left().getOutputSet(); + Set rightOutputSet = join.right().getOutputSet(); + Set nullRejectedJoinSlots = new HashSet<>(); + collectNullRejectedJoinSlots(join.getHashJoinConjuncts(), leftOutputSet, rightOutputSet, + nullRejectedJoinSlots); + collectNullRejectedJoinSlots(join.getOtherJoinConjuncts(), leftOutputSet, rightOutputSet, + nullRejectedJoinSlots); + return nullRejectedJoinSlots; + } + + private void collectNullRejectedJoinSlots(List conjuncts, Set leftOutputSet, + Set rightOutputSet, Set nullRejectedJoinSlots) { + for (Expression conjunct : conjuncts) { + if (!(conjunct instanceof ComparisonPredicate) || conjunct instanceof NullSafeEqual) { + continue; + } + ComparisonPredicate predicate = (ComparisonPredicate) conjunct; + if (isSlotComparisonBetweenChildren(predicate, leftOutputSet, rightOutputSet)) { + nullRejectedJoinSlots.addAll(predicate.getInputSlots()); + } + } + } + + private Set getAlwaysNullSlots(List childOutput, List joinOutput, Set alwaysNullSlots, + Set nullRejectedJoinSlots) { + Set result = new HashSet<>(); + for (int i = 0; i < childOutput.size(); i++) { + Slot childSlot = childOutput.get(i); + Slot outputSlot = joinOutput.get(i); + if ((alwaysNullSlots.contains(childSlot) || alwaysNullSlots.contains(outputSlot)) + && (!childSlot.nullable() || nullRejectedJoinSlots.contains(childSlot))) { + result.add(childSlot); + } + } + return result; + } + + private boolean isSlotComparisonBetweenChildren(ComparisonPredicate predicate, Set leftOutputSet, + Set rightOutputSet) { + Expression left = predicate.left(); + Expression right = predicate.right(); + if (!(left instanceof Slot) || !(right instanceof Slot)) { + return false; + } + return (leftOutputSet.contains(left) && rightOutputSet.contains(right)) + || (leftOutputSet.contains(right) && rightOutputSet.contains(left)); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java index 229459d7824046..a039c5dae90de5 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java @@ -17,9 +17,19 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.catalog.AggregateType; +import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.HashDistributionInfo; +import org.apache.doris.catalog.KeysType; +import org.apache.doris.catalog.OlapTable; +import org.apache.doris.catalog.PartitionInfo; +import org.apache.doris.catalog.Type; import org.apache.doris.common.Pair; import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.IsNull; +import org.apache.doris.nereids.trees.expressions.LessThan; +import org.apache.doris.nereids.trees.expressions.NullSafeEqual; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; @@ -32,12 +42,15 @@ import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import org.apache.doris.qe.ConnectContext; +import org.apache.doris.thrift.TStorageType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Sets; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.util.List; + class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { private LogicalOlapScan scan1; private LogicalOlapScan scan2; @@ -66,6 +79,58 @@ void testEliminateLeftWithProject() { .matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin())); } + @Test + void testEliminateLeftWithNullableRightSlotInEqualJoinCondition() { + LogicalOlapScan left = newNullableLogicalOlapScan(10, "nullable_left"); + LogicalOlapScan right = newNullableLogicalOlapScan(11, "nullable_right"); + LogicalPlan plan = new LogicalPlanBuilder(left) + .join(right, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) + .filter(new IsNull(right.getOutput().get(0))) + .projectExprs(ImmutableList.copyOf(left.getOutput())) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyCustom(new ConvertOuterJoinToAntiJoin()) + .printlnTree() + .matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin())); + } + + @Test + void testEliminateLeftWithNullableRightSlotInOtherJoinCondition() { + LogicalOlapScan left = newNullableLogicalOlapScan(12, "nullable_left_other"); + LogicalOlapScan right = newNullableLogicalOlapScan(13, "nullable_right_other"); + LogicalPlan plan = new LogicalPlanBuilder(left) + .join(right, JoinType.LEFT_OUTER_JOIN, ImmutableList.of(), + ImmutableList.of(new LessThan(left.getOutput().get(0), right.getOutput().get(0)))) + .filter(new IsNull(right.getOutput().get(0))) + .projectExprs(ImmutableList.copyOf(left.getOutput())) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyCustom(new ConvertOuterJoinToAntiJoin()) + .printlnTree() + .matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin())); + } + + @Test + void testNoEliminateLeftWithNullableRightSlotInNullSafeEqualJoinCondition() { + LogicalOlapScan left = newNullableLogicalOlapScan(14, "nullable_left_null_safe"); + LogicalOlapScan right = newNullableLogicalOlapScan(15, "nullable_right_null_safe"); + LogicalPlan plan = new LogicalPlanBuilder(left) + .join(right, JoinType.LEFT_OUTER_JOIN, + ImmutableList.of(new NullSafeEqual(left.getOutput().get(0), + right.getOutput().get(0))), + ImmutableList.of()) + .filter(new IsNull(right.getOutput().get(0))) + .projectExprs(ImmutableList.copyOf(left.getOutput())) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyCustom(new ConvertOuterJoinToAntiJoin()) + .printlnTree() + .matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin())); + } + @Test void testEliminateRightWithProject() { LogicalPlan plan = new LogicalPlanBuilder(scan1) @@ -81,6 +146,22 @@ void testEliminateRightWithProject() { .matches(logicalJoin().when(join -> join.getJoinType().isRightAntiJoin())); } + @Test + void testEliminateRightWithNullableLeftSlotInEqualJoinCondition() { + LogicalOlapScan left = newNullableLogicalOlapScan(16, "nullable_left_right_join"); + LogicalOlapScan right = newNullableLogicalOlapScan(17, "nullable_right_right_join"); + LogicalPlan plan = new LogicalPlanBuilder(left) + .join(right, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0)) + .filter(new IsNull(left.getOutput().get(0))) + .projectExprs(ImmutableList.copyOf(right.getOutput())) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyCustom(new ConvertOuterJoinToAntiJoin()) + .printlnTree() + .matches(logicalJoin().when(join -> join.getJoinType().isRightAntiJoin())); + } + @Test void testNoEliminateAsofWithProject() { testNoEliminateAsofWithProjectHelper(JoinType.ASOF_LEFT_OUTER_JOIN); @@ -231,4 +312,20 @@ private void testEliminateLeftWithAndPredicateHelper(JoinType joinType) { .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType() == joinType)); } + + private LogicalOlapScan newNullableLogicalOlapScan(long tableId, String tableName) { + List columns = ImmutableList.of( + new Column("id", Type.INT, true, AggregateType.NONE, true, "0", ""), + new Column("name", Type.STRING, false, AggregateType.NONE, true, "", "")); + HashDistributionInfo hashDistributionInfo = new HashDistributionInfo(3, ImmutableList.of(columns.get(0))); + OlapTable table = new OlapTable(tableId, tableName, columns, + KeysType.DUP_KEYS, new PartitionInfo(), hashDistributionInfo); + table.setIndexMeta(-1, + tableName, + table.getFullSchema(), + 0, 0, (short) 0, + TStorageType.COLUMN, + KeysType.DUP_KEYS); + return new LogicalOlapScan(PlanConstructor.getNextRelationId(), table, ImmutableList.of("db")); + } } diff --git a/regression-test/suites/nereids_syntax_p0/transform_outer_join_to_anti.groovy b/regression-test/suites/nereids_syntax_p0/transform_outer_join_to_anti.groovy index f806f4ce5c7a5e..9c73fefc24c8c2 100644 --- a/regression-test/suites/nereids_syntax_p0/transform_outer_join_to_anti.groovy +++ b/regression-test/suites/nereids_syntax_p0/transform_outer_join_to_anti.groovy @@ -65,6 +65,21 @@ suite("transform_outer_join_to_anti") { contains "OUTER JOIN" } + explain { + sql("select eliminate_outer_join_A.* from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.null_b = eliminate_outer_join_A.a where eliminate_outer_join_B.null_b is null") + contains "ANTI JOIN" + } + + explain { + sql("select eliminate_outer_join_B.* from eliminate_outer_join_A right outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.null_a where eliminate_outer_join_A.null_a is null") + contains "ANTI JOIN" + } + + explain { + sql("select eliminate_outer_join_A.* from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.null_b <=> eliminate_outer_join_A.null_a where eliminate_outer_join_B.null_b is null") + contains "OUTER JOIN" + } + explain { sql("select eliminate_outer_join_A.* from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_B.b is null or eliminate_outer_join_A.null_a is null") contains "OUTER JOIN"