@@ -30,15 +30,18 @@ import org.apache.hadoop.conf.Configuration
3030import org .apache .hadoop .io .{LongWritable , Text }
3131import org .apache .hadoop .mapred .TextInputFormat
3232import org .apache .hadoop .mapreduce .lib .input .FileInputFormat
33+ import org .mockito .Mockito .mock
3334
34- import org .apache .spark .{LocalSparkContext , SparkConf , SparkContext , SparkFunSuite }
35+ import org .apache .spark .{LocalSparkContext , SparkConf , SparkContext , SparkEnv , SparkFunSuite }
3536import org .apache .spark .api .java .JavaSparkContext
3637import org .apache .spark .rdd .{HadoopRDD , RDD }
3738import org .apache .spark .security .{SocketAuthHelper , SocketAuthServer }
3839import org .apache .spark .util .Utils
3940
4041class PythonRDDSuite extends SparkFunSuite with LocalSparkContext {
4142
43+ private def doReturn (value : Any ) = org.mockito.Mockito .doReturn(value, Seq .empty: _* )
44+
4245 var tempDir : File = _
4346
4447 override def beforeAll (): Unit = {
@@ -76,12 +79,22 @@ class PythonRDDSuite extends SparkFunSuite with LocalSparkContext {
7679 }
7780
7881 test(" python server error handling" ) {
79- val authHelper = new SocketAuthHelper (new SparkConf ())
80- val errorServer = new ExceptionPythonServer (authHelper)
81- val client = new Socket (InetAddress .getLoopbackAddress(), errorServer.port)
82- authHelper.authToServer(client)
83- val ex = intercept[Exception ] { errorServer.getResult(Duration (1 , " second" )) }
84- assert(ex.getCause().getMessage().contains(" exception within handleConnection" ))
82+ val savedSparkEnv = SparkEnv .get
83+ try {
84+ val conf = new SparkConf ()
85+ val env = mock(classOf [SparkEnv ])
86+ doReturn(conf).when(env).conf
87+ SparkEnv .set(env)
88+
89+ val authHelper = new SocketAuthHelper (conf)
90+ val errorServer = new ExceptionPythonServer (authHelper)
91+ val client = new Socket (InetAddress .getLoopbackAddress(), errorServer.port)
92+ authHelper.authToServer(client)
93+ val ex = intercept[Exception ] { errorServer.getResult(Duration (1 , " second" )) }
94+ assert(ex.getCause().getMessage().contains(" exception within handleConnection" ))
95+ } finally {
96+ SparkEnv .set(savedSparkEnv)
97+ }
8598 }
8699
87100 class ExceptionPythonServer (authHelper : SocketAuthHelper )
0 commit comments