Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ case class ApproxCountDistinctForIntervals(

override def inputTypes: Seq[AbstractDataType] = {
Seq(TypeCollection(NumericType, TimestampType, DateType, TimestampNTZType,
YearMonthIntervalType, DayTimeIntervalType), ArrayType)
YearMonthIntervalType, DayTimeIntervalType, AnyTimeType), ArrayType)
}

// Mark as lazy so that endpointsExpression is not evaluated during tree transformation.
Expand All @@ -90,7 +90,7 @@ case class ApproxCountDistinctForIntervals(
} else {
endpointsExpression.dataType match {
case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType |
_: AnsiIntervalType, _) =>
_: AnsiIntervalType | _: AnyTimeType, _) =>
if (endpoints.length < 2) {
DataTypeMismatch(
errorSubClass = "WRONG_NUM_ENDPOINTS",
Expand All @@ -100,7 +100,7 @@ case class ApproxCountDistinctForIntervals(
}
case inputType =>
val requiredElemTypes = toSQLType(TypeCollection(
NumericType, DateType, TimestampType, TimestampNTZType, AnsiIntervalType))
NumericType, DateType, TimestampType, TimestampNTZType, AnsiIntervalType, AnyTimeType))
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
Expand Down Expand Up @@ -144,7 +144,7 @@ case class ApproxCountDistinctForIntervals(
.toDouble(value.asInstanceOf[PhysicalNumericType#InternalType])
case _: DateType | _: YearMonthIntervalType =>
value.asInstanceOf[Int].toDouble
case TimestampType | TimestampNTZType | _: DayTimeIntervalType =>
case TimestampType | TimestampNTZType | _: DayTimeIntervalType | _: AnyTimeType =>
value.asInstanceOf[Long].toDouble
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import java.sql.{Date, Timestamp}
import java.time.LocalDateTime
import java.time.{LocalDateTime, LocalTime}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -44,7 +44,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> ("(\"NUMERIC\" or \"TIMESTAMP\" or \"DATE\" or \"TIMESTAMP_NTZ\"" +
" or \"INTERVAL YEAR TO MONTH\" or \"INTERVAL DAY TO SECOND\")"),
" or \"INTERVAL YEAR TO MONTH\" or \"INTERVAL DAY TO SECOND\" or \"TIME\")"),
"inputSql" -> "\"a\"",
"inputType" -> toSQLType(dataType)
)
Expand Down Expand Up @@ -92,7 +92,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(1),
"requiredType" -> "ARRAY OF (\"NUMERIC\" or \"DATE\" or \"TIMESTAMP\" or \"TIMESTAMP_NTZ\" or \"ANSI INTERVAL\")",
"requiredType" -> "ARRAY OF (\"NUMERIC\" or \"DATE\" or \"TIMESTAMP\" or \"TIMESTAMP_NTZ\" or \"ANSI INTERVAL\" or \"TIME\")",
"inputSql" -> "\"array(foobar)\"",
"inputType" -> "\"ARRAY<STRING>\"")))
// scalastyle:on line.size.limit
Expand Down Expand Up @@ -230,7 +230,9 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
(intRecords.map(DateTimeUtils.toJavaTimestamp(_)),
intEndpoints.map(DateTimeUtils.toJavaTimestamp(_)), TimestampType),
(intRecords.map(DateTimeUtils.microsToLocalDateTime(_)),
intEndpoints.map(DateTimeUtils.microsToLocalDateTime(_)), TimestampNTZType)
intEndpoints.map(DateTimeUtils.microsToLocalDateTime(_)), TimestampNTZType),
(intRecords.map(i => LocalTime.ofNanoOfDay(i.toLong)),
intEndpoints.map(i => LocalTime.ofNanoOfDay(i.toLong)), TimeType())
)

inputs.foreach { case (records, endpoints, dataType) =>
Expand All @@ -241,6 +243,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
case d: Date => DateTimeUtils.fromJavaDate(d)
case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t)
case ldt: LocalDateTime => DateTimeUtils.localDateTimeToMicros(ldt)
case lt: LocalTime => DateTimeUtils.localTimeToNanos(lt)
case _ => r
}
input.update(0, value)
Expand All @@ -253,6 +256,44 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
}
}

test("TIME type with realistic nanos-of-day magnitudes") {
// Realistic time-of-day values in nanos: midnight, 06:00, 12:00, 18:00, near max
// LocalTime.MAX is 23:59:59.999999999 = 86_399_999_999_999 nanos
val midnight = 0L
val sixAm = 6L * 3600L * 1000000000L // 21_600_000_000_000
val noon = 12L * 3600L * 1000000000L // 43_200_000_000_000
val sixPm = 18L * 3600L * 1000000000L // 64_800_000_000_000
val nearMax = 86399999999999L // 23:59:59.999999999

val endpoints = Array(midnight, sixAm, noon, sixPm, nearMax)
.map(n => LocalTime.ofNanoOfDay(n))

// Generate distinct values per interval using minute-granularity nanos.
// [midnight, 6AM): 100 distinct minutes (00:00 .. 01:39)
// [6AM, noon): 80 distinct minutes (06:00 .. 07:19)
// [noon, 6PM): 60 distinct minutes (12:00 .. 12:59)
// [6PM, nearMax]: 50 distinct values including edge nearMax
val minuteNanos = 60L * 1000000000L
val interval1 = (0 until 100).map(i => midnight + i * minuteNanos)
val interval2 = (0 until 80).map(i => sixAm + i * minuteNanos)
val interval3 = (0 until 60).map(i => noon + i * minuteNanos)
val interval4 = (0 until 49).map(i => sixPm + i * minuteNanos) :+ nearMax

val allNanos = interval1 ++ interval2 ++ interval3 ++ interval4

val (aggFunc, input, buffer) = createEstimator(endpoints, TimeType())
allNanos.foreach { n =>
input.update(0, n)
aggFunc.update(buffer, input)
}

// 4 intervals: [midnight,6AM), [6AM,noon), [noon,6PM), [6PM,nearMax]
checkNDVs(
ndvs = aggFunc.eval(buffer).asInstanceOf[ArrayData].toLongArray(),
expectedNdvs = Array(100, 80, 60, 50),
rsd = aggFunc.relativeSD)
}

private def checkNDVs(ndvs: Array[Long], expectedNdvs: Array[Long], rsd: Double): Unit = {
assert(ndvs.length == expectedNdvs.length)
for (i <- ndvs.indices) {
Expand Down