diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index a468153b57c5a..e5e798495c199 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -67,7 +67,7 @@ case class ApproxCountDistinctForIntervals( override def inputTypes: Seq[AbstractDataType] = { Seq(TypeCollection(NumericType, TimestampType, DateType, TimestampNTZType, - YearMonthIntervalType, DayTimeIntervalType), ArrayType) + YearMonthIntervalType, DayTimeIntervalType, AnyTimeType), ArrayType) } // Mark as lazy so that endpointsExpression is not evaluated during tree transformation. @@ -90,7 +90,7 @@ case class ApproxCountDistinctForIntervals( } else { endpointsExpression.dataType match { case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType | - _: AnsiIntervalType, _) => + _: AnsiIntervalType | _: AnyTimeType, _) => if (endpoints.length < 2) { DataTypeMismatch( errorSubClass = "WRONG_NUM_ENDPOINTS", @@ -100,7 +100,7 @@ case class ApproxCountDistinctForIntervals( } case inputType => val requiredElemTypes = toSQLType(TypeCollection( - NumericType, DateType, TimestampType, TimestampNTZType, AnsiIntervalType)) + NumericType, DateType, TimestampType, TimestampNTZType, AnsiIntervalType, AnyTimeType)) DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( @@ -144,7 +144,7 @@ case class ApproxCountDistinctForIntervals( .toDouble(value.asInstanceOf[PhysicalNumericType#InternalType]) case _: DateType | _: YearMonthIntervalType => value.asInstanceOf[Int].toDouble - case TimestampType | TimestampNTZType | _: DayTimeIntervalType => + case TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: AnyTimeType => value.asInstanceOf[Long].toDouble } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala index 656f8b161e17f..b7eb0d26c0b46 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.sql.{Date, Timestamp} -import java.time.LocalDateTime +import java.time.{LocalDateTime, LocalTime} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow @@ -44,7 +44,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { messageParameters = Map( "paramIndex" -> ordinalNumber(0), "requiredType" -> ("(\"NUMERIC\" or \"TIMESTAMP\" or \"DATE\" or \"TIMESTAMP_NTZ\"" + - " or \"INTERVAL YEAR TO MONTH\" or \"INTERVAL DAY TO SECOND\")"), + " or \"INTERVAL YEAR TO MONTH\" or \"INTERVAL DAY TO SECOND\" or \"TIME\")"), "inputSql" -> "\"a\"", "inputType" -> toSQLType(dataType) ) @@ -92,7 +92,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( "paramIndex" -> ordinalNumber(1), - "requiredType" -> "ARRAY OF (\"NUMERIC\" or \"DATE\" or \"TIMESTAMP\" or \"TIMESTAMP_NTZ\" or \"ANSI INTERVAL\")", + "requiredType" -> "ARRAY OF (\"NUMERIC\" or \"DATE\" or \"TIMESTAMP\" or \"TIMESTAMP_NTZ\" or \"ANSI INTERVAL\" or \"TIME\")", "inputSql" -> "\"array(foobar)\"", "inputType" -> "\"ARRAY\""))) // scalastyle:on line.size.limit @@ -230,7 +230,9 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { (intRecords.map(DateTimeUtils.toJavaTimestamp(_)), intEndpoints.map(DateTimeUtils.toJavaTimestamp(_)), TimestampType), (intRecords.map(DateTimeUtils.microsToLocalDateTime(_)), - intEndpoints.map(DateTimeUtils.microsToLocalDateTime(_)), TimestampNTZType) + intEndpoints.map(DateTimeUtils.microsToLocalDateTime(_)), TimestampNTZType), + (intRecords.map(i => LocalTime.ofNanoOfDay(i.toLong)), + intEndpoints.map(i => LocalTime.ofNanoOfDay(i.toLong)), TimeType()) ) inputs.foreach { case (records, endpoints, dataType) => @@ -241,6 +243,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { case d: Date => DateTimeUtils.fromJavaDate(d) case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) case ldt: LocalDateTime => DateTimeUtils.localDateTimeToMicros(ldt) + case lt: LocalTime => DateTimeUtils.localTimeToNanos(lt) case _ => r } input.update(0, value) @@ -253,6 +256,44 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { } } + test("TIME type with realistic nanos-of-day magnitudes") { + // Realistic time-of-day values in nanos: midnight, 06:00, 12:00, 18:00, near max + // LocalTime.MAX is 23:59:59.999999999 = 86_399_999_999_999 nanos + val midnight = 0L + val sixAm = 6L * 3600L * 1000000000L // 21_600_000_000_000 + val noon = 12L * 3600L * 1000000000L // 43_200_000_000_000 + val sixPm = 18L * 3600L * 1000000000L // 64_800_000_000_000 + val nearMax = 86399999999999L // 23:59:59.999999999 + + val endpoints = Array(midnight, sixAm, noon, sixPm, nearMax) + .map(n => LocalTime.ofNanoOfDay(n)) + + // Generate distinct values per interval using minute-granularity nanos. + // [midnight, 6AM): 100 distinct minutes (00:00 .. 01:39) + // [6AM, noon): 80 distinct minutes (06:00 .. 07:19) + // [noon, 6PM): 60 distinct minutes (12:00 .. 12:59) + // [6PM, nearMax]: 50 distinct values including edge nearMax + val minuteNanos = 60L * 1000000000L + val interval1 = (0 until 100).map(i => midnight + i * minuteNanos) + val interval2 = (0 until 80).map(i => sixAm + i * minuteNanos) + val interval3 = (0 until 60).map(i => noon + i * minuteNanos) + val interval4 = (0 until 49).map(i => sixPm + i * minuteNanos) :+ nearMax + + val allNanos = interval1 ++ interval2 ++ interval3 ++ interval4 + + val (aggFunc, input, buffer) = createEstimator(endpoints, TimeType()) + allNanos.foreach { n => + input.update(0, n) + aggFunc.update(buffer, input) + } + + // 4 intervals: [midnight,6AM), [6AM,noon), [noon,6PM), [6PM,nearMax] + checkNDVs( + ndvs = aggFunc.eval(buffer).asInstanceOf[ArrayData].toLongArray(), + expectedNdvs = Array(100, 80, 60, 50), + rsd = aggFunc.relativeSD) + } + private def checkNDVs(ndvs: Array[Long], expectedNdvs: Array[Long], rsd: Double): Unit = { assert(ndvs.length == expectedNdvs.length) for (i <- ndvs.indices) {