Skip to content

Commit 62338ed

Browse files
ivosonhvanhovell
authored andcommitted
[SPARK-42626][CONNECT] Add Destructive Iterator for SparkResult
### What changes were proposed in this pull request? Add a destructive iterator to SparkResult and change `Dataset.toLocalIterator` to use the desctructive iterator. With the desctructive iterator, we will: 1. Close the `ColumarBatch` once its data got consumed; 2. Remove the `ColumarBatch` from `SparkResult.batches`; ### Why are the changes needed? Instead of keeping everything in memory for the life time of SparkResult object, clean it up as soon as we know we are done with it. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT added. Closes #40610 from ivoson/SPARK-42626. Authored-by: Tengfei Huang <[email protected]> Signed-off-by: Herman van Hovell <[email protected]>
1 parent fc3489d commit 62338ed

File tree

3 files changed

+81
-17
lines changed

3 files changed

+81
-17
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2768,8 +2768,7 @@ class Dataset[T] private[sql] (
27682768
* @since 3.4.0
27692769
*/
27702770
def toLocalIterator(): java.util.Iterator[T] = {
2771-
// TODO make this a destructive iterator.
2772-
collectResult().iterator
2771+
collectResult().destructiveIterator
27732772
}
27742773

27752774
/**

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ private[sql] class SparkResult[T](
4646
private[this] var numRecords: Int = 0
4747
private[this] var structType: StructType = _
4848
private[this] var boundEncoder: ExpressionEncoder[T] = _
49-
private[this] val batches = mutable.Buffer.empty[ColumnarBatch]
49+
private[this] var nextBatchIndex: Int = 0
50+
private val idxToBatches = mutable.Map.empty[Int, ColumnarBatch]
5051

5152
private def createEncoder(schema: StructType): ExpressionEncoder[T] = {
5253
val agnosticEncoder = if (encoder == UnboundRowEncoder) {
@@ -70,12 +71,12 @@ private[sql] class SparkResult[T](
7071
val reader = new ArrowStreamReader(ipcStreamBytes.newInput(), allocator)
7172
try {
7273
val root = reader.getVectorSchemaRoot
73-
if (batches.isEmpty) {
74-
if (structType == null) {
75-
// If the schema is not available yet, fallback to the schema from Arrow.
76-
structType = ArrowUtils.fromArrowSchema(root.getSchema)
77-
}
78-
// TODO: create encoders that directly operate on arrow vectors.
74+
if (structType == null) {
75+
// If the schema is not available yet, fallback to the schema from Arrow.
76+
structType = ArrowUtils.fromArrowSchema(root.getSchema)
77+
}
78+
// TODO: create encoders that directly operate on arrow vectors.
79+
if (boundEncoder == null) {
7980
boundEncoder = createEncoder(structType).resolveAndBind(structType.toAttributes)
8081
}
8182
while (reader.loadNextBatch()) {
@@ -85,7 +86,8 @@ private[sql] class SparkResult[T](
8586
val vectors = root.getFieldVectors.asScala
8687
.map(v => new ArrowColumnVector(transferToNewVector(v)))
8788
.toArray[ColumnVector]
88-
batches += new ColumnarBatch(vectors, rowCount)
89+
idxToBatches.put(nextBatchIndex, new ColumnarBatch(vectors, rowCount))
90+
nextBatchIndex += 1
8991
numRecords += rowCount
9092
if (stopOnFirstNonEmptyResponse) {
9193
return true
@@ -142,24 +144,39 @@ private[sql] class SparkResult[T](
142144
/**
143145
* Returns an iterator over the contents of the result.
144146
*/
145-
def iterator: java.util.Iterator[T] with AutoCloseable = {
147+
def iterator: java.util.Iterator[T] with AutoCloseable =
148+
buildIterator(destructive = false)
149+
150+
/**
151+
* Returns an destructive iterator over the contents of the result.
152+
*/
153+
def destructiveIterator: java.util.Iterator[T] with AutoCloseable =
154+
buildIterator(destructive = true)
155+
156+
private def buildIterator(destructive: Boolean): java.util.Iterator[T] with AutoCloseable = {
146157
new java.util.Iterator[T] with AutoCloseable {
147158
private[this] var batchIndex: Int = -1
148159
private[this] var iterator: java.util.Iterator[InternalRow] = Collections.emptyIterator()
149160
private[this] var deserializer: Deserializer[T] = _
161+
150162
override def hasNext: Boolean = {
151163
if (iterator.hasNext) {
152164
return true
153165
}
166+
154167
val nextBatchIndex = batchIndex + 1
155-
val hasNextBatch = if (nextBatchIndex == batches.size) {
168+
if (destructive) {
169+
idxToBatches.remove(batchIndex).foreach(_.close())
170+
}
171+
172+
val hasNextBatch = if (!idxToBatches.contains(nextBatchIndex)) {
156173
processResponses(stopOnFirstNonEmptyResponse = true)
157174
} else {
158175
true
159176
}
160177
if (hasNextBatch) {
161178
batchIndex = nextBatchIndex
162-
iterator = batches(nextBatchIndex).rowIterator()
179+
iterator = idxToBatches(nextBatchIndex).rowIterator()
163180
if (deserializer == null) {
164181
deserializer = boundEncoder.createDeserializer()
165182
}
@@ -182,8 +199,8 @@ private[sql] class SparkResult[T](
182199
* Close this result, freeing any underlying resources.
183200
*/
184201
override def close(): Unit = {
185-
batches.foreach(_.close())
202+
idxToBatches.values.foreach(_.close())
186203
}
187204

188-
override def cleaner: AutoCloseable = AutoCloseables(batches.toSeq)
205+
override def cleaner: AutoCloseable = AutoCloseables(idxToBatches.values.toSeq)
189206
}

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.nio.file.Files
2121
import java.util.Properties
2222

2323
import scala.collection.JavaConverters._
24+
import scala.collection.mutable
2425
import scala.concurrent.{ExecutionContext, Future}
2526
import scala.concurrent.duration._
2627
import scala.util.{Failure, Success}
@@ -30,21 +31,23 @@ import org.apache.commons.io.FileUtils
3031
import org.apache.commons.io.output.TeeOutputStream
3132
import org.apache.commons.lang3.{JavaVersion, SystemUtils}
3233
import org.scalactic.TolerantNumerics
34+
import org.scalatest.PrivateMethodTester
3335
import org.scalatest.concurrent.Eventually._
3436

3537
import org.apache.spark.{SPARK_VERSION, SparkException}
3638
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
3739
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
3840
import org.apache.spark.sql.catalyst.parser.ParseException
39-
import org.apache.spark.sql.connect.client.SparkConnectClient
41+
import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult}
4042
import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
4143
import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils.port
4244
import org.apache.spark.sql.functions._
4345
import org.apache.spark.sql.internal.SQLConf
4446
import org.apache.spark.sql.types._
47+
import org.apache.spark.sql.vectorized.ColumnarBatch
4548
import org.apache.spark.util.ThreadUtils
4649

47-
class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
50+
class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester {
4851

4952
// Spark Result
5053
test("spark result schema") {
@@ -890,6 +893,51 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
890893
assert(message.contains("PARSE_SYNTAX_ERROR"))
891894
}
892895

896+
test("Dataset result destructive iterator") {
897+
// Helper methods for accessing private field `idxToBatches` from SparkResult
898+
val _idxToBatches =
899+
PrivateMethod[mutable.Map[Int, ColumnarBatch]](Symbol("idxToBatches"))
900+
901+
def getColumnarBatches(result: SparkResult[_]): Seq[ColumnarBatch] = {
902+
val idxToBatches = result invokePrivate _idxToBatches()
903+
904+
// Sort by key to get stable results.
905+
idxToBatches.toSeq.sortBy(_._1).map(_._2)
906+
}
907+
908+
val df = spark
909+
.range(0, 10, 1, 10)
910+
.filter("id > 5 and id < 9")
911+
912+
df.withResult { result =>
913+
try {
914+
// build and verify the destructive iterator
915+
val iterator = result.destructiveIterator
916+
// batches is empty before traversing the result iterator
917+
assert(getColumnarBatches(result).isEmpty)
918+
var previousBatch: ColumnarBatch = null
919+
val buffer = mutable.Buffer.empty[Long]
920+
while (iterator.hasNext) {
921+
// always having 1 batch, since a columnar batch will be removed and closed after
922+
// its data got consumed.
923+
val batches = getColumnarBatches(result)
924+
assert(batches.size === 1)
925+
assert(batches.head != previousBatch)
926+
previousBatch = batches.head
927+
928+
buffer.append(iterator.next())
929+
}
930+
// Batches should be closed and removed after traversing all the records.
931+
assert(getColumnarBatches(result).isEmpty)
932+
933+
val expectedResult = Seq(6L, 7L, 8L)
934+
assert(buffer.size === 3 && expectedResult.forall(buffer.contains))
935+
} finally {
936+
result.close()
937+
}
938+
}
939+
}
940+
893941
test("SparkSession.createDataFrame - large data set") {
894942
val threshold = 1024 * 1024
895943
withSQLConf(SQLConf.LOCAL_RELATION_CACHE_THRESHOLD.key -> threshold.toString) {

0 commit comments

Comments
 (0)