Skip to content

Commit 9edc5e9

Browse files
committed
Return UDF from udf.register
1 parent 90d77e9 commit 9edc5e9

File tree

4 files changed

+39
-16
lines changed

4 files changed

+39
-16
lines changed

python/pyspark/sql/catalog.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,22 +238,26 @@ def registerFunction(self, name, f, returnType=StringType()):
238238
:param f: python function
239239
:param returnType: a :class:`pyspark.sql.types.DataType` object
240240
241-
>>> spark.catalog.registerFunction("stringLengthString", lambda x: len(x))
241+
>>> strlen = spark.catalog.registerFunction("stringLengthString", len)
242242
>>> spark.sql("SELECT stringLengthString('test')").collect()
243243
[Row(stringLengthString(test)=u'4')]
244244
245+
>>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
246+
[Row(stringLengthString(text)=u'3')]
247+
245248
>>> from pyspark.sql.types import IntegerType
246-
>>> spark.catalog.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
249+
>>> strlen = spark.catalog.registerFunction("stringLengthInt", len, IntegerType())
247250
>>> spark.sql("SELECT stringLengthInt('test')").collect()
248251
[Row(stringLengthInt(test)=4)]
249252
250253
>>> from pyspark.sql.types import IntegerType
251-
>>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
254+
>>> strlen = spark.udf.register("stringLengthInt", len, IntegerType())
252255
>>> spark.sql("SELECT stringLengthInt('test')").collect()
253256
[Row(stringLengthInt(test)=4)]
254257
"""
255258
udf = UserDefinedFunction(f, returnType, name)
256259
self._jsparkSession.udf().registerPython(name, udf._judf)
260+
return udf._wrapped()
257261

258262
@since(2.0)
259263
def isCached(self, tableName):

python/pyspark/sql/context.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,22 +185,26 @@ def registerFunction(self, name, f, returnType=StringType()):
185185
:param name: name of the UDF
186186
:param f: python function
187187
:param returnType: a :class:`pyspark.sql.types.DataType` object
188+
:return a wrapped :class:`UserDefinedFunction`
188189
189-
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
190+
>>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x))
190191
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
191192
[Row(stringLengthString(test)=u'4')]
192193
194+
>>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
195+
[Row(stringLengthString(text)=u'3')]
196+
193197
>>> from pyspark.sql.types import IntegerType
194-
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
198+
>>> strlen = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
195199
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
196200
[Row(stringLengthInt(test)=4)]
197201
198202
>>> from pyspark.sql.types import IntegerType
199-
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
203+
>>> strlen = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
200204
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
201205
[Row(stringLengthInt(test)=4)]
202206
"""
203-
self.sparkSession.catalog.registerFunction(name, f, returnType)
207+
return self.sparkSession.catalog.registerFunction(name, f, returnType)
204208

205209
@ignore_unicode_prefix
206210
@since(2.1)

python/pyspark/sql/functions.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,6 +1917,19 @@ def __call__(self, *cols):
19171917
sc = SparkContext._active_spark_context
19181918
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
19191919

1920+
def _wrapped(self):
1921+
"""
1922+
Wrap this udf with a function and attach docstring from func
1923+
"""
1924+
@functools.wraps(self.func)
1925+
def wrapper(*args):
1926+
return self(*args)
1927+
1928+
wrapper.func = self.func
1929+
wrapper.returnType = self.returnType
1930+
1931+
return wrapper
1932+
19201933

19211934
@since(1.3)
19221935
def udf(f=None, returnType=StringType()):
@@ -1951,15 +1964,7 @@ def udf(f=None, returnType=StringType()):
19511964
"""
19521965
def _udf(f, returnType=StringType()):
19531966
udf_obj = UserDefinedFunction(f, returnType)
1954-
1955-
@functools.wraps(f)
1956-
def wrapper(*args):
1957-
return udf_obj(*args)
1958-
1959-
wrapper.func = udf_obj.func
1960-
wrapper.returnType = udf_obj.returnType
1961-
1962-
return wrapper
1967+
return udf_obj._wrapped()
19631968

19641969
# decorator @udf, @udf() or @udf(dataType())
19651970
if f is None or isinstance(f, (str, DataType)):

python/pyspark/sql/tests.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,15 @@ def test_udf_with_order_by_and_limit(self):
436436
res.explain(True)
437437
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
438438

439+
def test_udf_registration_returns_udf(self):
440+
df = self.spark.range(10)
441+
add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())
442+
443+
self.assertListEqual(
444+
df.selectExpr("add_three(id) AS plus_three").collect(),
445+
df.select(add_three("id").alias("plus_three")).collect()
446+
)
447+
439448
def test_wholefile_json(self):
440449
people1 = self.spark.read.json("python/test_support/sql/people.json")
441450
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
@@ -615,6 +624,7 @@ def f(x):
615624
self.assertEqual(f, f_.func)
616625
self.assertEqual(return_type, f_.returnType)
617626

627+
618628
def test_basic_functions(self):
619629
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
620630
df = self.spark.read.json(rdd)

0 commit comments

Comments
 (0)