Skip to content

Commit 11a6e24

Browse files
committed
fix MapInPandas
1 parent 0ab2b7a commit 11a6e24

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

python/pyspark/sql/tests/test_pandas_map.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ def func(iterator):
112112
expected = df.collect()
113113
self.assertEqual(actual, expected)
114114

115+
def test_self_join(self):
116+
# SPARK-34319: self-join with MapInPandas
117+
df1 = self.spark.range(10)
118+
df2 = df1.mapInPandas(lambda iter: iter, 'id long')
119+
actual = df2.join(df2).collect()
120+
expected = df1.join(df1).collect()
121+
self.assertEqual(actual, expected)
122+
115123

116124
if __name__ == "__main__":
117125
from pyspark.sql.tests.test_pandas_map import * # noqa: F401

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,11 @@ class Analyzer(override val catalogManager: CatalogManager)
14061406
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
14071407

14081408
case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
1409-
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1409+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1410+
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1411+
1412+
case oldVersion @ MapInPandas(_, output, _)
1413+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
14101414
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
14111415

14121416
case oldVersion: Generate

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,25 @@ class AnalysisSuite extends AnalysisTest with Matchers {
654654
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
655655
}
656656

657+
test("SPARK-34319: analysis fails on self-join with MapInPandas") {
658+
val pythonUdf = PythonUDF("pyUDF", null,
659+
StructType(Seq(StructField("a", LongType))),
660+
Seq.empty,
661+
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
662+
true)
663+
val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes
664+
val project = Project(Seq(UnresolvedAttribute("a")), testRelation)
665+
val mapInPandas = MapInPandas(
666+
pythonUdf,
667+
output,
668+
project)
669+
val left = SubqueryAlias("temp0", mapInPandas)
670+
val right = SubqueryAlias("temp1", mapInPandas)
671+
val join = Join(left, right, Inner, None, JoinHint.NONE)
672+
assertAnalysisSuccess(
673+
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
674+
}
675+
657676
test("SPARK-24488 Generator with multiple aliases") {
658677
assertAnalysisSuccess(
659678
listRelation.select(Explode($"list").as("first_alias").as("second_alias")))

0 commit comments

Comments
 (0)