Skip to content

Commit 344b516

Browse files
committed
Add test for PullOutPythonUDFInJoinCondition
1 parent 3ed91c9 commit 344b516

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.scalatest.Matchers._
21+
22+
import org.apache.spark.api.python.PythonEvalType
23+
import org.apache.spark.sql.AnalysisException
24+
import org.apache.spark.sql.catalyst.dsl.expressions._
25+
import org.apache.spark.sql.catalyst.dsl.plans._
26+
import org.apache.spark.sql.catalyst.expressions.PythonUDF
27+
import org.apache.spark.sql.catalyst.plans._
28+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
29+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
30+
import org.apache.spark.sql.internal.SQLConf._
31+
import org.apache.spark.sql.types.BooleanType
32+
33+
class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
34+
35+
object Optimize extends RuleExecutor[LogicalPlan] {
36+
val batches =
37+
Batch("Extract PythonUDF From JoinCondition", Once,
38+
PullOutPythonUDFInJoinCondition) ::
39+
Batch("Check Cartesian Products", Once,
40+
CheckCartesianProducts) :: Nil
41+
}
42+
43+
val testRelationLeft = LocalRelation('a.int, 'b.int)
44+
val testRelationRight = LocalRelation('c.int, 'd.int)
45+
46+
// Dummy python UDF for testing. Unable to execute.
47+
val pythonUDF = PythonUDF("pythonUDF", null,
48+
BooleanType,
49+
Seq.empty,
50+
PythonEvalType.SQL_BATCHED_UDF,
51+
udfDeterministic = true)
52+
53+
val notSupportJoinTypes = Seq(LeftOuter, RightOuter, FullOuter, LeftAnti)
54+
55+
test("inner join condition with python udf only") {
56+
val query = testRelationLeft.join(
57+
testRelationRight,
58+
joinType = Inner,
59+
condition = Some(pythonUDF))
60+
val expected = testRelationLeft.join(
61+
testRelationRight,
62+
joinType = Inner,
63+
condition = None).where(pythonUDF).analyze
64+
65+
// AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false
66+
val exception = the [AnalysisException] thrownBy {
67+
Optimize.execute(query.analyze)
68+
}
69+
assert(exception.message.startsWith("Detected implicit cartesian product"))
70+
71+
// pull out the python udf while set spark.sql.crossJoin.enabled=true
72+
withSQLConf(CROSS_JOINS_ENABLED.key -> "true") {
73+
val optimized = Optimize.execute(query.analyze)
74+
comparePlans(optimized, expected)
75+
}
76+
}
77+
78+
test("left semi join condition with python udf only") {
79+
val query = testRelationLeft.join(
80+
testRelationRight,
81+
joinType = LeftSemi,
82+
condition = Some(pythonUDF))
83+
val expected = testRelationLeft.join(
84+
testRelationRight,
85+
joinType = Inner,
86+
condition = None).where(pythonUDF).select('a, 'b).analyze
87+
88+
// AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false
89+
val exception = the [AnalysisException] thrownBy {
90+
Optimize.execute(query.analyze)
91+
}
92+
assert(exception.message.startsWith("Detected implicit cartesian product"))
93+
94+
// pull out the python udf while set spark.sql.crossJoin.enabled=true
95+
withSQLConf(CROSS_JOINS_ENABLED.key -> "true") {
96+
val optimized = Optimize.execute(query.analyze)
97+
comparePlans(optimized, expected)
98+
}
99+
}
100+
101+
test("python udf with other common condition") {
102+
val query = testRelationLeft.join(
103+
testRelationRight,
104+
joinType = Inner,
105+
condition = Some(pythonUDF && 'a.attr === 'c.attr))
106+
val expected = testRelationLeft.join(
107+
testRelationRight,
108+
joinType = Inner,
109+
condition = Some('a.attr === 'c.attr)).where(pythonUDF).analyze
110+
val optimized = Optimize.execute(query.analyze)
111+
comparePlans(optimized, expected)
112+
}
113+
114+
test("throw an exception for not support join type") {
115+
for (joinType <- notSupportJoinTypes) {
116+
val thrownException = the [AnalysisException] thrownBy {
117+
val query = testRelationLeft.join(
118+
testRelationRight,
119+
joinType,
120+
condition = Some(pythonUDF))
121+
Optimize.execute(query.analyze)
122+
}
123+
assert(thrownException.message.contentEquals(
124+
s"Using PythonUDF in join condition of join type $joinType is not supported."))
125+
}
126+
}
127+
}
128+

0 commit comments

Comments
 (0)