Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/benchmarks/bench_eval_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from pyspark.cloudpickle import dumps as cloudpickle_dumps
from pyspark.serializers import write_int, write_long, SpecialLengths
from pyspark.sql.types import (
ArrayType,
BinaryType,
BooleanType,
DoubleType,
Expand Down Expand Up @@ -251,6 +252,12 @@ class MockDataFactory:
"string": (lambda r: pa.array([f"s{j}" for j in range(r)]), StringType()),
"binary": (lambda r: pa.array([f"b{j}".encode() for j in range(r)]), BinaryType()),
"boolean": (lambda r: pa.array(np.random.choice([True, False], r)), BooleanType()),
"string_array": (
lambda r: pa.array(
[[f"s{j}", f"t{j}"] for j in range(r)], type=pa.list_(pa.string())
),
ArrayType(StringType()),
),
}

MIXED_TYPES = [
Expand All @@ -266,6 +273,7 @@ class MockDataFactory:
"pure_ints": [TYPE_REGISTRY["int"]],
"pure_floats": [TYPE_REGISTRY["double"]],
"pure_strings": [TYPE_REGISTRY["string"]],
"pure_string_arrays": [TYPE_REGISTRY["string_array"]],
"pure_ts": [
(
lambda r: pa.array(
Expand Down Expand Up @@ -480,6 +488,7 @@ class _ArrowBatchedBenchMixin:
"pure_ints": ("pure_ints", 50_000, 10, 5_000),
"pure_floats": ("pure_floats", 50_000, 10, 5_000),
"pure_strings": ("pure_strings", 50_000, 10, 5_000),
"pure_string_arrays": ("pure_string_arrays", 50_000, 10, 5_000),
"mixed_types": ("mixed", 50_000, 10, 5_000),
}

Expand All @@ -502,6 +511,16 @@ def _build_scenario(cls, name):
"identity_udf": (lambda x: x, None, [0]),
"stringify_udf": (lambda x: str(x), StringType(), [0]),
"nullcheck_udf": (lambda x: x is not None, BooleanType(), [0]),
# array<string> out. Exercises the output-side Python->Arrow conversion
# for a nested type. Type-agnostic on input (like the other UDFs here, so
# it runs across every scenario in the cross product); pair with the
# ``pure_string_arrays`` scenario for the array<string>-in / array<string>-out
# case that stresses the nested output conversion the most.
"to_string_array_udf": (
lambda x: [str(e)[::-1] for e in x] if isinstance(x, (list, tuple)) else [str(x)],
ArrayType(StringType()),
[0],
),
}
params = [list(_scenario_configs), list(_udfs)]
param_names = ["scenario", "udf"]
Expand Down
63 changes: 63 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
StringType,
StructField,
StructType,
TimestampType,
VarcharType,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
Expand Down Expand Up @@ -270,6 +271,68 @@ def f(v: float):
rounded = df.select(f("v").alias("d")).first().d
self.assertEqual(rounded, Decimal("1.233999999999999986"))

def test_array_string_output_fast_path(self):
# Regression test for the Arrow Python UDF output fast path: an
# array<string> UDF whose elements are already strings must produce the
# same result whether or not the per-element output converter is used.
df = self.spark.range(0, 5)

@udf(returnType=ArrayType(StringType()))
def reverse_each(i):
words = ["alpha", "beta", "gamma"]
return [w[::-1] for w in words[: (int(i) % 3) + 1]]

result = [r.res for r in df.select(reverse_each("id").alias("res")).collect()]
self.assertEqual(
result,
[
["ahpla"],
["ahpla", "ateb"],
["ahpla", "ateb", "ammag"],
["ahpla"],
["ahpla", "ateb"],
],
)

def test_array_string_output_requires_coercion(self):
# When the array<string> UDF returns non-string elements, the output must
# still be coerced to string (fast path must fall back to the converter).
# Scoped to the non-legacy conversion path, which this fast path targets;
# the legacy pandas path does not coerce non-string array elements.
with self.sql_conf(
{"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled": False}
):
df = self.spark.range(0, 3)

@udf(returnType=ArrayType(StringType()))
def ints_as_strings(i):
return [int(i), int(i) + 1]

result = [r.res for r in df.select(ints_as_strings("id").alias("res")).collect()]
self.assertEqual(result, [["0", "1"], ["1", "2"], ["2", "3"]])

def test_array_timestamp_output_timezone(self):
# array<timestamp> is excluded from the fast path because the converter
# applies session-timezone truncation that raw pa.array would skip. Verify
# the timestamps come back correctly (i.e. the converter path was used).
import datetime

with self.sql_conf({"spark.sql.session.timeZone": "America/Los_Angeles"}):
df = self.spark.range(0, 2)

@udf(returnType=ArrayType(TimestampType()))
def make_ts(i):
return [datetime.datetime(2020, 1, 1, 12, 0, 0)]

result = [r.res for r in df.select(make_ts("id").alias("res")).collect()]
self.assertEqual(
result,
[
[datetime.datetime(2020, 1, 1, 12, 0, 0)],
[datetime.datetime(2020, 1, 1, 12, 0, 0)],
],
)

def test_err_return_type(self):
with self.assertRaises(PySparkNotImplementedError) as pe:
udf(lambda x: x, VarcharType(10), useArrow=True)
Expand Down
96 changes: 87 additions & 9 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,19 @@
ArrayType,
BinaryType,
DataType,
DecimalType,
GeographyType,
GeometryType,
MapType,
NullType,
Row,
StringType,
StructField,
StructType,
TimestampNTZType,
TimestampType,
UserDefinedType,
VariantType,
_create_row,
_parse_datatype_json_string,
)
Expand Down Expand Up @@ -3031,6 +3039,49 @@ def cogrouped_func(
):
import pyarrow as pa

def _output_fast_path_safe(dt: DataType) -> bool:
# True when building an Arrow array directly from raw UDF results
# (pa.array(results, type=...)) yields values identical to running the
# per-element LocalDataToArrowConversion converter first. This holds for
# numeric/bool/string/binary (and nested containers of them), where the
# converter is a no-op for already-correct-type values and pa.array
# raises for anything that would need coercion (so we fall back safely).
#
# It does NOT hold, and the type must be excluded, when the converter can
# transform a value that pa.array would ALSO accept (silently producing a
# different result) or when pa.array accepts an input the converter is
# meant to reject:
# - Timestamp/TimestampNTZ: converter truncates to session tz.
# - Decimal: pa.array coerces int->decimal, but the converter only does
# so when intToDecimalCoercionEnabled; skipping it would bypass that
# gate (and rescaling / Decimal('NaN')->None handling).
# - UDT/Variant/Geo: converter serializes to a storage form.
# - Null: trivial; keep on the safe/simple side.
if isinstance(dt, ArrayType):
return _output_fast_path_safe(dt.elementType)
elif isinstance(dt, MapType):
return _output_fast_path_safe(dt.keyType) and _output_fast_path_safe(
dt.valueType
)
elif isinstance(dt, StructType):
return all(_output_fast_path_safe(f.dataType) for f in dt.fields)
elif isinstance(
dt,
(
TimestampType,
TimestampNTZType,
DecimalType,
UserDefinedType,
VariantType,
GeographyType,
GeometryType,
NullType,
),
):
return False
else:
return True

# --- UDF preparation ---
udf_infos = []
for udf_func, udf_args_offsets, udf_kwargs_offsets, udf_return_type in udfs:
Expand All @@ -3053,6 +3104,7 @@ def cogrouped_func(
none_on_identity=True,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
),
_output_fast_path_safe(udf_return_type),
)
)
col_names = [f"_{i}" for i in range(len(udfs))]
Expand Down Expand Up @@ -3088,7 +3140,14 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record

