Skip to content

Commit b05e630

Browse files
committed
fix via partitioning restriction
1 parent cde64ad commit b05e630

File tree

7 files changed

+135
-6
lines changed

7 files changed

+135
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,16 @@ sealed trait Partitioning {
171171
* produced by `A` could have also been produced by `B`.
172172
*/
173173
def guarantees(other: Partitioning): Boolean = this == other
174+
175+
/**
176+
* Returns the partitioning scheme that is valid under restriction to a given set of output
177+
* attributes. If the partitioning is an [[Expression]] then the attributes that it depends on
178+
* must be in the outputSet otherwise the attribute leaks.
179+
*/
180+
def restrict(outputSet: AttributeSet): Partitioning = this match {
181+
case p: Expression if !p.references.subsetOf(outputSet) => UnknownPartitioning(numPartitions)
182+
case _ => this
183+
}
174184
}
175185

176186
object Partitioning {
@@ -356,6 +366,14 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
356366
override def guarantees(other: Partitioning): Boolean =
357367
partitionings.exists(_.guarantees(other))
358368

369+
override def restrict(outputSet: AttributeSet): Partitioning = {
370+
partitionings.map(_.restrict(outputSet)).filter(!_.isInstanceOf[UnknownPartitioning]) match {
371+
case Nil => UnknownPartitioning(numPartitions)
372+
case singlePartitioning :: Nil => singlePartitioning
373+
case more => PartitioningCollection(more)
374+
}
375+
}
376+
359377
override def toString: String = {
360378
partitionings.map(_.toString).mkString("(", " or ", ")")
361379
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
package org.apache.spark.sql.catalyst
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal}
22-
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning}
21+
22+
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, AttributeSet, InterpretedMutableProjection, Literal, NullsFirst, SortOrder}
23+
import org.apache.spark.sql.catalyst.plans.physical._
24+
import org.apache.spark.sql.types.DataTypes
2325

2426
class PartitioningSuite extends SparkFunSuite {
2527
test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") {
@@ -52,4 +54,41 @@ class PartitioningSuite extends SparkFunSuite {
5254
assert(partitioningA.guarantees(partitioningA))
5355
assert(partitioningA.compatibleWith(partitioningA))
5456
}
57+
58+
test("restriction of Partitioning works") {
59+
val n = 5
60+
61+
val a1 = AttributeReference("a1", DataTypes.IntegerType)()
62+
val a2 = AttributeReference("a2", DataTypes.IntegerType)()
63+
val a3 = AttributeReference("a3", DataTypes.IntegerType)()
64+
65+
val hashPartitioning = HashPartitioning(Seq(a1, a2), n)
66+
67+
assert(hashPartitioning.restrict(AttributeSet(Seq())) === UnknownPartitioning(n))
68+
assert(hashPartitioning.restrict(AttributeSet(Seq(a1))) === UnknownPartitioning(n))
69+
assert(hashPartitioning.restrict(AttributeSet(Seq(a1, a2))) === hashPartitioning)
70+
assert(hashPartitioning.restrict(AttributeSet(Seq(a1, a2, a3))) === hashPartitioning)
71+
72+
val so1 = SortOrder(a1, Ascending)
73+
val so2 = SortOrder(a2, Ascending)
74+
75+
val rangePartitioning1 = RangePartitioning(Seq(so1), n)
76+
val rangePartitioning2 = RangePartitioning(Seq(so1, so2), n)
77+
78+
assert(rangePartitioning2.restrict(AttributeSet(Seq())) == UnknownPartitioning(n))
79+
assert(rangePartitioning2.restrict(AttributeSet(Seq(a1))) == UnknownPartitioning(n))
80+
assert(rangePartitioning2.restrict(AttributeSet(Seq(a1, a2))) === rangePartitioning2)
81+
assert(rangePartitioning2.restrict(AttributeSet(Seq(a1, a2, a3))) === rangePartitioning2)
82+
83+
assert(SinglePartition.restrict(AttributeSet(a1)) === SinglePartition)
84+
85+
val all = Seq(hashPartitioning, rangePartitioning1, rangePartitioning2)
86+
val partitioningCollection = PartitioningCollection(all)
87+
88+
assert(partitioningCollection.restrict(AttributeSet(Seq())) == UnknownPartitioning(n))
89+
assert(partitioningCollection.restrict(AttributeSet(Seq(a1))) == rangePartitioning1)
90+
assert(partitioningCollection.restrict(AttributeSet(Seq(a1, a2))) == partitioningCollection)
91+
assert(partitioningCollection.restrict(AttributeSet(Seq(a1, a2, a3))) == partitioningCollection)
92+
93+
}
5594
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
6565
false
6666
}
6767

