@@ -485,6 +485,79 @@ def test_udf_with_array_type(self):
485485 self .assertEqual (list (range (3 )), l1 )
486486 self .assertEqual (1 , l2 )
487487
488+ def test_udf_returning_date_time (self ):
489+ from pyspark .sql .functions import udf
490+ from pyspark .sql .types import DateType
491+
492+ data = self .spark .createDataFrame ([(2017 , 10 , 30 )], ['year' , 'month' , 'day' ])
493+
494+ expected_date = datetime .date (2017 , 10 , 30 )
495+ expected_datetime = datetime .datetime (2017 , 10 , 30 )
496+
497+ # test Python UDF with default returnType=StringType()
498+ # Returning a date or datetime object at runtime with such returnType declaration
499+ # is a mismatch, which results in a null, as PySpark treats it as unconvertible.
500+ py_date_str , py_datetime_str = udf (datetime .date ), udf (datetime .datetime )
501+ query = data .select (
502+ py_date_str (data .year , data .month , data .day ).isNull (),
503+ py_datetime_str (data .year , data .month , data .day ).isNull ())
504+ [row ] = query .collect ()
505+ self .assertEqual (row [0 ], True )
506+ self .assertEqual (row [1 ], True )
507+
508+ query = data .select (
509+ py_date_str (data .year , data .month , data .day ),
510+ py_datetime_str (data .year , data .month , data .day ))
511+ [row ] = query .collect ()
512+ self .assertEqual (row [0 ], None )
513+ self .assertEqual (row [1 ], None )
514+
515+ # test Python UDF with specific returnType matching actual result
516+ py_date , py_datetime = udf (datetime .date , DateType ()), udf (datetime .datetime , 'timestamp' )
517+ query = data .select (
518+ py_date (data .year , data .month , data .day ) == lit (expected_date ),
519+ py_datetime (data .year , data .month , data .day ) == lit (expected_datetime ))
520+ [row ] = query .collect ()
521+ self .assertEqual (row [0 ], True )
522+ self .assertEqual (row [1 ], True )
523+
524+ query = data .select (
525+ py_date (data .year , data .month , data .day ),
526+ py_datetime (data .year , data .month , data .day ))
527+ [row ] = query .collect ()
528+ self .assertEqual (row [0 ], expected_date )
529+ self .assertEqual (row [1 ], expected_datetime )
530+
531+ # test semantic matching of datetime with timezone
532+ # class in __main__ is not serializable
533+ from pyspark .sql .tests import UTCOffsetTimezone
534+ datetime_with_utc0 = datetime .datetime (2017 , 10 , 30 , tzinfo = UTCOffsetTimezone (0 ))
535+ datetime_with_utc1 = datetime .datetime (2017 , 10 , 30 , tzinfo = UTCOffsetTimezone (1 ))
536+ test_udf = udf (lambda : datetime_with_utc0 , 'timestamp' )
537+ query = data .select (
538+ test_udf () == lit (datetime_with_utc0 ),
539+ test_udf () > lit (datetime_with_utc1 ),
540+ test_udf ()
541+ )
542+ [row ] = query .collect ()
543+ self .assertEqual (row [0 ], True )
544+ self .assertEqual (row [1 ], True )
545+ # Note: datetime returned from PySpark is always naive (timezone unaware).
546+ # It currently respects Python's current local timezone.
547+ self .assertEqual (row [2 ].tzinfo , None )
548+
549+ # tzinfo=None is really the same as not specifying it: a naive datetime object
550+ # Just adding a test case for it here for completeness
551+ datetime_with_null_timezone = datetime .datetime (2017 , 10 , 30 , tzinfo = None )
552+ test_udf = udf (lambda : datetime_with_null_timezone , 'timestamp' )
553+ query = data .select (
554+ test_udf () == lit (datetime_with_null_timezone ),
555+ test_udf ()
556+ )
557+ [row ] = query .collect ()
558+ self .assertEqual (row [0 ], True )
559+ self .assertEqual (row [1 ], datetime_with_null_timezone )
560+
488561 def test_broadcast_in_udf (self ):
489562 bar = {"a" : "aa" , "b" : "bb" , "c" : "abc" }
490563 foo = self .sc .broadcast (bar )
0 commit comments