Skip to content

Commit 424be64

Browse files
committed
PythonRDDSuite fix
1 parent 17d357b commit 424be64

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,18 @@ import org.apache.hadoop.conf.Configuration
3030
import org.apache.hadoop.io.{LongWritable, Text}
3131
import org.apache.hadoop.mapred.TextInputFormat
3232
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
33+
import org.mockito.Mockito.mock
3334

34-
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
35+
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
3536
import org.apache.spark.api.java.JavaSparkContext
3637
import org.apache.spark.rdd.{HadoopRDD, RDD}
3738
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
3839
import org.apache.spark.util.Utils
3940

4041
class PythonRDDSuite extends SparkFunSuite with LocalSparkContext {
4142

43+
private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*)
44+
4245
var tempDir: File = _
4346

4447
override def beforeAll(): Unit = {
@@ -76,12 +79,22 @@ class PythonRDDSuite extends SparkFunSuite with LocalSparkContext {
7679
}
7780

7881
test("python server error handling") {
79-
val authHelper = new SocketAuthHelper(new SparkConf())
80-
val errorServer = new ExceptionPythonServer(authHelper)
81-
val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
82-
authHelper.authToServer(client)
83-
val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) }
84-
assert(ex.getCause().getMessage().contains("exception within handleConnection"))
82+
val savedSparkEnv = SparkEnv.get
83+
try {
84+
val conf = new SparkConf()
85+
val env = mock(classOf[SparkEnv])
86+
doReturn(conf).when(env).conf
87+
SparkEnv.set(env)
88+
89+
val authHelper = new SocketAuthHelper(conf)
90+
val errorServer = new ExceptionPythonServer(authHelper)
91+
val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
92+
authHelper.authToServer(client)
93+
val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) }
94+
assert(ex.getCause().getMessage().contains("exception within handleConnection"))
95+
} finally {
96+
SparkEnv.set(savedSparkEnv)
97+
}
8598
}
8699

87100
class ExceptionPythonServer(authHelper: SocketAuthHelper)

0 commit comments

Comments
 (0)