Skip to content

Commit a6f37b0

Browse files
committed
[SPARK-25456][SQL][TEST] Fix PythonForeachWriterSuite
PythonForeachWriterSuite was failing because RowQueue now needs to have a handle on a SparkEnv with a SerializerManager, so added a mock env with a serializer manager. Also fixed a typo in the `finally` that was hiding the real exception. Tested PythonForeachWriterSuite locally, full tests via jenkins. Closes #22452 from squito/SPARK-25456. Authored-by: Imran Rashid <[email protected]> Signed-off-by: Imran Rashid <[email protected]>
1 parent 123f004 commit a6f37b0

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,20 @@ package org.apache.spark.sql.execution.python
1919

2020
import scala.collection.mutable.ArrayBuffer
2121

22+
import org.mockito.Mockito.when
2223
import org.scalatest.concurrent.Eventually
24+
import org.scalatest.mockito.MockitoSugar
2325
import org.scalatest.time.SpanSugar._
2426

2527
import org.apache.spark._
2628
import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
29+
import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
2730
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
2831
import org.apache.spark.sql.execution.python.PythonForeachWriter.UnsafeRowBuffer
2932
import org.apache.spark.sql.types.{DataType, IntegerType}
3033
import org.apache.spark.util.Utils
3134

32-
class PythonForeachWriterSuite extends SparkFunSuite with Eventually {
35+
class PythonForeachWriterSuite extends SparkFunSuite with Eventually with MockitoSugar {
3336

3437
testWithBuffer("UnsafeRowBuffer: iterator blocks when no data is available") { b =>
3538
b.assertIteratorBlocked()
@@ -75,15 +78,20 @@ class PythonForeachWriterSuite extends SparkFunSuite with Eventually {
7578
tester = new BufferTester(memBytes, sleepPerRowReadMs)
7679
f(tester)
7780
} finally {
78-
if (tester == null) tester.close()
81+
if (tester != null) tester.close()
7982
}
8083
}
8184
}
8285

8386

8487
class BufferTester(memBytes: Long, sleepPerRowReadMs: Int) {
8588
private val buffer = {
86-
val mem = new TestMemoryManager(new SparkConf())
89+
val mockEnv = mock[SparkEnv]
90+
val conf = new SparkConf()
91+
val serializerManager = new SerializerManager(new JavaSerializer(conf), conf, None)
92+
when(mockEnv.serializerManager).thenReturn(serializerManager)
93+
SparkEnv.set(mockEnv)
94+
val mem = new TestMemoryManager(conf)
8795
mem.limit(memBytes)
8896
val taskM = new TaskMemoryManager(mem, 0)
8997
new UnsafeRowBuffer(taskM, Utils.createTempDir(), 1)

0 commit comments

Comments
 (0)