68+
override def verboseStringWithSuffix: String = {
69+
s"$verboseString $outputPartitioning"
70+
}
71+
6872
/** Overridden make copy also propagates sqlContext to copied plan. */
6973
override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
7074
SparkSession.setActiveSession(sqlContext.sparkSession)

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
273273
verbose: Boolean,
274274
prefix: String = "",
275275
addSuffix: Boolean = false): StringBuilder = {
276-
child.generateTreeString(depth, lastChildren, builder, verbose, "")
276+
child.generateTreeString(depth, lastChildren, builder, verbose, "", addSuffix)
277277
}
278278
}
279279

@@ -448,7 +448,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
448448
verbose: Boolean,
449449
prefix: String = "",
450450
addSuffix: Boolean = false): StringBuilder = {
451-
child.generateTreeString(depth, lastChildren, builder, verbose, "*")
451+
child.generateTreeString(depth, lastChildren, builder, verbose, "*", addSuffix)
452452
}
453453
}
454454

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ case class HashAggregateExec(
6464

6565
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
6666

67-
override def outputPartitioning: Partitioning = child.outputPartitioning
67+
override def outputPartitioning: Partitioning = child.outputPartitioning.restrict(outputSet)
6868

6969
override def producedAttributes: AttributeSet =
7070
AttributeSet(aggregateAttributes) ++

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
8080

8181
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
8282

83-
override def outputPartitioning: Partitioning = child.outputPartitioning
83+
override def outputPartitioning: Partitioning = child.outputPartitioning.restrict(outputSet)
8484
}
8585

8686

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ import org.apache.spark.sql.execution.joins._
2626
import org.apache.spark.sql.internal.SQLConf
2727
import org.apache.spark.sql.test.SharedSQLContext
2828
import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
29+
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
30+
import org.apache.spark.sql.execution.WholeStageCodegenExec
31+
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
32+
import org.apache.spark.sql.execution.exchange.Exchange
2933

3034
class JoinSuite extends QueryTest with SharedSQLContext {
3135
import testImplicits._
@@ -36,6 +40,70 @@ class JoinSuite extends QueryTest with SharedSQLContext {
3640
df.queryExecution.optimizedPlan.stats.sizeInBytes
3741
}
3842

43+
test("SPARK-16683 Repeated joins to same table can leak attributes via partitioning") {
44+
val hier = sqlContext.sparkSession.sparkContext.parallelize(Seq(
45+
("A10", "A1"),
46+
("A11", "A1"),
47+
("A20", "A2"),
48+
("A21", "A2"),
49+
("B10", "B1"),
50+
("B11", "B1"),
51+
("B20", "B2"),
52+
("B21", "B2"),
53+
("A1", "A"),
54+
("A2", "A"),
55+
("B1", "B"),
56+
("B2", "B")
57+
)).toDF("son", "parent").cache() // passes if cache is removed but with count on dist1
58+
hier.createOrReplaceTempView("hier")
59+
hier.count() // if this is removed it passes
60+
61+
val base = sqlContext.sparkSession.sparkContext.parallelize(Seq(
62+
Tuple1("A10"),
63+
Tuple1("A11"),
64+
Tuple1("A20"),
65+
Tuple1("A21"),
66+
Tuple1("B10"),
67+
Tuple1("B11"),
68+
Tuple1("B20"),
69+
Tuple1("B21")
70+
)).toDF("id")
71+
base.createOrReplaceTempView("base")
72+
73+
val dist1 = spark.sql("""
74+
SELECT parent level1
75+
FROM base INNER JOIN hier h1 ON base.id = h1.son
76+
GROUP BY parent""")
77+
78+
dist1.createOrReplaceTempView("dist1")
79+
// dist1.count() // or put a count here
80+
81+
val dist2 = spark.sql("""
82+
SELECT parent level2
83+
FROM dist1 INNER JOIN hier h2 ON dist1.level1 = h2.son
84+
GROUP BY parent""")
85+
86+
val plan = dist2.queryExecution.executedPlan
87+
// For debug print tree string with partitioning suffix
88+
// println(plan.treeString(verbose = true, addSuffix = true))
89+
90+
dist2.createOrReplaceTempView("dist2")
91+
checkAnswer(dist2, Row("A") :: Row("B") :: Nil)
92+
93+
assert(plan.isInstanceOf[WholeStageCodegenExec])
94+
assert(plan.outputPartitioning === UnknownPartitioning(5))
95+
96+
val agg = plan.children.head
97+
98+
assert(agg.isInstanceOf[HashAggregateExec])
99+
assert(agg.outputPartitioning === UnknownPartitioning(5))
100+
101+
// Skip input adaptor
102+
val exchange = agg.children.head.children.head
103+
assert(exchange.isInstanceOf[Exchange])
104+
assert(exchange.outputPartitioning.isInstanceOf[HashPartitioning])
105+
}
106+
39107
test("equi-join is hash-join") {
40108
val x = testData2.as("x")
41109
val y = testData2.as("y")

0 commit comments

Comments
 (0)