Skip to content

Commit 323dc3a

Browse files
author
Marcelo Vanzin
committed
[PYSPARK] Update py4j to version 0.10.7.
(cherry picked from commit cc613b5) Signed-off-by: Marcelo Vanzin <[email protected]>
1 parent eab10f9 commit 323dc3a

File tree

32 files changed

+418
-116
lines changed

32 files changed

+418
-116
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
263263
(New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf)
264264
(The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net)
265265
(The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net)
266-
(The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/)
266+
(The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/)
267267
(Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/)
268268
(BSD licence) sbt and sbt-launch-lib.bash
269269
(BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE)

bin/pyspark

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh
2525
export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]"
2626

2727
# In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option
28-
# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
28+
# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
2929
# to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver
3030
# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython
3131
# and executor Python executables.
3232

3333
# Fail noisily if removed options are set
3434
if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then
35-
echo "Error in pyspark startup:"
35+
echo "Error in pyspark startup:"
3636
echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead."
3737
exit 1
3838
fi
@@ -57,7 +57,7 @@ export PYSPARK_PYTHON
5757

5858
# Add the PySpark classes to the Python path:
5959
export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH"
60-
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH"
60+
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH"
6161

6262
# Load the PySpark shell.py script when ./pyspark is used interactively:
6363
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"

bin/pyspark2.cmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
3030
)
3131

3232
set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH%
33-
set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH%
33+
set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH%
3434

3535
set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
3636
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py

core/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@
344344
<dependency>
345345
<groupId>net.sf.py4j</groupId>
346346
<artifactId>py4j</artifactId>
347-
<version>0.10.6</version>
347+
<version>0.10.7</version>
348348
</dependency>
349349
<dependency>
350350
<groupId>org.apache.spark</groupId>

core/src/main/scala/org/apache/spark/SecurityManager.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717

1818
package org.apache.spark
1919

20-
import java.lang.{Byte => JByte}
2120
import java.net.{Authenticator, PasswordAuthentication}
2221
import java.nio.charset.StandardCharsets.UTF_8
23-
import java.security.{KeyStore, SecureRandom}
22+
import java.security.KeyStore
2423
import java.security.cert.X509Certificate
2524
import javax.net.ssl._
2625

27-
import com.google.common.hash.HashCodes
2826
import com.google.common.io.Files
2927
import org.apache.hadoop.io.Text
3028
import org.apache.hadoop.security.{Credentials, UserGroupInformation}
@@ -542,13 +540,8 @@ private[spark] class SecurityManager(
542540
return
543541
}
544542

545-
val rnd = new SecureRandom()
546-
val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE
547-
val secretBytes = new Array[Byte](length)
548-
rnd.nextBytes(secretBytes)
549-
543+
secretKey = Utils.createSecret(sparkConf)
550544
val creds = new Credentials()
551-
secretKey = HashCodes.fromBytes(secretBytes).toString()
552545
creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8))
553546
UserGroupInformation.getCurrentUser().addCredentials(creds)
554547
}

core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,39 @@
1717

1818
package org.apache.spark.api.python
1919

20-
import java.io.DataOutputStream
21-
import java.net.Socket
20+
import java.io.{DataOutputStream, File, FileOutputStream}
21+
import java.net.InetAddress
22+
import java.nio.charset.StandardCharsets.UTF_8
23+
import java.nio.file.Files
2224

2325
import py4j.GatewayServer
2426

27+
import org.apache.spark.SparkConf
2528
import org.apache.spark.internal.Logging
2629
import org.apache.spark.util.Utils
2730

