Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@

import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
Expand All @@ -35,6 +38,8 @@
import com.google.common.collect.ImmutableList;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -88,12 +93,15 @@ private Plan toAntiJoin(LogicalFilter<LogicalJoin<Plan, Plan>> filter, Map<ExprI
.filter(p -> TypeUtils.isNull(p).isPresent())
.flatMap(p -> p.getInputSlots().stream())
.collect(Collectors.toSet());
Set<Slot> leftAlwaysNullSlots = join.left().getOutputSet().stream()
.filter(s -> alwaysNullSlots.contains(s) && !s.nullable())
.collect(Collectors.toSet());
Set<Slot> rightAlwaysNullSlots = join.right().getOutputSet().stream()
.filter(s -> alwaysNullSlots.contains(s) && !s.nullable())
.collect(Collectors.toSet());
Set<Slot> nullRejectedJoinSlots = getNullRejectedJoinSlots(join);
List<Slot> leftOutput = join.left().getOutput();
List<Slot> rightOutput = join.right().getOutput();
List<Slot> joinOutput = join.getOutput();
Set<Slot> leftAlwaysNullSlots = getAlwaysNullSlots(leftOutput,
joinOutput.subList(0, leftOutput.size()), alwaysNullSlots, nullRejectedJoinSlots);
Set<Slot> rightAlwaysNullSlots = getAlwaysNullSlots(rightOutput,
joinOutput.subList(leftOutput.size(), leftOutput.size() + rightOutput.size()),
alwaysNullSlots, nullRejectedJoinSlots);

Plan newChild = null;
if (join.getJoinType().isLeftOuterJoin() && !rightAlwaysNullSlots.isEmpty()) {
Expand Down Expand Up @@ -128,4 +136,53 @@ private Plan toAntiJoin(LogicalFilter<LogicalJoin<Plan, Plan>> filter, Map<ExprI
return filter.withChildren(newChild);
}
}

private Set<Slot> getNullRejectedJoinSlots(LogicalJoin<Plan, Plan> join) {
Set<Slot> leftOutputSet = join.left().getOutputSet();
Set<Slot> rightOutputSet = join.right().getOutputSet();
Set<Slot> nullRejectedJoinSlots = new HashSet<>();
collectNullRejectedJoinSlots(join.getHashJoinConjuncts(), leftOutputSet, rightOutputSet,
nullRejectedJoinSlots);
collectNullRejectedJoinSlots(join.getOtherJoinConjuncts(), leftOutputSet, rightOutputSet,
nullRejectedJoinSlots);
return nullRejectedJoinSlots;
}

private void collectNullRejectedJoinSlots(List<Expression> conjuncts, Set<Slot> leftOutputSet,
Set<Slot> rightOutputSet, Set<Slot> nullRejectedJoinSlots) {
for (Expression conjunct : conjuncts) {
if (!(conjunct instanceof ComparisonPredicate) || conjunct instanceof NullSafeEqual) {
continue;
}
ComparisonPredicate predicate = (ComparisonPredicate) conjunct;
if (isSlotComparisonBetweenChildren(predicate, leftOutputSet, rightOutputSet)) {
nullRejectedJoinSlots.addAll(predicate.getInputSlots());
}
}
}

private Set<Slot> getAlwaysNullSlots(List<Slot> childOutput, List<Slot> joinOutput, Set<Slot> alwaysNullSlots,
Set<Slot> nullRejectedJoinSlots) {
Set<Slot> result = new HashSet<>();
for (int i = 0; i < childOutput.size(); i++) {
Slot childSlot = childOutput.get(i);
Slot outputSlot = joinOutput.get(i);
if ((alwaysNullSlots.contains(childSlot) || alwaysNullSlots.contains(outputSlot))
&& (!childSlot.nullable() || nullRejectedJoinSlots.contains(childSlot))) {
result.add(childSlot);
}
}
return result;
}

private boolean isSlotComparisonBetweenChildren(ComparisonPredicate predicate, Set<Slot> leftOutputSet,
Set<Slot> rightOutputSet) {
Expression left = predicate.left();
Expression right = predicate.right();
if (!(left instanceof Slot) || !(right instanceof Slot)) {
return false;
}
return (leftOutputSet.contains(left) && rightOutputSet.contains(right))
|| (leftOutputSet.contains(right) && rightOutputSet.contains(left));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,19 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.HashDistributionInfo;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.PartitionInfo;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
Expand All @@ -32,12 +42,15 @@
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TStorageType;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.List;

class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
private LogicalOlapScan scan1;
private LogicalOlapScan scan2;
Expand Down Expand Up @@ -66,6 +79,58 @@ void testEliminateLeftWithProject() {
.matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin()));
}

