Skip to content

Commit 1f4e4c8

Browse files
somaniwangyum
authored andcommitted
[SPARK-32268][SQL] Row-level Runtime Filtering
### What changes were proposed in this pull request? This PR proposes row-level runtime filters in Spark to reduce intermediate data volume for operators like shuffle, join and aggregate, and hence improve performance. We propose two mechanisms to do this: semi-join filters or bloom filters, and both mechanisms are proposed to co-exist side-by-side behind feature configs. [Design Doc](https://docs.google.com/document/d/16IEuyLeQlubQkH8YuVuXWKo2-grVIoDJqQpHZrE7q04/edit?usp=sharing) with more details. ### Why are the changes needed? With Semi-Join, we see 9 queries improve for the TPC DS 3TB benchmark, and no regressions. With Bloom Filter, we see 10 queries improve for the TPC DS 3TB benchmark, and no regressions. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added tests Closes apache#35789 from somani/rf. Lead-authored-by: Abhishek Somani <[email protected]> Co-authored-by: Abhishek Somani <[email protected]> Co-authored-by: Yuming Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 4e60638 commit 1f4e4c8

File tree

14 files changed

+1432
-16
lines changed

14 files changed

+1432
-16
lines changed

common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ int getVersionNumber() {
163163
*/
164164
public abstract void writeTo(OutputStream out) throws IOException;
165165

166+
/**
167+
* @return the number of set bits in this {@link BloomFilter}.
168+
*/
169+
public long cardinality() {
170+
throw new UnsupportedOperationException("Not implemented");
171+
}
172+
166173
/**
167174
* Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close
168175
* the stream.

common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,11 @@ public BloomFilter intersectInPlace(BloomFilter other) throws IncompatibleMergeE
207207
return this;
208208
}
209209

210+
@Override
211+
public long cardinality() {
212+
return this.bits.cardinality();
213+
}
214+
210215
private BloomFilterImpl checkCompatibilityForMerge(BloomFilter other)
211216
throws IncompatibleMergeException {
212217
// Duplicates the logic of `isCompatible` here to provide better error message.
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import java.io.ByteArrayInputStream
21+
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
24+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral}
25+
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
26+
import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
27+
import org.apache.spark.sql.types._
28+
import org.apache.spark.util.sketch.BloomFilter
29+
30+
/**
31+
* An internal scalar function that returns the membership check result (either true or false)
32+
* for values of `valueExpression` in the Bloom filter represented by `bloomFilterExpression`.
33+
* Not that since the function is "might contain", always returning true regardless is not
34+
* wrong.
35+
* Note that this expression requires that `bloomFilterExpression` is either a constant value or
36+
* an uncorrelated scalar subquery. This is sufficient for the Bloom filter join rewrite.
37+
*
38+
* @param bloomFilterExpression the Binary data of Bloom filter.
39+
* @param valueExpression the Long value to be tested for the membership of `bloomFilterExpression`.
40+
*/
41+
case class BloomFilterMightContain(
42+
bloomFilterExpression: Expression,
43+
valueExpression: Expression) extends BinaryExpression {
44+
45+
override def nullable: Boolean = true
46+
override def left: Expression = bloomFilterExpression
47+
override def right: Expression = valueExpression
48+
override def prettyName: String = "might_contain"
49+
override def dataType: DataType = BooleanType
50+
51+
override def checkInputDataTypes(): TypeCheckResult = {
52+
(left.dataType, right.dataType) match {
53+
case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) |
54+
(BinaryType, LongType) =>
55+
bloomFilterExpression match {
56+
case e : Expression if e.foldable => TypeCheckResult.TypeCheckSuccess
57+
case subquery : PlanExpression[_] if !subquery.containsPattern(OUTER_REFERENCE) =>
58+
TypeCheckResult.TypeCheckSuccess
59+
case _ =>
60+
TypeCheckResult.TypeCheckFailure(s"The Bloom filter binary input to $prettyName " +
61+
"should be either a constant value or a scalar subquery expression")
62+
}
63+
case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
64+
s"been ${BinaryType.simpleString} followed by a value with ${LongType.simpleString}, " +
65+
s"but it's [${left.dataType.catalogString}, ${right.dataType.catalogString}].")
66+
}
67+
}
68+
69+
override protected def withNewChildrenInternal(
70+
newBloomFilterExpression: Expression,
71+
newValueExpression: Expression): BloomFilterMightContain =
72+
copy(bloomFilterExpression = newBloomFilterExpression,
73+
valueExpression = newValueExpression)
74+
75+
// The bloom filter created from `bloomFilterExpression`.
76+
@transient private lazy val bloomFilter = {
77+
val bytes = bloomFilterExpression.eval().asInstanceOf[Array[Byte]]
78+
if (bytes == null) null else deserialize(bytes)
79+
}
80+
81+
override def eval(input: InternalRow): Any = {
82+
if (bloomFilter == null) {
83+
null
84+
} else {
85+
val value = valueExpression.eval(input)
86+
if (value == null) null else bloomFilter.mightContainLong(value.asInstanceOf[Long])
87+
}
88+
}
89+
90+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
91+
if (bloomFilter == null) {
92+
ev.copy(isNull = TrueLiteral, value = JavaCode.defaultLiteral(dataType))
93+
} else {
94+
val bf = ctx.addReferenceObj("bloomFilter", bloomFilter, classOf[BloomFilter].getName)
95+
val valueEval = valueExpression.genCode(ctx)
96+
ev.copy(code = code"""
97+
${valueEval.code}
98+
boolean ${ev.isNull} = ${valueEval.isNull};
99+
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
100+
if (!${ev.isNull}) {
101+
${ev.value} = $bf.mightContainLong((Long)${valueEval.value});
102+
}""")
103+
}
104+
}
105+
106+
final def deserialize(bytes: Array[Byte]): BloomFilter = {
107+
val in = new ByteArrayInputStream(bytes)
108+
val bloomFilter = BloomFilter.readFrom(in)
109+
in.close()
110+
bloomFilter
111+
}
112+
113+
}
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import java.io.ByteArrayInputStream
21+
import java.io.ByteArrayOutputStream
22+
23+
import org.apache.spark.sql.catalyst.InternalRow
24+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
25+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
26+
import org.apache.spark.sql.catalyst.expressions._
27+
import org.apache.spark.sql.catalyst.trees.TernaryLike
28+
import org.apache.spark.sql.internal.SQLConf
29+
import org.apache.spark.sql.types._
30+
import org.apache.spark.util.sketch.BloomFilter
31+
32+
/**
33+
* An internal aggregate function that creates a Bloom filter from input values.
34+
*
35+
* @param child Child expression of Long values for creating a Bloom filter.
36+
* @param estimatedNumItemsExpression The number of estimated distinct items (optional).
37+
* @param numBitsExpression The number of bits to use (optional).
38+
*/
39+
case class BloomFilterAggregate(
40+
child: Expression,
41+
estimatedNumItemsExpression: Expression,
42+
numBitsExpression: Expression,
43+
override val mutableAggBufferOffset: Int,
44+
override val inputAggBufferOffset: Int)
45+
extends TypedImperativeAggregate[BloomFilter] with TernaryLike[Expression] {
46+
47+
def this(child: Expression, estimatedNumItemsExpression: Expression,
48+
numBitsExpression: Expression) = {
49+
this(child, estimatedNumItemsExpression, numBitsExpression, 0, 0)
50+
}
51+
52+
def this(child: Expression, estimatedNumItemsExpression: Expression) = {
53+
this(child, estimatedNumItemsExpression,
54+
// 1 byte per item.
55+
Multiply(estimatedNumItemsExpression, Literal(8L)))
56+
}
57+
58+
def this(child: Expression) = {
59+
this(child, Literal(SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_EXPECTED_NUM_ITEMS)),
60+
Literal(SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_NUM_BITS)))
61+
}
62+
63+
override def checkInputDataTypes(): TypeCheckResult = {
64+
(first.dataType, second.dataType, third.dataType) match {
65+
case (_, NullType, _) | (_, _, NullType) =>
66+
TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as size arguments")
67+
case (LongType, LongType, LongType) =>
68+
if (!estimatedNumItemsExpression.foldable) {
69+
TypeCheckFailure("The estimated number of items provided must be a constant literal")
70+
} else if (estimatedNumItems <= 0L) {
71+
TypeCheckFailure("The estimated number of items must be a positive value " +
72+
s" (current value = $estimatedNumItems)")
73+
} else if (!numBitsExpression.foldable) {
74+
TypeCheckFailure("The number of bits provided must be a constant literal")
75+
} else if (numBits <= 0L) {
76+
TypeCheckFailure("The number of bits must be a positive value " +
77+
s" (current value = $numBits)")
78+
} else {
79+
require(estimatedNumItems <=
80+
SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))
81+
require(numBits <= SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))
82+
TypeCheckSuccess
83+
}
84+
case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
85+
s"been a ${LongType.simpleString} value followed with two ${LongType.simpleString} size " +
86+
s"arguments, but it's [${first.dataType.catalogString}, " +
87+
s"${second.dataType.catalogString}, ${third.dataType.catalogString}]")
88+
}
89+
}
90+
override def nullable: Boolean = true
91+
92+
override def dataType: DataType = BinaryType
93+
94+
override def prettyName: String = "bloom_filter_agg"
95+
96+
// Mark as lazy so that `estimatedNumItems` is not evaluated during tree transformation.
97+
private lazy val estimatedNumItems: Long =
98+
Math.min(estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue,
99+
SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))
100+
101+
// Mark as lazy so that `numBits` is not evaluated during tree transformation.
102+
private lazy val numBits: Long =
103+
Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue,
104+
SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))
105+
106+
override def first: Expression = child
107+
108+
override def second: Expression = estimatedNumItemsExpression
109+
110+
override def third: Expression = numBitsExpression
111+
112+
override protected def withNewChildrenInternal(
113+
newChild: Expression,
114+
newEstimatedNumItemsExpression: Expression,
115+
newNumBitsExpression: Expression): BloomFilterAggregate = {
116+
copy(child = newChild, estimatedNumItemsExpression = newEstimatedNumItemsExpression,
117+
numBitsExpression = newNumBitsExpression)
118+
}
119+
120+
override def createAggregationBuffer(): BloomFilter = {
121+
BloomFilter.create(estimatedNumItems, numBits)
122+
}
123+
124+
override def update(buffer: BloomFilter, inputRow: InternalRow): BloomFilter = {
125+
val value = child.eval(inputRow)
126+
// Ignore null values.
127+
if (value == null) {
128+
return buffer
129+
}
130+
buffer.putLong(value.asInstanceOf[Long])
131+
buffer
132+
}
133+
134+
override def merge(buffer: BloomFilter, other: BloomFilter): BloomFilter = {
135+
buffer.mergeInPlace(other)
136+
}
137+
138+
override def eval(buffer: BloomFilter): Any = {
139+
if (buffer.cardinality() == 0) {
140+
// There's no set bit in the Bloom filter and hence no not-null value is processed.
141+
return null
142+
}
143+
serialize(buffer)
144+
}
145+
146+
override def withNewMutableAggBufferOffset(newOffset: Int): BloomFilterAggregate =
147+
copy(mutableAggBufferOffset = newOffset)
148+
149+
override def withNewInputAggBufferOffset(newOffset: Int): BloomFilterAggregate =
150+
copy(inputAggBufferOffset = newOffset)
151+
152+
override def serialize(obj: BloomFilter): Array[Byte] = {
153+
BloomFilterAggregate.serialize(obj)
154+
}
155+
156+
override def deserialize(bytes: Array[Byte]): BloomFilter = {
157+
BloomFilterAggregate.deserialize(bytes)
158+
}
159+
}
160+
161+
object BloomFilterAggregate {
162+
final def serialize(obj: BloomFilter): Array[Byte] = {
163+
// BloomFilterImpl.writeTo() writes 2 integers (version number and num hash functions), hence
164+
// the +8
165+
val size = (obj.bitSize() / 8) + 8
166+
require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size")
167+
val out = new ByteArrayOutputStream(size.intValue())
168+
obj.writeTo(out)
169+
out.close()
170+
out.toByteArray
171+
}
172+
173+
final def deserialize(bytes: Array[Byte]): BloomFilter = {
174+
val in = new ByteArrayInputStream(bytes)
175+
val bloomFilter = BloomFilter.readFrom(in)
176+
in.close()
177+
bloomFilter
178+
}
179+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,8 @@ case class Invoke(
360360

361361
lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
362362

363+
final override val nodePatterns: Seq[TreePattern] = Seq(INVOKE)
364+
363365
override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
364366
override def children: Seq[Expression] = targetObject +: arguments
365367
override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,22 @@ trait PredicateHelper extends AliasHelper with Logging {
287287
}
288288
}
289289
}
290+
291+
/**
292+
* Returns whether an expression is likely to be selective
293+
*/
294+
def isLikelySelective(e: Expression): Boolean = e match {
295+
case Not(expr) => isLikelySelective(expr)
296+
case And(l, r) => isLikelySelective(l) || isLikelySelective(r)
297+
case Or(l, r) => isLikelySelective(l) && isLikelySelective(r)
298+
case _: StringRegexExpression => true
299+
case _: BinaryComparison => true
300+
case _: In | _: InSet => true
301+
case _: StringPredicate => true
302+
case BinaryPredicate(_) => true
303+
case _: MultiLikeBase => true
304+
case _ => false
305+
}
290306
}
291307

292308
@ExpressionDescription(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure,
3131
import org.apache.spark.sql.catalyst.expressions.codegen._
3232
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
3333
import org.apache.spark.sql.catalyst.trees.BinaryLike
34-
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern}
34+
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern}
3535
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
3636
import org.apache.spark.sql.errors.QueryExecutionErrors
3737
import org.apache.spark.sql.types._
@@ -627,6 +627,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
627627
@transient private var lastReplacementInUTF8: UTF8String = _
628628
// result buffer write by Matcher
629629
@transient private lazy val result: StringBuffer = new StringBuffer
630+
final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_REPLACE)
630631

631632
override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = {
632633
if (!p.equals(lastRegex)) {
@@ -751,6 +752,8 @@ abstract class RegExpExtractBase
751752
// last regex pattern, we cache it for performance concern
752753
@transient private var pattern: Pattern = _
753754

755+
final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY)
756+
754757
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
755758
override def first: Expression = subject
756759
override def second: Expression = regexp

0 commit comments

Comments
 (0)