2831
/**
29-
* Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port
30-
* back to its caller via a callback port specified by the caller.
32+
* Process that starts a Py4J GatewayServer on an ephemeral port.
3133
*
3234
* This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py).
3335
*/
3436
private[spark] object PythonGatewayServer extends Logging {
3537
initializeLogIfNecessary(true)
3638

37-
def main(args: Array[String]): Unit = Utils.tryOrExit {
38-
// Start a GatewayServer on an ephemeral port
39-
val gatewayServer: GatewayServer = new GatewayServer(null, 0)
39+
def main(args: Array[String]): Unit = {
40+
val secret = Utils.createSecret(new SparkConf())
41+
42+
// Start a GatewayServer on an ephemeral port. Make sure the callback client is configured
43+
// with the same secret, in case the app needs callbacks from the JVM to the underlying
44+
// python processes.
45+
val localhost = InetAddress.getLoopbackAddress()
46+
val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
47+
.authToken(secret)
48+
.javaPort(0)
49+
.javaAddress(localhost)
50+
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
51+
.build()
52+
4053
gatewayServer.start()
4154
val boundPort: Int = gatewayServer.getListeningPort
4255
if (boundPort == -1) {
@@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging {
4659
logDebug(s"Started PythonGatewayServer on port $boundPort")
4760
}
4861

49-
// Communicate the bound port back to the caller via the caller-specified callback port
50-
val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
51-
val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
52-
logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
53-
val callbackSocket = new Socket(callbackHost, callbackPort)
54-
val dos = new DataOutputStream(callbackSocket.getOutputStream)
62+
// Communicate the connection information back to the python process by writing the
63+
// information in the requested file. This needs to match the read side in java_gateway.py.
64+
val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH"))
65+
val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(),
66+
"connection", ".info").toFile()
67+
68+
val dos = new DataOutputStream(new FileOutputStream(tmpPath))
5569
dos.writeInt(boundPort)
70+
71+
val secretBytes = secret.getBytes(UTF_8)
72+
dos.writeInt(secretBytes.length)
73+
dos.write(secretBytes, 0, secretBytes.length)
5674
dos.close()
57-
callbackSocket.close()
75+
76+
if (!tmpPath.renameTo(connectionInfoPath)) {
77+
logError(s"Unable to write connection information to $connectionInfoPath.")
78+
System.exit(1)
79+
}
5880

5981
// Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
6082
while (System.in.read() != -1) {

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast
3838
import org.apache.spark.input.PortableDataStream
3939
import org.apache.spark.internal.Logging
4040
import org.apache.spark.rdd.RDD
41+
import org.apache.spark.security.SocketAuthHelper
4142
import org.apache.spark.util._
4243

4344

@@ -107,6 +108,12 @@ private[spark] object PythonRDD extends Logging {
107108
// remember the broadcasts sent to each worker
108109
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
109110

111+
// Authentication helper used when serving iterator data.
112+
private lazy val authHelper = {
113+
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
114+
new SocketAuthHelper(conf)
115+
}
116+
110117
def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = {
111118
synchronized {
112119
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
@@ -129,12 +136,13 @@ private[spark] object PythonRDD extends Logging {
129136
* (effectively a collect()), but allows you to run on a certain subset of partitions,
130137
* or to enable local execution.
131138
*
132-
* @return the port number of a local socket which serves the data collected from this job.
139+
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
140+
* data collected from this job, and the secret for authentication.
133141
*/
134142
def runJob(
135143
sc: SparkContext,
136144
rdd: JavaRDD[Array[Byte]],
137-
partitions: JArrayList[Int]): Int = {
145+
partitions: JArrayList[Int]): Array[Any] = {
138146
type ByteArray = Array[Byte]
139147
type UnrolledPartition = Array[ByteArray]
140148
val allPartitions: Array[UnrolledPartition] =
@@ -147,13 +155,14 @@ private[spark] object PythonRDD extends Logging {
147155
/**
148156
* A helper function to collect an RDD as an iterator, then serve it via socket.
149157
*
150-
* @return the port number of a local socket which serves the data collected from this job.
158+
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
159+
* data collected from this job, and the secret for authentication.
151160
*/
152-
def collectAndServe[T](rdd: RDD[T]): Int = {
161+
def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
153162
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
154163
}
155164

156-
def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = {
165+
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
157166
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
158167
}
159168

@@ -384,8 +393,11 @@ private[spark] object PythonRDD extends Logging {
384393
* and send them into this connection.
385394
*
386395
* The thread will terminate after all the data are sent or any exceptions happen.
396+
*
397+
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
398+
* data collected from this job, and the secret for authentication.
387399
*/
388-
def serveIterator[T](items: Iterator[T], threadName: String): Int = {
400+
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
389401
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
390402
// Close the socket if no connection in 15 seconds
391403
serverSocket.setSoTimeout(15000)
@@ -395,11 +407,14 @@ private[spark] object PythonRDD extends Logging {
395407
override def run() {
396408
try {
397409
val sock = serverSocket.accept()
410+
authHelper.authClient(sock)
411+
398412
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
399413
Utils.tryWithSafeFinally {
400414
writeIteratorToStream(items, out)
401415
} {
402416
out.close()
417+
sock.close()
403418
}
404419
} catch {
405420
case NonFatal(e) =>
@@ -410,7 +425,7 @@ private[spark] object PythonRDD extends Logging {
410425
}
411426
}.start()
412427

413-
serverSocket.getLocalPort
428+
Array(serverSocket.getLocalPort, authHelper.secret)
414429
}
415430

416431
private def getMergedConf(confAsMap: java.util.HashMap[String, String],

core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ private[spark] object PythonUtils {
3232
val pythonPath = new ArrayBuffer[String]
3333
for (sparkHome <- sys.env.get("SPARK_HOME")) {
3434
pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator)
35-
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator)
35+
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator)
3636
}
3737
pythonPath ++= SparkContext.jarOfObject(this)
3838
pythonPath.mkString(File.pathSeparator)

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import scala.collection.mutable
2727

2828
import org.apache.spark._
2929
import org.apache.spark.internal.Logging
30+
import org.apache.spark.security.SocketAuthHelper
3031
import org.apache.spark.util.{RedirectThread, Utils}
3132

3233
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
@@ -45,6 +46,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
4546
!System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled
4647
}
4748

49+
50+
private val authHelper = new SocketAuthHelper(SparkEnv.get.conf)
51+
4852
var daemon: Process = null
4953
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
5054
var daemonPort: Int = 0
@@ -85,6 +89,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
8589
if (pid < 0) {
8690
throw new IllegalStateException("Python daemon failed to launch worker with code " + pid)
8791
}
92+
93+
authHelper.authToServer(socket)
8894
daemonWorkers.put(socket, pid)
8995
socket
9096
}
@@ -122,25 +128,24 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
122128
workerEnv.put("PYTHONPATH", pythonPath)
123129
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
124130
workerEnv.put("PYTHONUNBUFFERED", "YES")
131+
workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString)
132+
workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
125133
val worker = pb.start()
126134

127135
// Redirect worker stdout and stderr
128136
redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream)
129137

130-
// Tell the worker our port
131-
val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8)
132-
out.write(serverSocket.getLocalPort + "\n")
133-
out.flush()
134-
135-
// Wait for it to connect to our socket
138+
// Wait for it to connect to our socket, and validate the auth secret.
136139
serverSocket.setSoTimeout(10000)
140+
137141
try {
138142
val socket = serverSocket.accept()
143+
authHelper.authClient(socket)
139144
simpleWorkers.put(socket, worker)
140145
return socket
141146
} catch {
142147
case e: Exception =>
143-
throw new SparkException("Python worker did not connect back in time", e)
148+
throw new SparkException("Python worker failed to connect back.", e)
144149
}
145150
} finally {
146151
if (serverSocket != null) {
@@ -163,6 +168,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
163168
val workerEnv = pb.environment()
164169
workerEnv.putAll(envVars.asJava)
165170
workerEnv.put("PYTHONPATH", pythonPath)
171+
workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
166172
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
167173
workerEnv.put("PYTHONUNBUFFERED", "YES")
168174
daemon = pb.start()
@@ -172,7 +178,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
172178

173179
// Redirect daemon stdout and stderr
174180
redirectStreamsToStderr(in, daemon.getErrorStream)
175-
176181
} catch {
177182
case e: Exception =>
178183

core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.deploy
1919

2020
import java.io.File
21-
import java.net.URI
21+
import java.net.{InetAddress, URI}
2222

2323
import scala.collection.JavaConverters._
2424
import scala.collection.mutable.ArrayBuffer
@@ -39,6 +39,7 @@ object PythonRunner {
3939
val pyFiles = args(1)
4040
val otherArgs = args.slice(2, args.length)
4141
val sparkConf = new SparkConf()
42+
val secret = Utils.createSecret(sparkConf)
4243
val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON)
4344
.orElse(sparkConf.get(PYSPARK_PYTHON))
4445
.orElse(sys.env.get("PYSPARK_DRIVER_PYTHON"))
@@ -51,7 +52,13 @@ object PythonRunner {
5152

5253
// Launch a Py4J gateway server for the process to connect to; this will let it see our
5354
// Java system properties and such
54-
val gatewayServer = new py4j.GatewayServer(null, 0)
55+
val localhost = InetAddress.getLoopbackAddress()
56+
val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder()
57+
.authToken(secret)
58+
.javaPort(0)
59+
.javaAddress(localhost)
60+
.callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
61+
.build()
5562
val thread = new Thread(new Runnable() {
5663
override def run(): Unit = Utils.logUncaughtExceptions {
5764
gatewayServer.start()
@@ -82,6 +89,7 @@ object PythonRunner {
8289
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
8390
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
8491
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
92+
env.put("PYSPARK_GATEWAY_SECRET", secret)
8593
// pass conf spark.pyspark.python to python process, the only way to pass info to
8694
// python process is through environment variable.
8795
sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))

0 commit comments

Comments
 (0)