Skip to content

Commit 248e3cc

Browse files
committed
Add support another real case
1 parent bfa6039 commit 248e3cc

File tree

2 files changed

+41
-20
lines changed

2 files changed

+41
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -122,25 +122,33 @@ trait ConstraintHelper {
122122

123123
greaterThans.foreach {
124124
case gt @ GreaterThan(l: Attribute, r: Attribute) =>
125-
inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt)
125+
inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt)
126126
case gt @ GreaterThanOrEqual(l: Attribute, r: Attribute) =>
127-
inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt)
127+
inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt)
128128
case gt @ GreaterThan(l @ Cast(_: Attribute, _, _), r: Attribute) =>
129-
inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt)
129+
inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt)
130130
case gt @ GreaterThanOrEqual(l @ Cast(_: Attribute, _, _), r: Attribute) =>
131-
inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt)
131+
inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt)
132+
case gt @ GreaterThan(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
133+
inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt)
134+
case gt @ GreaterThanOrEqual(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
135+
inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt)
132136
case _ => // No inference
133137
}
134138

135139
lessThans.foreach {
136140
case lt @ LessThan(l: Attribute, r: Attribute) =>
137-
inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt)
141+
inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt)
138142
case lt @ LessThanOrEqual(l: Attribute, r: Attribute) =>
139-
inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt)
143+
inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt)
140144
case lt @ LessThan(l @ Cast(_: Attribute, _, _), r: Attribute) =>
141-
inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt)
145+
inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt)
142146
case lt @ LessThanOrEqual(l @ Cast(_: Attribute, _, _), r: Attribute) =>
143-
inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt)
147+
inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt)
148+
case lt @ LessThan(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
149+
inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt)
150+
case lt @ LessThanOrEqual(l: Attribute, r @ Cast(_: Attribute, _, _)) =>
151+
inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt)
144152
case _ => // No inference
145153
}
146154
(inferredConstraints -- constraints -- greaterThans -- lessThans)
@@ -154,21 +162,16 @@ trait ConstraintHelper {
154162
case e: Expression if e.semanticEquals(source) => destination
155163
})
156164

157-
private def inferInequalityConstraints(
165+
private def replaceInequalityConstraints(
158166
constraints: Set[Expression],
159167
source: Expression,
160168
destination: Expression,
161-
binaryComparison: BinaryComparison): Set[Expression] = constraints.map {
162-
case EqualTo(l, r) if l.semanticEquals(source) =>
163-
binaryComparison.makeCopy(Array(destination, r))
164-
case EqualTo(l, r) if r.semanticEquals(source) =>
165-
binaryComparison.makeCopy(Array(destination, l))
166-
case gt @ GreaterThan(l, r) if l.semanticEquals(source) =>
167-
gt.makeCopy(Array(destination, r))
168-
case lt @ LessThan(l, r) if l.semanticEquals(source) =>
169-
lt.makeCopy(Array(destination, r))
170-
case BinaryComparison(l, r) if l.semanticEquals(source) =>
171-
binaryComparison.makeCopy(Array(destination, r))
169+
op: BinaryComparison): Set[Expression] = (constraints - op).map {
170+
case EqualTo(l, r) if l.semanticEquals(source) => op.makeCopy(Array(destination, r))
171+
case EqualTo(l, r) if r.semanticEquals(source) => op.makeCopy(Array(destination, l))
172+
case gt @ GreaterThan(l, r) if l.semanticEquals(source) => gt.makeCopy(Array(destination, r))
173+
case lt @ LessThan(l, r) if l.semanticEquals(source) => lt.makeCopy(Array(destination, r))
174+
case BinaryComparison(l, r) if l.semanticEquals(source) => op.makeCopy(Array(destination, r))
172175
case other => other
173176
}
174177

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,4 +421,22 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
421421
val optimized = optimizedLeft.join(optimizedRight, Inner, condition)
422422
comparePlans(Optimize.execute(original.analyze), optimized.analyze)
423423
}
424+
425+
test("Constraints inferred from inequality attributes: case4") {
426+
val testRelation1 = LocalRelation('a.long, 'b.long, 'c.long).as("x")
427+
val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int).as("y")
428+
429+
// y.b < 13 inferred from y.b < x.b && x.b <= 13
430+
val left = testRelation1.where('b <= 13L).as("x")
431+
val right = testRelation2.as("y")
432+
433+
val optimizedLeft = testRelation1.where(IsNotNull('a) && IsNotNull('b) && 'b <= 13L).as("x")
434+
val optimizedRight = testRelation2.where(IsNotNull('a) && IsNotNull('b)
435+
&& 'b.cast(LongType) < 13L).as("y")
436+
437+
val condition = Some("x.a".attr === "y.a".attr && "y.b".attr < "x.b".attr)
438+
val original = left.join(right, Inner, condition)
439+
val optimized = optimizedLeft.join(optimizedRight, Inner, condition)
440+
comparePlans(Optimize.execute(original.analyze), optimized.analyze)
441+
}
424442
}

0 commit comments

Comments
 (0)