Skip to content

Commit d307cee

Browse files
committed
[SPARK-22966][PYTHON][SQL] Python UDFs with returnType=StringType should treat return values of datetime.date or datetime.datetime as unconvertible
Add conversion to PySpark to mark Python UDFs that declared returnType=StringType() but actually returned a datatime.date or datetime.datetime as unconvertible, i.e. converting it to null. Also added a new unit test to pyspark/sql/tests.py to reflect current semantics of Python UDFs returning a value of mismatched type with the declared returnType.
1 parent 186bf8f commit d307cee

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

python/pyspark/sql/tests.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,79 @@ def test_udf_with_array_type(self):
485485
self.assertEqual(list(range(3)), l1)
486486
self.assertEqual(1, l2)
487487

488+
def test_udf_returning_date_time(self):
489+
from pyspark.sql.functions import udf
490+
from pyspark.sql.types import DateType
491+
492+
data = self.spark.createDataFrame([(2017, 10, 30)], ['year', 'month', 'day'])
493+
494+
expected_date = datetime.date(2017, 10, 30)
495+
expected_datetime = datetime.datetime(2017, 10, 30)
496+
497+
# test Python UDF with default returnType=StringType()
498+
# Returning a date or datetime object at runtime with such returnType declaration
499+
# is a mismatch, which results in a null, as PySpark treats it as unconvertible.
500+
py_date_str, py_datetime_str = udf(datetime.date), udf(datetime.datetime)
501+
query = data.select(
502+
py_date_str(data.year, data.month, data.day).isNull(),
503+
py_datetime_str(data.year, data.month, data.day).isNull())
504+
[row] = query.collect()
505+
self.assertEqual(row[0], True)
506+
self.assertEqual(row[1], True)
507+
508+
query = data.select(
509+
py_date_str(data.year, data.month, data.day),
510+
py_datetime_str(data.year, data.month, data.day))
511+
[row] = query.collect()
512+
self.assertEqual(row[0], None)
513+
self.assertEqual(row[1], None)
514+
515+
# test Python UDF with specific returnType matching actual result
516+
py_date, py_datetime = udf(datetime.date, DateType()), udf(datetime.datetime, 'timestamp')
517+
query = data.select(
518+
py_date(data.year, data.month, data.day) == lit(expected_date),
519+
py_datetime(data.year, data.month, data.day) == lit(expected_datetime))
520+
[row] = query.collect()
521+
self.assertEqual(row[0], True)
522+
self.assertEqual(row[1], True)
523+
524+
query = data.select(
525+
py_date(data.year, data.month, data.day),
526+
py_datetime(data.year, data.month, data.day))
527+
[row] = query.collect()
528+
self.assertEqual(row[0], expected_date)
529+
self.assertEqual(row[1], expected_datetime)
530+
531+
# test semantic matching of datetime with timezone
532+
# class in __main__ is not serializable
533+
from pyspark.sql.tests import UTCOffsetTimezone
534+
datetime_with_utc0 = datetime.datetime(2017, 10, 30, tzinfo=UTCOffsetTimezone(0))
535+
datetime_with_utc1 = datetime.datetime(2017, 10, 30, tzinfo=UTCOffsetTimezone(1))
536+
test_udf = udf(lambda: datetime_with_utc0, 'timestamp')
537+
query = data.select(
538+
test_udf() == lit(datetime_with_utc0),
539+
test_udf() > lit(datetime_with_utc1),
540+
test_udf()
541+
)
542+
[row] = query.collect()
543+
self.assertEqual(row[0], True)
544+
self.assertEqual(row[1], True)
545+
# Note: datetime returned from PySpark is always naive (timezone unaware).
546+
# It currently respects Python's current local timezone.
547+
self.assertEqual(row[2].tzinfo, None)
548+
549+
# tzinfo=None is really the same as not specifying it: a naive datetime object
550+
# Just adding a test case for it here for completeness
551+
datetime_with_null_timezone = datetime.datetime(2017, 10, 30, tzinfo=None)
552+
test_udf = udf(lambda: datetime_with_null_timezone, 'timestamp')
553+
query = data.select(
554+
test_udf() == lit(datetime_with_null_timezone),
555+
test_udf()
556+
)
557+
[row] = query.collect()
558+
self.assertEqual(row[0], True)
559+
self.assertEqual(row[1], datetime_with_null_timezone)
560+
488561
def test_broadcast_in_udf(self):
489562
bar = {"a": "aa", "b": "bb", "c": "abc"}
490563
foo = self.sc.broadcast(bar)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python
1919

2020
import java.io.OutputStream
2121
import java.nio.charset.StandardCharsets
22+
import java.util.Calendar
2223

2324
import scala.collection.JavaConverters._
2425

@@ -144,6 +145,7 @@ object EvaluatePython {
144145
}
145146

146147
case StringType => (obj: Any) => nullSafeConvert(obj) {
148+
case _: Calendar => null
147149
case _ => UTF8String.fromString(obj.toString)
148150
}
149151

0 commit comments

Comments
 (0)