@Test
void testEliminateLeftWithNullableRightSlotInEqualJoinCondition() {
LogicalOlapScan left = newNullableLogicalOlapScan(10, "nullable_left");
LogicalOlapScan right = newNullableLogicalOlapScan(11, "nullable_right");
LogicalPlan plan = new LogicalPlanBuilder(left)
.join(right, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))
.filter(new IsNull(right.getOutput().get(0)))
.projectExprs(ImmutableList.copyOf(left.getOutput()))
.build();

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyCustom(new ConvertOuterJoinToAntiJoin())
.printlnTree()
.matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin()));
}

@Test
void testEliminateLeftWithNullableRightSlotInOtherJoinCondition() {
LogicalOlapScan left = newNullableLogicalOlapScan(12, "nullable_left_other");
LogicalOlapScan right = newNullableLogicalOlapScan(13, "nullable_right_other");
LogicalPlan plan = new LogicalPlanBuilder(left)
.join(right, JoinType.LEFT_OUTER_JOIN, ImmutableList.of(),
ImmutableList.<Expression>of(new LessThan(left.getOutput().get(0), right.getOutput().get(0))))
.filter(new IsNull(right.getOutput().get(0)))
.projectExprs(ImmutableList.copyOf(left.getOutput()))
.build();

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyCustom(new ConvertOuterJoinToAntiJoin())
.printlnTree()
.matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin()));
}

@Test
void testNoEliminateLeftWithNullableRightSlotInNullSafeEqualJoinCondition() {
LogicalOlapScan left = newNullableLogicalOlapScan(14, "nullable_left_null_safe");
LogicalOlapScan right = newNullableLogicalOlapScan(15, "nullable_right_null_safe");
LogicalPlan plan = new LogicalPlanBuilder(left)
.join(right, JoinType.LEFT_OUTER_JOIN,
ImmutableList.<Expression>of(new NullSafeEqual(left.getOutput().get(0),
right.getOutput().get(0))),
ImmutableList.of())
.filter(new IsNull(right.getOutput().get(0)))
.projectExprs(ImmutableList.copyOf(left.getOutput()))
.build();

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyCustom(new ConvertOuterJoinToAntiJoin())
.printlnTree()
.matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin()));
}

@Test
void testEliminateRightWithProject() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
Expand All @@ -81,6 +146,22 @@ void testEliminateRightWithProject() {
.matches(logicalJoin().when(join -> join.getJoinType().isRightAntiJoin()));
}

@Test
void testEliminateRightWithNullableLeftSlotInEqualJoinCondition() {
LogicalOlapScan left = newNullableLogicalOlapScan(16, "nullable_left_right_join");
LogicalOlapScan right = newNullableLogicalOlapScan(17, "nullable_right_right_join");
LogicalPlan plan = new LogicalPlanBuilder(left)
.join(right, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0))
.filter(new IsNull(left.getOutput().get(0)))
.projectExprs(ImmutableList.copyOf(right.getOutput()))
.build();

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyCustom(new ConvertOuterJoinToAntiJoin())
.printlnTree()
.matches(logicalJoin().when(join -> join.getJoinType().isRightAntiJoin()));
}

@Test
void testNoEliminateAsofWithProject() {
testNoEliminateAsofWithProjectHelper(JoinType.ASOF_LEFT_OUTER_JOIN);
Expand Down Expand Up @@ -231,4 +312,20 @@ private void testEliminateLeftWithAndPredicateHelper(JoinType joinType) {
.printlnTree()
.matches(logicalJoin().when(join -> join.getJoinType() == joinType));
}

private LogicalOlapScan newNullableLogicalOlapScan(long tableId, String tableName) {
List<Column> columns = ImmutableList.of(
new Column("id", Type.INT, true, AggregateType.NONE, true, "0", ""),
new Column("name", Type.STRING, false, AggregateType.NONE, true, "", ""));
HashDistributionInfo hashDistributionInfo = new HashDistributionInfo(3, ImmutableList.of(columns.get(0)));
OlapTable table = new OlapTable(tableId, tableName, columns,
KeysType.DUP_KEYS, new PartitionInfo(), hashDistributionInfo);
table.setIndexMeta(-1,
tableName,
table.getFullSchema(),
0, 0, (short) 0,
TStorageType.COLUMN,
KeysType.DUP_KEYS);
return new LogicalOlapScan(PlanConstructor.getNextRelationId(), table, ImmutableList.of("db"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ suite("transform_outer_join_to_anti") {
contains "OUTER JOIN"
}

explain {
sql("select eliminate_outer_join_A.* from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.null_b = eliminate_outer_join_A.a where eliminate_outer_join_B.null_b is null")
contains "ANTI JOIN"
}

explain {
sql("select eliminate_outer_join_B.* from eliminate_outer_join_A right outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.null_a where eliminate_outer_join_A.null_a is null")
contains "ANTI JOIN"
}

explain {
sql("select eliminate_outer_join_A.* from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.null_b <=> eliminate_outer_join_A.null_a where eliminate_outer_join_B.null_b is null")
contains "OUTER JOIN"
}

explain {
sql("select eliminate_outer_join_A.* from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_B.b is null or eliminate_outer_join_A.null_a is null")
contains "OUTER JOIN"
Expand Down
Loading