# --- Process: evaluate each UDF row-by-row ---
output_arrays = []
for udf_func, offsets, zero_arg, arrow_return_type, result_conv in udf_infos:
for (
udf_func,
offsets,
zero_arg,
arrow_return_type,
result_conv,
fast_ok,
) in udf_infos:
rows = (
[() for _ in range(num_rows)]
if zero_arg
Expand All @@ -3098,15 +3157,34 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record
verify_result_row_count(len(results), num_rows)

# --- Output: Python -> Arrow ---
converted = (
[result_conv(r) for r in results] if result_conv is not None else results
)
try:
arr = pa.array(converted, type=arrow_return_type)
except pa.lib.ArrowInvalid:
arr = pa.array(converted).cast(
target_type=arrow_return_type, safe=runner_conf.safecheck
# Fast path: when the return type has no value-transforming
# coercion (see output_fast_path_safe), try building the Arrow
# array directly from the raw UDF results and let PyArrow (C++)
# do the work, skipping the per-element Python converter, which
# is pure overhead when results already match the declared type.
# Only fall back to the converter if PyArrow rejects the raw
# results (an element genuinely needs coercion, e.g. int->string).
# NOTE: types like timestamp are excluded because raw pa.array
# succeeds but yields DIFFERENT values than the converter (tz
# truncation), so a try/except gate alone is unsafe for them.
arr = None
if result_conv is not None and fast_ok:
try:
arr = pa.array(results, type=arrow_return_type)
except pa.lib.ArrowException:
arr = None
if arr is None:
converted = (
[result_conv(r) for r in results]
if result_conv is not None
else results
)
try:
arr = pa.array(converted, type=arrow_return_type)
except pa.lib.ArrowInvalid:
arr = pa.array(converted).cast(
target_type=arrow_return_type, safe=runner_conf.safecheck
)
output_arrays.append(arr)

yield pa.RecordBatch.from_arrays(output_arrays, col_names)
Expand Down