From 7e187804780852af79b7b1cd07c4f75c80d207a0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 1 Jul 2026 08:14:45 -0700 Subject: [PATCH 1/8] [SPARK-57851][SQL] Shuffle-free single-task execution for small queries ### What changes were proposed in this pull request? This adds a conservative optimizer rule `MarkSingleTaskExecution` that marks small single-partition scans, optionally with a shuffle-inducing operator on top (sort, aggregate, distinct, window, limit/offset, expand) or an in-memory `LocalRelation`, as candidates for single-task execution. Such a scan reports a `SinglePartition` output partitioning, allowing `EnsureRequirements` to elide the shuffle that would otherwise be inserted before the operator on top. The rule runs as the last optimizer batch and marks eligible `LogicalRelation`/`LocalRelation` nodes with a `TreeNodeTag`. The planning strategies propagate the mark to `FileSourceScanExec`/`LocalTableScanExec`. `FileSourceScanExec` additionally gates on file count and size thresholds using the generic `ScanFileListing`, reports `SinglePartition`, and coalesces its input RDD to a single partition as a correctness backstop. `ExpandExec` forwards `SinglePartition` from its child, since Expand never moves rows across partitions. The feature is controlled by new internal configs under `spark.sql.optimizer.singleTaskExecution.*` and is disabled by default. Join is intentionally left out for now; union is already covered by the existing `spark.sql.unionOutputPartitioning`. This is part of the SPIP umbrella SPARK-56978 (Faster queries in local laptop mode), covering the shuffle-free local execution category. ### Why are the changes needed? For small, low-latency queries the fixed cost of a shuffle (scheduling, serialization, network) dominates. When the input is already a single small partition, the shuffle before a sort/aggregate/window is unnecessary and can be removed to reduce latency. ### Does this PR introduce _any_ user-facing change? No. The optimization is behind internal configs and is disabled by default. ### How was this patch tested? New `MarkSingleTaskExecutionSuite` (14 tests) covering the marking decision, `SinglePartition` output with no shuffle, empty-scan correctness, disabled-flag negatives, join/subquery ineligibility, and the leaf-parallelism override. `SQLConfSuite` passes as a config-wiring regression check. Co-authored-by: Isaac --- .../apache/spark/sql/internal/SQLConf.scala | 95 +++++++++ .../execution/SparkConnectPlanExecution.scala | 2 +- .../sql/execution/DataSourceScanExec.scala | 58 +++++- .../spark/sql/execution/ExpandExec.scala | 15 +- .../sql/execution/LocalTableScanExec.scala | 21 +- .../spark/sql/execution/SparkOptimizer.scala | 7 +- .../spark/sql/execution/SparkStrategies.scala | 6 +- .../datasources/FileSourceStrategy.scala | 4 +- .../datasources/MarkSingleTaskExecution.scala | 161 +++++++++++++++ .../apache/spark/sql/DataFrameJoinSuite.scala | 2 +- .../org/apache/spark/sql/SubquerySuite.scala | 2 +- .../MarkSingleTaskExecutionSuite.scala | 187 ++++++++++++++++++ 12 files changed, 541 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecution.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6776f88ed1ef8..e15ecaddbbb08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -7313,6 +7313,101 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SINGLE_TASK_EXECUTION_ENABLED = + buildConf("spark.sql.optimizer.singleTaskExecution.enabled") + .doc("When true, eligible query fragments that read a small single-partition scan can run " + + "in a single task, skipping the shuffle that would otherwise be inserted before an " + + "operator such as a sort or aggregation. This avoids the scheduling overhead of an " + + "unnecessary shuffle for small, low-latency queries.") + .version("4.3.0") + .booleanConf + .createWithDefault(false) + + val SINGLE_TASK_EXECUTION_AGGREGATION = + buildConf("spark.sql.optimizer.singleTaskExecution.aggregation") + .internal() + .doc("When true, and 'spark.sql.optimizer.singleTaskExecution.enabled' is also true, " + + "enable the single-task optimization for query plans with aggregation operators.") + .version("4.3.0") + .fallbackConf(SINGLE_TASK_EXECUTION_ENABLED) + + val SINGLE_TASK_EXECUTION_EXPAND = + buildConf("spark.sql.optimizer.singleTaskExecution.expand") + .internal() + .doc("When true, and 'spark.sql.optimizer.singleTaskExecution.enabled' is also true, " + + "enable the single-task optimization for query plans with expand operators.") + .version("4.3.0") + .fallbackConf(SINGLE_TASK_EXECUTION_ENABLED) + + val SINGLE_TASK_EXECUTION_LIMIT_OFFSET = + buildConf("spark.sql.optimizer.singleTaskExecution.limitOffset") + .internal() + .doc("When true, and 'spark.sql.optimizer.singleTaskExecution.enabled' is also true, " + + "enable the single-task optimization for query plans with limit or offset operators.") + .version("4.3.0") + .fallbackConf(SINGLE_TASK_EXECUTION_ENABLED) + + val SINGLE_TASK_EXECUTION_SORT = + buildConf("spark.sql.optimizer.singleTaskExecution.sort") + .internal() + .doc("When true, and 'spark.sql.optimizer.singleTaskExecution.enabled' is also true, " + + "enable the single-task optimization for query plans with sort operators.") + .version("4.3.0") + .fallbackConf(SINGLE_TASK_EXECUTION_ENABLED) + + val SINGLE_TASK_EXECUTION_WINDOW = + buildConf("spark.sql.optimizer.singleTaskExecution.window") + .internal() + .doc("When true, and 'spark.sql.optimizer.singleTaskExecution.enabled' is also true, " + + "enable the single-task optimization for query plans with window operators.") + .version("4.3.0") + .fallbackConf(SINGLE_TASK_EXECUTION_ENABLED) + + val SINGLE_TASK_EXECUTION_MAX_NUM_FILES = + buildConf("spark.sql.optimizer.singleTaskExecution.maxNumFiles") + .internal() + .doc("The maximum number of files that a file scan may have for the single-task " + + "optimization to apply to it.") + .version("4.3.0") + .intConf + .createWithDefault(1) + + val SINGLE_TASK_EXECUTION_MIN_NUM_FILES = + buildConf("spark.sql.optimizer.singleTaskExecution.minNumFiles") + .internal() + .doc("The minimum number of files that a file scan may have for the single-task " + + "optimization to apply to it.") + .version("4.3.0") + .intConf + .createWithDefault(1) + + val SINGLE_TASK_EXECUTION_MIN_NUM_BYTES = + buildConf("spark.sql.optimizer.singleTaskExecution.minNumBytes") + .internal() + .doc("The minimum total size in bytes that a file scan may have for the single-task " + + "optimization to apply to it.") + .version("4.3.0") + .longConf + .createWithDefault(1) + + val SINGLE_TASK_EXECUTION_LOCAL_TABLE_SCAN_MIN_ROWS = + buildConf("spark.sql.optimizer.singleTaskExecution.localTableScan.minRows") + .internal() + .doc("The minimum number of rows that a local in-memory relation may have for the " + + "single-task optimization to apply to it.") + .version("4.3.0") + .intConf + .createWithDefault(1) + + val SINGLE_TASK_EXECUTION_LOCAL_TABLE_SCAN_THRESHOLD = + buildConf("spark.sql.optimizer.singleTaskExecution.localTableScan.threshold") + .internal() + .doc("The maximum number of rows that a local in-memory relation may have for the " + + "single-task optimization to apply to it.") + .version("4.3.0") + .intConf + .createWithDefault(1000) + val LEGACY_PARSE_QUERY_WITHOUT_EOF = buildConf("spark.sql.legacy.parseQueryWithoutEof") .internal() .doc( diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 5fdfd5d1ccd16..0c4ca9357e848 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -213,7 +213,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) } } dataframe.queryExecution.executedPlan match { - case LocalTableScanExec(_, rows, _) => + case LocalTableScanExec(_, rows, _, _) => executePlan.eventsManager.postFinished(Some(rows.length)) var offset = 0L converter(rows.iterator).foreach { case (bytes, count) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index be7013188f2f9..cb5846c3ae4cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow, TableIdent import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.{truncatedString, CaseInsensitiveMap} import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.errors.QueryExecutionErrors @@ -320,6 +320,33 @@ trait FileSourceScanLike extends DataSourceScanExec with SessionStateHelper { def requiredSchema: StructType // Identifier for the table in the metastore. def tableIdentifier: Option[TableIdentifier] + // When true, the `MarkSingleTaskExecution` optimizer rule has marked this scan's plan shape as a + // candidate for single-task execution. The scan is only actually executed in a single task when + // it additionally passes the file count and size thresholds (see `useSingleTaskExecution`). + def markedForSingleTaskExecution: Boolean + + /** + * Whether this file scan should run in a single task, reporting a `SinglePartition` output + * partitioning so that a following shuffle can be elided. This is true when the plan shape was + * marked eligible by the optimizer and the statically-selected files fall within the configured + * count and size bounds. It relies on `selectedPartitions`, so it must not be evaluated before + * the scan's file listing is available. + */ + lazy val useSingleTaskExecution: Boolean = { + if (!markedForSingleTaskExecution) { + false + } else { + val sqlConf = getSqlConf(relation.sparkSession) + val minNumFiles = sqlConf.getConf(SQLConf.SINGLE_TASK_EXECUTION_MIN_NUM_FILES) + val maxNumFiles = sqlConf.getConf(SQLConf.SINGLE_TASK_EXECUTION_MAX_NUM_FILES) + val minNumBytes = sqlConf.getConf(SQLConf.SINGLE_TASK_EXECUTION_MIN_NUM_BYTES) + val maxPartitionBytes = sqlConf.getConf(SQLConf.FILES_MAX_PARTITION_BYTES) + val numFiles = selectedPartitions.totalNumberOfFiles + val numBytes = selectedPartitions.totalFileSize + numFiles >= minNumFiles && numFiles <= maxNumFiles && + numBytes >= minNumBytes && numBytes <= maxPartitionBytes + } + } lazy val fileConstantMetadataColumns: Seq[AttributeReference] = output.collect { @@ -478,6 +505,8 @@ trait FileSourceScanLike extends DataSourceScanExec with SessionStateHelper { Nil } (partitioning, sortOrder) + } else if (useSingleTaskExecution) { + (SinglePartition, Nil) } else { (UnknownPartitioning(0), Nil) } @@ -696,7 +725,8 @@ case class FileSourceScanExec( override val optionalNumCoalescedBuckets: Option[Int], override val dataFilters: Seq[Expression], override val tableIdentifier: Option[TableIdentifier], - override val disableBucketedScan: Boolean = false) + override val disableBucketedScan: Boolean = false, + override val markedForSingleTaskExecution: Boolean = false) extends FileSourceScanLike { // Note that some vals referring the file-based relation are lazy intentionally @@ -744,10 +774,28 @@ case class FileSourceScanExec( inputRDD :: Nil } + /** + * The input RDD, coalesced to a single partition when this scan runs in single-task mode. This + * enforces the `SinglePartition` output partitioning reported by `outputPartitioning`, which is + * estimated from the statically-selected files and may not correspond exactly to the number of + * partitions the input RDD produces after dynamic pruning. Coalescing here keeps the query + * correct in either case. + */ + private[spark] lazy val maybeCoalesceInputRDD: RDD[InternalRow] = { + if (useSingleTaskExecution && inputRDD.getNumPartitions > 1) { + inputRDD.coalesce(1) + } else if (useSingleTaskExecution && inputRDD.getNumPartitions == 0) { + // All files were pruned away; produce a single empty partition to match `SinglePartition`. + sparkContext.parallelize[InternalRow](Nil, 1) + } else { + inputRDD + } + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") if (needsUnsafeRowConversion) { - inputRDD.mapPartitionsWithIndexInternal { (index, iter) => + maybeCoalesceInputRDD.mapPartitionsWithIndexInternal { (index, iter) => val toUnsafe = UnsafeProjection.create(schema) toUnsafe.initialize(index) iter.map { row => @@ -756,7 +804,7 @@ case class FileSourceScanExec( } } } else { - inputRDD.mapPartitionsInternal { iter => + maybeCoalesceInputRDD.mapPartitionsInternal { iter => iter.map { row => numOutputRows += 1 row @@ -768,7 +816,7 @@ case class FileSourceScanExec( protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("numOutputRows") val scanTime = longMetric("scanTime") - inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + maybeCoalesceInputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => new Iterator[ColumnarBatch] { override def hasNext: Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 254772f73208d..3fffb68613976 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf @@ -43,8 +43,17 @@ case class ExpandExec( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) // The GroupExpressions can output data with arbitrary partitioning, so set it - // as UNKNOWN partitioning - override def outputPartitioning: Partitioning = UnknownPartitioning(0) + // as UNKNOWN partitioning. Expand only replicates rows within a partition and never moves rows + // across partitions, so when the single-task optimization is enabled and the child produces a + // single partition, we can forward the `SinglePartition` property to avoid an unneeded shuffle. + override def outputPartitioning: Partitioning = { + if (conf.getConf(SQLConf.SINGLE_TASK_EXECUTION_EXPAND) && + child.outputPartitioning == SinglePartition) { + SinglePartition + } else { + UnknownPartitioning(0) + } + } @transient override lazy val references: AttributeSet = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 2d5dbf8199599..930c27c8dff21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning} import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ArrayImplicits._ @@ -34,7 +35,10 @@ import org.apache.spark.util.ArrayImplicits._ case class LocalTableScanExec( output: Seq[Attribute], @transient rows: Seq[InternalRow], - @transient stream: Option[SparkDataStream]) + @transient stream: Option[SparkDataStream], + // When true, the relation is scanned in a single partition, so this node reports a + // `SinglePartition` output partitioning. Set by the `MarkSingleTaskExecution` optimizer rule. + useSingleTask: Boolean = false) extends LeafExecNode with StreamSourceAwareSparkPlan with InputRDDCodegen { @@ -55,12 +59,23 @@ case class LocalTableScanExec( if (rows.isEmpty) { sparkContext.emptyRDD } else { - val numSlices = math.min( - unsafeRows.length, session.leafNodeDefaultParallelism) + val numSlices = if (useSingleTask) { + 1 + } else { + math.min(unsafeRows.length, session.leafNodeDefaultParallelism) + } sparkContext.parallelize(unsafeRows.toImmutableArraySeq, numSlices) } } + override def outputPartitioning: Partitioning = { + if (useSingleTask) { + SinglePartition + } else { + UnknownPartitioning(0) + } + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") rdd.map { r => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 1b3b2d3efc72a..54158d5bb4a94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, PushVariantIntoScan, SchemaPruning, V1Writes} +import org.apache.spark.sql.execution.datasources.{MarkSingleTaskExecution, PruneFileSourcePartitions, PushVariantIntoScan, SchemaPruning, V1Writes} import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes} import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering} import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs, ExtractPythonUDTFs} @@ -100,7 +100,10 @@ class SparkOptimizer( ConstantFolding, EliminateLimits), Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*), - Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition))) + Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition), + // Must run last: it inspects the final plan shape to mark scans that can run in a single task, + // and no subsequent rule should reshape the plan or copy the marked scan nodes. + Batch("MarkSingleTaskExecution", Once, MarkSingleTaskExecution))) override def nonExcludableRules: Seq[String] = super.nonExcludableRules ++ Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6e761fbe07b27..9f8c818c38eb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -1161,8 +1161,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "TABLESAMPLE SYSTEM node was not properly handled by V2ScanRelationPushDown.") } execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil - case logical.LocalRelation(output, data, _, stream) => - LocalTableScanExec(output, data, stream) :: Nil + case r @ logical.LocalRelation(output, data, _, stream) => + val useSingleTask = r.getTagValue( + datasources.MarkSingleTaskExecution.markTag).getOrElse(false) + LocalTableScanExec(output, data, stream, useSingleTask) :: Nil case logical.EmptyRelation(l) => EmptyRelationExec(l) :: Nil case CommandResult(output, _, plan, data) => CommandResultExec(output, plan, data) :: Nil // We should match the combination of limit and offset first, to get the optimal physical diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 396375890c249..0a0ec315daa3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -333,7 +333,9 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { bucketSet, None, rebindFileSourceMetadataAttributesInFilters(dataFilters), - table.map(_.identifier)) + table.map(_.identifier), + markedForSingleTaskExecution = + l.getTagValue(MarkSingleTaskExecution.markTag).getOrElse(false)) // extra Project node: wrap flat metadata columns to a metadata struct val withMetadataProjections = metadataStructOpt.map { metadataStruct => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecution.scala new file mode 100644 index 0000000000000..4824ca8fe5038 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecution.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.internal.SQLConf + +/** + * This optimizer rule marks eligible query plans for single-task execution. The optimization + * targets a conservative, specific query shape to ensure predictable and efficient behavior. + * + * The rule matches simple query plans with a single small file scan or a single small in-memory + * relation, optionally with a shuffle-inducing operator (sort, aggregation, window, expand, or + * limit/offset) on top. When it detects such a shape, it marks the underlying scan: + * + * - a [[LogicalRelation]] or [[LocalRelation]] is marked with the + * [[MarkSingleTaskExecution.markTag]] tag. + * + * The physical scan then reports a `SinglePartition` output partitioning, which allows + * [[org.apache.spark.sql.execution.exchange.EnsureRequirements]] to elide the shuffle that would + * otherwise be inserted before the operator on top. This shuffle is not required for correctness + * of the query, so removing it reduces scheduling overhead for small, low-latency queries. + * + * The matching is deliberately strict and conservative to minimize the risk of unintended + * performance regressions. It can be broadened in the future as needed. + * + * This rule is controlled by [[SQLConf.SINGLE_TASK_EXECUTION_ENABLED]] and the per-operator + * sub-flags in [[SQLConf]]. + */ +object MarkSingleTaskExecution extends Rule[LogicalPlan] { + + /** + * Tag placed on a [[LogicalRelation]] or [[LocalRelation]] that has been marked eligible for + * single-task execution. The planning strategies read this tag to propagate the decision to the + * physical [[org.apache.spark.sql.execution.FileSourceScanExec]] / + * [[org.apache.spark.sql.execution.LocalTableScanExec]]. + */ + val markTag: TreeNodeTag[Boolean] = TreeNodeTag[Boolean]("__single_task_execution") + + private def get[T](entry: org.apache.spark.internal.config.ConfigEntry[T]): T = + SQLConf.get.getConf(entry) + + /** + * Plan patterns that make a query ineligible for the optimization. These operators either + * require shuffles that we cannot safely elide, or run user code whose behavior we should not + * change (e.g. user-defined aggregations skip the final merge step when run in a single task). + */ + val unsupportedPatterns: Seq[TreePattern] = Seq( + EVAL_PYTHON_UDF, + EVAL_PYTHON_UDTF, + EXISTS_SUBQUERY, + FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION, + LATERAL_SUBQUERY, + LIST_SUBQUERY, + PYTHON_UDF, + SCALAR_SUBQUERY) + + /** + * The per-operator sub-flags, resolved once per invocation. Each field indicates whether the + * corresponding shuffle-inducing operator is allowed on top of a single small scan. + */ + private case class EnabledOperators( + aggregation: Boolean, + expand: Boolean, + limitOffset: Boolean, + sort: Boolean, + window: Boolean) + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!get(SQLConf.SINGLE_TASK_EXECUTION_ENABLED)) { + return plan + } + val enabled = EnabledOperators( + aggregation = get(SQLConf.SINGLE_TASK_EXECUTION_AGGREGATION), + expand = get(SQLConf.SINGLE_TASK_EXECUTION_EXPAND), + limitOffset = get(SQLConf.SINGLE_TASK_EXECUTION_LIMIT_OFFSET), + sort = get(SQLConf.SINGLE_TASK_EXECUTION_SORT), + window = get(SQLConf.SINGLE_TASK_EXECUTION_WINDOW)) + + if (plan.containsAnyPattern(unsupportedPatterns: _*)) { + plan + } else if (isSupportedShape(plan, enabled)) { + markSingleTaskExecution(plan) + } else { + plan + } + } + + /** + * Returns true if every operator in the plan is one that we support keeping on top of a single + * small scan. Only operators that either do not require a shuffle, or whose shuffle-inducing + * sub-flag is enabled, are allowed. Any other operator makes the plan ineligible. + */ + private def isSupportedShape(plan: LogicalPlan, enabled: EnabledOperators): Boolean = plan match { + case _: LogicalRelation | _: LocalRelation => true + // Operators that never introduce a shuffle by themselves. + case _: Project | _: Filter | _: SubqueryAlias | + _: DeserializeToObject | _: SerializeFromObject => + plan.children.forall(isSupportedShape(_, enabled)) + // Shuffle-inducing operators, allowed only when the matching sub-flag is enabled. + case _: Aggregate if enabled.aggregation => + plan.children.forall(isSupportedShape(_, enabled)) + case _: Distinct if enabled.aggregation => + plan.children.forall(isSupportedShape(_, enabled)) + case _: Expand if enabled.expand => + plan.children.forall(isSupportedShape(_, enabled)) + case (_: GlobalLimit | _: LocalLimit | _: Offset) if enabled.limitOffset => + plan.children.forall(isSupportedShape(_, enabled)) + case _: Sort if enabled.sort => + plan.children.forall(isSupportedShape(_, enabled)) + case _: Window if enabled.window => + plan.children.forall(isSupportedShape(_, enabled)) + case _ => false + } + + /** + * Marks each scan in the (already validated) plan for single-task execution and returns the + * updated plan. + */ + private def markSingleTaskExecution(plan: LogicalPlan): LogicalPlan = plan match { + case lr: LogicalRelation => + lr.setTagValue(markTag, true) + lr + case r: LocalRelation => + if (isLocalRelationEligible(r)) { + r.setTagValue(markTag, true) + } + r + case other => + other.withNewChildren(other.children.map(markSingleTaskExecution)) + } + + /** + * A local in-memory relation is eligible when its row count falls within the configured bounds + * and there is no explicit leaf-node parallelism override in effect. + */ + private def isLocalRelationEligible(r: LocalRelation): Boolean = { + val minRows = get(SQLConf.SINGLE_TASK_EXECUTION_LOCAL_TABLE_SCAN_MIN_ROWS) + val threshold = get(SQLConf.SINGLE_TASK_EXECUTION_LOCAL_TABLE_SCAN_THRESHOLD) + get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM).isEmpty && + r.data.length >= minRows && r.data.length <= threshold + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 9733d51a91cba..129d6bb686762 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -447,7 +447,7 @@ class DataFrameJoinSuite extends SharedSparkSession } assert(broadcastExchanges.size == 1) val tables = broadcastExchanges.head.collect { - case FileSourceScanExec(_, _, _, _, _, _, _, _, Some(tableIdent), _) => tableIdent + case FileSourceScanExec(_, _, _, _, _, _, _, _, Some(tableIdent), _, _) => tableIdent } assert(tables.size == 1) assert(tables.head === diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index cd3e389d765d3..d8d885c9b927e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1523,7 +1523,7 @@ class SubquerySuite extends SharedSparkSession // need to execute the query before we can examine fs.inputRDDs() assert(stripAQEPlan(df.queryExecution.executedPlan) match { case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter( - fs @ FileSourceScanExec(_, _, _, _, partitionFilters, _, _, _, _, _)))) => + fs @ FileSourceScanExec(_, _, _, _, partitionFilters, _, _, _, _, _, _)))) => partitionFilters.exists(ExecSubqueryExpression.hasSubquery) && fs.inputRDDs().forall( _.asInstanceOf[FileScanRDD].filePartitions.forall( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala new file mode 100644 index 0000000000000..75029297a1aac --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.execution.{FileSourceScanExec, LocalTableScanExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Test suite for the [[MarkSingleTaskExecution]] optimizer rule and its physical effects. The rule + * marks small single-partition scans, optionally under a shuffle-inducing operator, so that the + * scan reports a `SinglePartition` output partitioning and the following shuffle can be elided. + */ +class MarkSingleTaskExecutionSuite extends QueryTest with SharedSparkSession + with AdaptiveSparkPlanHelper { + + private val t = "single_task_t" + private val t2 = "single_task_t2" + private val emptyTable = "single_task_empty" + + private def enabledConfs: Seq[(String, String)] = Seq( + SQLConf.SINGLE_TASK_EXECUTION_ENABLED.key -> "true", + // Force the optimization to also apply to zero-file / zero-byte scans so that we can exercise + // the empty-scan corner case created by dynamic pruning. + SQLConf.SINGLE_TASK_EXECUTION_MIN_NUM_FILES.key -> "0", + SQLConf.SINGLE_TASK_EXECUTION_MIN_NUM_BYTES.key -> "0", + SQLConf.SINGLE_TASK_EXECUTION_LOCAL_TABLE_SCAN_MIN_ROWS.key -> "0") + + override def beforeAll(): Unit = { + super.beforeAll() + // A single-file Parquet table with a small amount of data. + spark.range(0, 2).selectExpr("id as col", "cast(id as string) as col_str") + .repartition(1).write.mode("overwrite").saveAsTable(t) + spark.range(0, 3).selectExpr("id % 2 as col") + .repartition(1).write.mode("overwrite").saveAsTable(t2) + spark.range(0, 0).selectExpr("id as col").write.mode("overwrite").saveAsTable(emptyTable) + } + + override def afterAll(): Unit = { + try { + sql(s"drop table if exists $t") + sql(s"drop table if exists $t2") + sql(s"drop table if exists $emptyTable") + } finally { + super.afterAll() + } + } + + private def getFinalPhysicalPlan(df: org.apache.spark.sql.DataFrame): SparkPlan = { + df.queryExecution.executedPlan match { + case a: AdaptiveSparkPlanExec => a.finalPhysicalPlan + case other => other + } + } + + private def hasShuffle(plan: SparkPlan): Boolean = + collectWithSubqueries(plan) { case s: ShuffleExchangeLike => s }.nonEmpty + + private def isMarked(plan: LogicalPlan): Boolean = { + val marks = plan.collect { + case lr: LogicalRelation => lr.getTagValue(MarkSingleTaskExecution.markTag).getOrElse(false) + case lr: LocalRelation => lr.getTagValue(MarkSingleTaskExecution.markTag).getOrElse(false) + } + marks.nonEmpty && marks.forall(identity) + } + + private def checkMarked(query: String): Unit = withSQLConf(enabledConfs: _*) { + val plan = sql(query).queryExecution.optimizedPlan + assert(isMarked(plan), s"expected plan to be marked for single-task execution:\n$plan") + } + + private def checkNotMarked(query: String, confs: Seq[(String, String)] = enabledConfs): Unit = + withSQLConf(confs: _*) { + val plan = sql(query).queryExecution.optimizedPlan + assert(!isMarked(plan), s"expected plan NOT to be marked:\n$plan") + } + + private def checkSinglePartition( + query: String, + expected: Seq[Row], + confs: Seq[(String, String)] = enabledConfs): Unit = withSQLConf(confs: _*) { + val df = sql(query) + QueryTest.checkAnswer(df, expected) + val plan = getFinalPhysicalPlan(df) + assert(!hasShuffle(plan), s"expected no shuffle in:\n$plan") + val scans = collect(plan) { + case s: FileSourceScanExec => s.outputPartitioning + case s: LocalTableScanExec => s.outputPartitioning + } + assert(scans.nonEmpty, s"expected a scan in:\n$plan") + assert(scans.forall(_ == SinglePartition), + s"expected all scans to report SinglePartition, got $scans in:\n$plan") + } + + test("marks scan + sort") { + checkMarked(s"select col from $t order by col") + checkMarked(s"select col from (select col from $t where col = 0) order by col") + } + + test("marks scan + aggregation") { + checkMarked(s"select count(1) from $t group by col") + checkMarked(s"select sum(col) from (select col from $t where col < 42)") + } + + test("marks scan + window") { + checkMarked( + s"select col, row_number() over (partition by col order by col) from $t") + } + + test("does not mark when the feature is disabled") { + checkNotMarked( + s"select col from $t order by col", + Seq(SQLConf.SINGLE_TASK_EXECUTION_ENABLED.key -> "false")) + } + + test("does not mark when the per-operator flag is disabled") { + checkNotMarked( + s"select col from $t order by col", + enabledConfs :+ (SQLConf.SINGLE_TASK_EXECUTION_SORT.key -> "false")) + } + + test("does not mark unsupported plan shapes (join)") { + // Join is not a supported operator in this port, so the presence of a join makes the whole + // plan ineligible. + checkNotMarked(s"select a.col from $t a join $t b on a.col = b.col order by a.col") + } + + test("does not mark plans with subquery expressions") { + checkNotMarked(s"select col from $t where col = (select max(col) from $t2) order by col") + } + + test("output partitioning is SinglePartition, scan + sort") { + checkSinglePartition(s"select col from $t order by col", Seq(Row(0), Row(1))) + } + + test("output partitioning is SinglePartition, scan + aggregation with group by") { + checkSinglePartition( + s"select count(1) as c from $t2 group by col", + Seq(Row(1), Row(2))) + } + + test("output partitioning is SinglePartition, scan + aggregation without group by") { + checkSinglePartition(s"select sum(col) from $t", Seq(Row(1))) + } + + test("output partitioning is SinglePartition, scan + distinct") { + checkSinglePartition(s"select distinct col from $t2", Seq(Row(0), Row(1))) + } + + test("empty table scan + aggregation is correct and single-partition") { + // Without single-task execution eliding the shuffle before the aggregation, an empty scan + // could incorrectly return zero rows instead of a single NULL row for a global aggregation. + checkSinglePartition(s"select sum(col) from $emptyTable", Seq(Row(null))) + } + + test("in-memory local relation is scanned in a single partition") { + checkSinglePartition( + "select col from values (0), (1) as tab(col) order by col", + Seq(Row(0), Row(1))) + } + + test("does not mark when a leaf-node parallelism override is set") { + checkNotMarked( + "select col from values (0), (1) as tab(col) order by col", + enabledConfs :+ (SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key -> "4")) + } +} From eec31105d70b6ad28ae4218ff1dd6960348ea8eb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 1 Jul 2026 17:05:12 -0700 Subject: [PATCH 2/8] Set NOT_APPLICABLE binding policy on singleTaskExecution configs Co-authored-by: Claude Code --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e15ecaddbbb08..59f56f733d1f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -7320,6 +7320,7 @@ object SQLConf { "operator such as a sort or aggregation. This avoids the scheduling overhead of an " + "unnecessary shuffle for small, low-latency queries.") .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .booleanConf .createWithDefault(false) @@ -7329,6 +7330,7 @@ object SQLConf { .doc("When true, and 'spark.sql.optimizer.singleTaskExecution.enabled' is also true, " + "enable the single-task optimization for query plans with aggregation operators.") .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .fallbackConf(SINGLE_TASK_EXECUTION_ENABLED) val SINGLE_TASK_EXECUTION_EXPAND = @@ -7337,6 +7339,7 @@ object SQLConf { .doc("When true, and 'spark.sql.optimizer.singleTaskExecution.enabled' is also true, " + "enable the single-task optimization for query plans with expand operators.") .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .fallbackConf(SINGLE_TASK_EXECUTION_ENABLED) val SINGLE_TASK_EXECUTION_LIMIT_OFFSET = @@ -7345,6 +7348,7 @@ object SQLConf { .doc("When true, and 'spark.sql.optimizer.singleTaskExecution.enabled' is also true, " + "enable the single-task optimization for query plans with limit or offset operators.") .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .fallbackConf(SINGLE_TASK_EXECUTION_ENABLED) val SINGLE_TASK_EXECUTION_SORT = @@ -7353,6 +7357,7 @@ object SQLConf { .doc("When true, and 'spark.sql.optimizer.singleTaskExecution.enabled' is also true, " + "enable the single-task optimization for query plans with sort operators.") .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .fallbackConf(SINGLE_TASK_EXECUTION_ENABLED) val SINGLE_TASK_EXECUTION_WINDOW = @@ -7361,6 +7366,7 @@ object SQLConf { .doc("When true, and 'spark.sql.optimizer.singleTaskExecution.enabled' is also true, " + "enable the single-task optimization for query plans with window operators.") .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .fallbackConf(SINGLE_TASK_EXECUTION_ENABLED) val SINGLE_TASK_EXECUTION_MAX_NUM_FILES = @@ -7369,6 +7375,7 @@ object SQLConf { .doc("The maximum number of files that a file scan may have for the single-task " + "optimization to apply to it.") .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .intConf .createWithDefault(1) @@ -7378,6 +7385,7 @@ object SQLConf { .doc("The minimum number of files that a file scan may have for the single-task " + "optimization to apply to it.") .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .intConf .createWithDefault(1) @@ -7387,6 +7395,7 @@ object SQLConf { .doc("The minimum total size in bytes that a file scan may have for the single-task " + "optimization to apply to it.") .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .longConf .createWithDefault(1) @@ -7396,6 +7405,7 @@ object SQLConf { .doc("The minimum number of rows that a local in-memory relation may have for the " + "single-task optimization to apply to it.") .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .intConf .createWithDefault(1) @@ -7405,6 +7415,7 @@ object SQLConf { .doc("The maximum number of rows that a local in-memory relation may have for the " + "single-task optimization to apply to it.") .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .intConf .createWithDefault(1000) From 98796e2525aaa3f1669f43f63cde238310e54713 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 2 Jul 2026 00:04:24 -0700 Subject: [PATCH 3/8] Address review comments - LocalTableScanExec: produce one empty partition for an empty marked relation to match the advertised SinglePartition; a zero-partition RDD with the shuffle elided returns no rows for a global aggregation. - FileSourceScanLike: exclude bucketed scans from single-task execution; coalescing would invalidate their HashPartitioning. - ExpandExec: decide SinglePartition forwarding at planning time from the markTag instead of reading the session conf at execution time, and only within marked plans. - MarkSingleTaskExecution: skip the whole rule when a leaf-node parallelism override is set, so file scans respect it too; use the inherited conf.getConf instead of a private helper. - FileSourceScanExec.doCanonicalize: retain markedForSingleTaskExecution. Co-authored-by: Claude Code --- .../sql/execution/DataSourceScanExec.scala | 10 ++-- .../spark/sql/execution/ExpandExec.scala | 14 ++++-- .../sql/execution/LocalTableScanExec.scala | 10 +++- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../datasources/MarkSingleTaskExecution.scala | 46 +++++++++++-------- .../MarkSingleTaskExecutionSuite.scala | 43 +++++++++++++++++ 6 files changed, 97 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index cb5846c3ae4cb..a727ccf565063 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -329,11 +329,12 @@ trait FileSourceScanLike extends DataSourceScanExec with SessionStateHelper { * Whether this file scan should run in a single task, reporting a `SinglePartition` output * partitioning so that a following shuffle can be elided. This is true when the plan shape was * marked eligible by the optimizer and the statically-selected files fall within the configured - * count and size bounds. It relies on `selectedPartitions`, so it must not be evaluated before - * the scan's file listing is available. + * count and size bounds. Bucketed scans are excluded: they report a `HashPartitioning` over the + * bucket columns, which coalescing to a single partition would invalidate. It relies on + * `selectedPartitions`, so it must not be evaluated before the scan's file listing is available. */ lazy val useSingleTaskExecution: Boolean = { - if (!markedForSingleTaskExecution) { + if (!markedForSingleTaskExecution || bucketedScan) { false } else { val sqlConf = getSqlConf(relation.sparkSession) @@ -969,7 +970,8 @@ case class FileSourceScanExec( optionalNumCoalescedBuckets, QueryPlan.normalizePredicates(dataFilters, output), None, - disableBucketedScan) + disableBucketedScan, + markedForSingleTaskExecution) } override def getStream: Option[SparkDataStream] = stream diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 3fffb68613976..6b262cbb24c00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -36,7 +36,11 @@ import org.apache.spark.sql.internal.SQLConf case class ExpandExec( projections: Seq[Seq[Expression]], output: Seq[Attribute], - child: SparkPlan) + child: SparkPlan, + // When true, this Expand is part of a plan marked for single-task execution by the + // `MarkSingleTaskExecution` optimizer rule, and forwards the child's `SinglePartition` + // output partitioning (see `outputPartitioning`). + useSingleTask: Boolean = false) extends UnaryExecNode with CodegenSupport { override lazy val metrics = Map( @@ -44,11 +48,11 @@ case class ExpandExec( // The GroupExpressions can output data with arbitrary partitioning, so set it // as UNKNOWN partitioning. Expand only replicates rows within a partition and never moves rows - // across partitions, so when the single-task optimization is enabled and the child produces a - // single partition, we can forward the `SinglePartition` property to avoid an unneeded shuffle. + // across partitions, so when this Expand is part of a plan marked for single-task execution + // and the child produces a single partition, we can forward the `SinglePartition` property to + // avoid an unneeded shuffle. override def outputPartitioning: Partitioning = { - if (conf.getConf(SQLConf.SINGLE_TASK_EXECUTION_EXPAND) && - child.outputPartitioning == SinglePartition) { + if (useSingleTask && child.outputPartitioning == SinglePartition) { SinglePartition } else { UnknownPartitioning(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 930c27c8dff21..21c9a43710fac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -57,7 +57,15 @@ case class LocalTableScanExec( @transient private lazy val rdd: RDD[InternalRow] = { if (rows.isEmpty) { - sparkContext.emptyRDD + if (useSingleTask) { + // Produce a single empty partition to match the `SinglePartition` reported by + // `outputPartitioning`. `emptyRDD` has zero partitions, and running e.g. a global + // aggregation on a zero-partition RDD with the shuffle elided would return no rows + // instead of the single row expected on empty input. + sparkContext.parallelize(Seq.empty[InternalRow], 1) + } else { + sparkContext.emptyRDD + } } else { val numSlices = if (useSingleTask) { 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9f8c818c38eb9..d89f7a919269a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -1151,7 +1151,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case f: logical.TypedFilter => execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil case e @ logical.Expand(_, _, child) => - execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil + val useSingleTask = e.getTagValue( + datasources.MarkSingleTaskExecution.markTag).getOrElse(false) + execution.ExpandExec(e.projections, e.output, planLater(child), useSingleTask) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child, sampleMethod) => if (sampleMethod == logical.SampleMethod.System) { // V2ScanRelationPushDown is non-excludable and always handles SYSTEM samples diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecution.scala index 4824ca8fe5038..175387f0dab33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecution.scala @@ -32,7 +32,8 @@ import org.apache.spark.sql.internal.SQLConf * limit/offset) on top. When it detects such a shape, it marks the underlying scan: * * - a [[LogicalRelation]] or [[LocalRelation]] is marked with the - * [[MarkSingleTaskExecution.markTag]] tag. + * [[MarkSingleTaskExecution.markTag]] tag, as is any [[Expand]] in the plan so that the + * physical Expand can forward the child's `SinglePartition` output partitioning. * * The physical scan then reports a `SinglePartition` output partitioning, which allows * [[org.apache.spark.sql.execution.exchange.EnsureRequirements]] to elide the shuffle that would @@ -49,15 +50,14 @@ object MarkSingleTaskExecution extends Rule[LogicalPlan] { /** * Tag placed on a [[LogicalRelation]] or [[LocalRelation]] that has been marked eligible for - * single-task execution. The planning strategies read this tag to propagate the decision to the - * physical [[org.apache.spark.sql.execution.FileSourceScanExec]] / - * [[org.apache.spark.sql.execution.LocalTableScanExec]]. + * single-task execution, and on any [[Expand]] in such a plan. The planning strategies read + * this tag to propagate the decision to the physical + * [[org.apache.spark.sql.execution.FileSourceScanExec]] / + * [[org.apache.spark.sql.execution.LocalTableScanExec]] / + * [[org.apache.spark.sql.execution.ExpandExec]]. */ val markTag: TreeNodeTag[Boolean] = TreeNodeTag[Boolean]("__single_task_execution") - private def get[T](entry: org.apache.spark.internal.config.ConfigEntry[T]): T = - SQLConf.get.getConf(entry) - /** * Plan patterns that make a query ineligible for the optimization. These operators either * require shuffles that we cannot safely elide, or run user code whose behavior we should not @@ -85,15 +85,18 @@ object MarkSingleTaskExecution extends Rule[LogicalPlan] { window: Boolean) override def apply(plan: LogicalPlan): LogicalPlan = { - if (!get(SQLConf.SINGLE_TASK_EXECUTION_ENABLED)) { + // An explicit leaf-node parallelism override expresses the user's intent about how many + // partitions leaf scans should produce, so do not force scans into a single partition. + if (!conf.getConf(SQLConf.SINGLE_TASK_EXECUTION_ENABLED) || + conf.getConf(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM).isDefined) { return plan } val enabled = EnabledOperators( - aggregation = get(SQLConf.SINGLE_TASK_EXECUTION_AGGREGATION), - expand = get(SQLConf.SINGLE_TASK_EXECUTION_EXPAND), - limitOffset = get(SQLConf.SINGLE_TASK_EXECUTION_LIMIT_OFFSET), - sort = get(SQLConf.SINGLE_TASK_EXECUTION_SORT), - window = get(SQLConf.SINGLE_TASK_EXECUTION_WINDOW)) + aggregation = conf.getConf(SQLConf.SINGLE_TASK_EXECUTION_AGGREGATION), + expand = conf.getConf(SQLConf.SINGLE_TASK_EXECUTION_EXPAND), + limitOffset = conf.getConf(SQLConf.SINGLE_TASK_EXECUTION_LIMIT_OFFSET), + sort = conf.getConf(SQLConf.SINGLE_TASK_EXECUTION_SORT), + window = conf.getConf(SQLConf.SINGLE_TASK_EXECUTION_WINDOW)) if (plan.containsAnyPattern(unsupportedPatterns: _*)) { plan @@ -144,18 +147,23 @@ object MarkSingleTaskExecution extends Rule[LogicalPlan] { r.setTagValue(markTag, true) } r + case e: Expand => + // Also mark the Expand itself: the physical `ExpandExec` reads this tag to forward the + // child's `SinglePartition` output partitioning, which it must only do within a plan + // marked for single-task execution. + val marked = e.withNewChildren(e.children.map(markSingleTaskExecution)) + marked.setTagValue(markTag, true) + marked case other => other.withNewChildren(other.children.map(markSingleTaskExecution)) } /** - * A local in-memory relation is eligible when its row count falls within the configured bounds - * and there is no explicit leaf-node parallelism override in effect. + * A local in-memory relation is eligible when its row count falls within the configured bounds. */ private def isLocalRelationEligible(r: LocalRelation): Boolean = { - val minRows = get(SQLConf.SINGLE_TASK_EXECUTION_LOCAL_TABLE_SCAN_MIN_ROWS) - val threshold = get(SQLConf.SINGLE_TASK_EXECUTION_LOCAL_TABLE_SCAN_THRESHOLD) - get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM).isEmpty && - r.data.length >= minRows && r.data.length <= threshold + val minRows = conf.getConf(SQLConf.SINGLE_TASK_EXECUTION_LOCAL_TABLE_SCAN_MIN_ROWS) + val threshold = conf.getConf(SQLConf.SINGLE_TASK_EXECUTION_LOCAL_TABLE_SCAN_THRESHOLD) + r.data.length >= minRows && r.data.length <= threshold } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala index 75029297a1aac..d92575e938e75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{FileSourceScanExec, LocalTableScanExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike +import org.apache.spark.sql.functions.{count, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -122,6 +123,10 @@ class MarkSingleTaskExecutionSuite extends QueryTest with SharedSparkSession checkMarked(s"select sum(col) from (select col from $t where col < 42)") } + test("marks scan + expand (grouping sets)") { + checkMarked(s"select col, count(1) from $t group by rollup(col)") + } + test("marks scan + window") { checkMarked( s"select col, row_number() over (partition by col order by col) from $t") @@ -167,6 +172,30 @@ class MarkSingleTaskExecutionSuite extends QueryTest with SharedSparkSession checkSinglePartition(s"select distinct col from $t2", Seq(Row(0), Row(1))) } + test("output partitioning is SinglePartition, scan + expand") { + checkSinglePartition( + s"select col, count(1) as c from $t group by rollup(col)", + Seq(Row(0, 1), Row(1, 1), Row(null, 2))) + } + + test("bucketed scan does not run in a single task") { + val bucketed = "single_task_bucketed" + withTable(bucketed) { + spark.range(0, 2).selectExpr("id as col").write.bucketBy(2, "col").saveAsTable(bucketed) + // Raise the file count bound so that only the bucketing makes the scan ineligible. + val confs = enabledConfs :+ (SQLConf.SINGLE_TASK_EXECUTION_MAX_NUM_FILES.key -> "4") + withSQLConf(confs: _*) { + val df = sql(s"select col, count(1) as c from $bucketed group by col") + checkAnswer(df, Seq(Row(0, 1), Row(1, 1))) + val scans = collect(getFinalPhysicalPlan(df)) { case s: FileSourceScanExec => s } + assert(scans.nonEmpty) + assert(scans.forall(!_.useSingleTaskExecution), + "a bucketed scan must not run in a single task as that would invalidate its " + + "HashPartitioning over the bucket columns") + } + } + } + test("empty table scan + aggregation is correct and single-partition") { // Without single-task execution eliding the shuffle before the aggregation, an empty scan // could incorrectly return zero rows instead of a single NULL row for a global aggregation. @@ -179,9 +208,23 @@ class MarkSingleTaskExecutionSuite extends QueryTest with SharedSparkSession Seq(Row(0), Row(1))) } + test("empty local relation + global aggregation returns one row") { + withSQLConf(enabledConfs: _*) { + import testImplicits._ + val df = Seq.empty[Int].toDF("col").agg(count($"col"), sum($"col")) + assert(isMarked(df.queryExecution.optimizedPlan), + "expected the empty local relation to be marked for single-task execution") + // A global aggregation over an empty input must still return a single row. + checkAnswer(df, Row(0, null)) + } + } + test("does not mark when a leaf-node parallelism override is set") { checkNotMarked( "select col from values (0), (1) as tab(col) order by col", enabledConfs :+ (SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key -> "4")) + checkNotMarked( + s"select col from $t order by col", + enabledConfs :+ (SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key -> "4")) } } From 99c646ad26bb36b249573e8ec1a02e0724adf3e0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 2 Jul 2026 02:21:11 -0700 Subject: [PATCH 4/8] Keep ExpandExec explain output unchanged when useSingleTask is not set The new useSingleTask constructor parameter appended ", false" to every Expand node's explain arguments, breaking the TPC-DS plan stability golden files for rollup queries. Show the flag only when it is set. Co-authored-by: Claude Code --- .../org/apache/spark/sql/execution/ExpandExec.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 6b262cbb24c00..ba1238564348e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -59,6 +59,16 @@ case class ExpandExec( } } + // Show `useSingleTask` in the string representation only when it is set, so that plans not + // using single-task execution (the default) keep their existing explain output. + override protected def stringArgs: Iterator[Any] = { + if (useSingleTask) { + super.stringArgs + } else { + Iterator(projections, output, child) + } + } + @transient override lazy val references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) From e61ed42fd23e3b0582855c5e7ff3ec938d20d342 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 2 Jul 2026 08:53:07 -0700 Subject: [PATCH 5/8] Exclude user-defined aggregations and drop dead matches in MarkSingleTaskExecution - Add the USER_DEFINED_AGGREGATION tree pattern, tag ScalaUDAF, ScalaAggregator and the typed aggregate expressions with it, and add it to the rule's unsupported patterns. This is a defensive guard: an optimization that collapses partial and final aggregates separated by no exchange would skip the user's merge step. - Drop the Distinct and SubqueryAlias cases from isSupportedShape: both are rewritten away by non-excludable rules (ReplaceDistinctWithAggregate, EliminateSubqueryAliases) before this last optimizer batch runs. Co-authored-by: Claude Code --- .../sql/catalyst/trees/TreePatterns.scala | 1 + .../aggregate/TypedAggregateExpression.scala | 5 ++++ .../spark/sql/execution/aggregate/udaf.scala | 5 ++++ .../datasources/MarkSingleTaskExecution.scala | 15 ++++++---- .../MarkSingleTaskExecutionSuite.scala | 30 +++++++++++++++++-- 5 files changed, 48 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 94b4666a88a8d..dfb815414dd35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -107,6 +107,7 @@ object TreePattern extends Enumeration { val TIME_WINDOW: Value = Value val TIME_ZONE_AWARE_EXPRESSION: Value = Value val TRUE_OR_FALSE_LITERAL: Value = Value + val USER_DEFINED_AGGREGATION: Value = Value val VARIANT_GET: Value = Value val WINDOW_EXPRESSION: Value = Value val WINDOW_TIME: Value = Value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index d958790dd09b1..df0addad7861a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, USER_DEFINED_AGGREGATION} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -125,6 +126,8 @@ case class SimpleTypedAggregateExpression( nullable: Boolean) extends DeclarativeAggregate with TypedAggregateExpression with NonSQLExpression { + final override val nodePatterns: Seq[TreePattern] = Seq(USER_DEFINED_AGGREGATION) + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = { @@ -223,6 +226,8 @@ case class ComplexTypedAggregateExpression( inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[Any] with TypedAggregateExpression with NonSQLExpression { + final override val nodePatterns: Seq[TreePattern] = Seq(USER_DEFINED_AGGREGATION) + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 492f11607ce6d..203ee2d89b7b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, USER_DEFINED_AGGREGATION} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, UserDefinedAggregator} import org.apache.spark.sql.types._ @@ -358,6 +359,8 @@ case class ScalaUDAF( with ImplicitCastInputTypes with UserDefinedExpression { + final override val nodePatterns: Seq[TreePattern] = Seq(USER_DEFINED_AGGREGATION) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -500,6 +503,8 @@ case class ScalaAggregator[IN, BUF, OUT]( with ImplicitCastInputTypes with Logging { + final override val nodePatterns: Seq[TreePattern] = Seq(USER_DEFINED_AGGREGATION) + // input and buffer encoders are resolved by ResolveEncodersInScalaAgg @transient private[this] lazy val inputDeserializer = inputEncoder.createDeserializer() @transient private[this] lazy val bufferSerializer = bufferEncoder.createSerializer() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecution.scala index 175387f0dab33..ed4d1a1f7fa51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecution.scala @@ -61,7 +61,9 @@ object MarkSingleTaskExecution extends Rule[LogicalPlan] { /** * Plan patterns that make a query ineligible for the optimization. These operators either * require shuffles that we cannot safely elide, or run user code whose behavior we should not - * change (e.g. user-defined aggregations skip the final merge step when run in a single task). + * change. User-defined aggregations are excluded defensively: an optimization that collapses + * the partial and final aggregates when no exchange separates them would skip the user's merge + * step, so single-task plans must never be assumed safe for them. */ val unsupportedPatterns: Seq[TreePattern] = Seq( EVAL_PYTHON_UDF, @@ -71,7 +73,8 @@ object MarkSingleTaskExecution extends Rule[LogicalPlan] { LATERAL_SUBQUERY, LIST_SUBQUERY, PYTHON_UDF, - SCALAR_SUBQUERY) + SCALAR_SUBQUERY, + USER_DEFINED_AGGREGATION) /** * The per-operator sub-flags, resolved once per invocation. Each field indicates whether the @@ -114,15 +117,15 @@ object MarkSingleTaskExecution extends Rule[LogicalPlan] { */ private def isSupportedShape(plan: LogicalPlan, enabled: EnabledOperators): Boolean = plan match { case _: LogicalRelation | _: LocalRelation => true - // Operators that never introduce a shuffle by themselves. - case _: Project | _: Filter | _: SubqueryAlias | + // Operators that never introduce a shuffle by themselves. Note that `Distinct` and + // `SubqueryAlias` need no cases here: they are rewritten away by non-excludable rules + // (`ReplaceDistinctWithAggregate` and `EliminateSubqueryAliases`) long before this rule runs. + case _: Project | _: Filter | _: DeserializeToObject | _: SerializeFromObject => plan.children.forall(isSupportedShape(_, enabled)) // Shuffle-inducing operators, allowed only when the matching sub-flag is enabled. case _: Aggregate if enabled.aggregation => plan.children.forall(isSupportedShape(_, enabled)) - case _: Distinct if enabled.aggregation => - plan.children.forall(isSupportedShape(_, enabled)) case _: Expand if enabled.expand => plan.children.forall(isSupportedShape(_, enabled)) case (_: GlobalLimit | _: LocalLimit | _: Offset) if enabled.limitOffset => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala index d92575e938e75..cdf2d2ba1f119 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{Encoder, Encoders, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{FileSourceScanExec, LocalTableScanExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike -import org.apache.spark.sql.functions.{count, sum} +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.functions.{count, sum, udaf} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -154,6 +155,31 @@ class MarkSingleTaskExecutionSuite extends QueryTest with SharedSparkSession checkNotMarked(s"select col from $t where col = (select max(col) from $t2) order by col") } + test("does not mark plans with user-defined aggregations") { + val strLen = new Aggregator[String, Long, Long] { + override def zero: Long = 0L + override def reduce(b: Long, a: String): Long = b + a.length + override def merge(b1: Long, b2: Long): Long = b1 + b2 + override def finish(reduction: Long): Long = reduction + override def bufferEncoder: Encoder[Long] = Encoders.scalaLong + override def outputEncoder: Encoder[Long] = Encoders.scalaLong + } + // `functions.udaf` produces a `ScalaAggregator` expression. + spark.udf.register("test_str_len_agg", udaf(strLen)) + try { + checkNotMarked(s"select test_str_len_agg(col_str) from $t") + } finally { + spark.sessionState.catalog.dropTempFunction("test_str_len_agg", ignoreIfNotExists = true) + } + // A typed Dataset aggregation produces a `TypedAggregateExpression`. + withSQLConf(enabledConfs: _*) { + import testImplicits._ + val ds = spark.table(t).select($"col_str").as[String].select(strLen.toColumn) + assert(!isMarked(ds.queryExecution.optimizedPlan), + s"expected plan with typed aggregation NOT to be marked:\n${ds.queryExecution.optimizedPlan}") + } + } + test("output partitioning is SinglePartition, scan + sort") { checkSinglePartition(s"select col from $t order by col", Seq(Row(0), Row(1))) } From 51e86f51acc442f53ffb4e89e8dd48dd4b4739a4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 2 Jul 2026 10:11:36 -0700 Subject: [PATCH 6/8] Fix scalastyle line length in MarkSingleTaskExecutionSuite Co-authored-by: Claude Code --- .../execution/datasources/MarkSingleTaskExecutionSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala index cdf2d2ba1f119..82d29f8b33c8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala @@ -175,8 +175,9 @@ class MarkSingleTaskExecutionSuite extends QueryTest with SharedSparkSession withSQLConf(enabledConfs: _*) { import testImplicits._ val ds = spark.table(t).select($"col_str").as[String].select(strLen.toColumn) - assert(!isMarked(ds.queryExecution.optimizedPlan), - s"expected plan with typed aggregation NOT to be marked:\n${ds.queryExecution.optimizedPlan}") + val optimized = ds.queryExecution.optimizedPlan + assert(!isMarked(optimized), + s"expected plan with typed aggregation NOT to be marked:\n$optimized") } } From 4590187c29eecd193593fafd8b0e32c8770c310b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 2 Jul 2026 11:32:51 -0700 Subject: [PATCH 7/8] Tag HiveUDAFFunction and V2Aggregator with USER_DEFINED_AGGREGATION These are the remaining user-defined aggregate expressions whose merge step runs user code, so plans containing them must also be excluded from single-task execution. PythonUDAF is already covered by the PYTHON_UDF pattern. Note V2Aggregator does not mix in UserDefinedExpression, so it would be missed by a trait-based scan. Co-authored-by: Claude Code --- .../expressions/aggregate/V2Aggregator.scala | 3 +++ .../MarkSingleTaskExecutionSuite.scala | 24 +++++++++++++++++++ .../org/apache/spark/sql/hive/hiveUDFs.scala | 3 +++ 3 files changed, 30 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala index 49ba2ec8b904e..e92e612452014 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/V2Aggregator.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, UnsafeProjection} +import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, USER_DEFINED_AGGREGATION} import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction} import org.apache.spark.sql.types.{AbstractDataType, DataType} import org.apache.spark.util.ArrayImplicits._ @@ -31,6 +32,8 @@ case class V2Aggregator[BUF <: java.io.Serializable, OUT]( inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[BUF] with ImplicitCastInputTypes { + final override val nodePatterns: Seq[TreePattern] = Seq(USER_DEFINED_AGGREGATION) + private[this] lazy val inputProjection = UnsafeProjection.create(children) override def nullable: Boolean = aggrFunc.isResultNullable diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala index 82d29f8b33c8b..29d78fcb715f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala @@ -18,14 +18,20 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.{Encoder, Encoders, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.aggregate.V2Aggregator import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.catalyst.trees.TreePattern.USER_DEFINED_AGGREGATION +import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction} import org.apache.spark.sql.execution.{FileSourceScanExec, LocalTableScanExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions.{count, sum, udaf} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, LongType} import org.apache.spark.sql.test.SharedSparkSession /** @@ -181,6 +187,24 @@ class MarkSingleTaskExecutionSuite extends QueryTest with SharedSparkSession } } + test("V2Aggregator carries the USER_DEFINED_AGGREGATION pattern") { + // ScalaAggregator and the typed aggregate expressions are covered by the end-to-end test + // above; V2Aggregator has no operator-level pattern of its own, so verify it directly. + // HiveUDAFFunction lives in the hive module and is covered there. + val v2Func = new V2AggregateFunction[java.lang.Long, java.lang.Long] { + override def newAggregationState(): java.lang.Long = 0L + override def update(state: java.lang.Long, input: InternalRow): java.lang.Long = + state + input.getLong(0) + override def merge(l: java.lang.Long, r: java.lang.Long): java.lang.Long = l + r + override def produceResult(state: java.lang.Long): java.lang.Long = state + override def name(): String = "test_v2_sum" + override def inputTypes(): Array[DataType] = Array(LongType) + override def resultType(): DataType = LongType + } + val agg = V2Aggregator(v2Func, Seq(BoundReference(0, LongType, nullable = false))) + assert(agg.containsPattern(USER_DEFINED_AGGREGATION)) + } + test("output partitioning is SinglePartition, scan + sort") { checkSinglePartition(s"select col from $t order by col", Seq(Row(0), Row(1))) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index bf708eecf0c0c..129c5e2cc053b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper +import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, USER_DEFINED_AGGREGATION} import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ @@ -337,6 +338,8 @@ private[hive] case class HiveUDAFFunction( with HiveInspectors with UserDefinedExpression { + final override val nodePatterns: Seq[TreePattern] = Seq(USER_DEFINED_AGGREGATION) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) From 4de70310f1e8e31656fb4674734d2bc5345ac9c9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 2 Jul 2026 12:59:35 -0700 Subject: [PATCH 8/8] Fix scalastyle import ordering in MarkSingleTaskExecutionSuite Co-authored-by: Claude Code --- .../execution/datasources/MarkSingleTaskExecutionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala index 29d78fcb715f0..6e60e9a03e8dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/MarkSingleTaskExecutionSuite.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions.{count, sum, udaf} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, LongType} import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DataType, LongType} /** * Test suite for the [[MarkSingleTaskExecution]] optimizer rule and its physical effects. The rule