Skip to content

Commit b703f83

Browse files
committed
Added thread-safe broadcast pickle registry
1 parent 3ac6093 commit b703f83

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

python/pyspark/broadcast.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,29 @@ def __reduce__(self):
139139
return _from_id, (self._jbroadcast.id(),)
140140

141141

142+
class BroadcastPickleRegistry(object):
143+
""" Thread-safe registry for broadcast variables that have been pickled
144+
"""
145+
146+
def __init__(self, lock):
147+
self._registry = set()
148+
self._lock = lock
149+
150+
@property
151+
def lock(self):
152+
return self._lock
153+
154+
def add(self, bcast):
155+
with self._lock:
156+
self._registry.add(bcast)
157+
158+
def get_and_clear(self):
159+
with self._lock:
160+
registry_copy = self._registry.copy()
161+
self._registry.clear()
162+
return registry_copy
163+
164+
142165
if __name__ == "__main__":
143166
import doctest
144167
(failure_count, test_count) = doctest.testmod()

python/pyspark/context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from pyspark import accumulators
3232
from pyspark.accumulators import Accumulator
33-
from pyspark.broadcast import Broadcast
33+
from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
3434
from pyspark.conf import SparkConf
3535
from pyspark.files import SparkFiles
3636
from pyspark.java_gateway import launch_gateway
@@ -195,7 +195,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
195195
# This allows other code to determine which Broadcast instances have
196196
# been pickled, so it can determine which Java broadcast objects to
197197
# send.
198-
self._pickled_broadcast_vars = set()
198+
self._pickled_broadcast_registry = BroadcastPickleRegistry(self._lock)
199199

200200
SparkFiles._sc = self
201201
root_dir = SparkFiles.getRootDirectory()
@@ -793,7 +793,7 @@ def broadcast(self, value):
793793
object for reading it in distributed functions. The variable will
794794
be sent to each cluster only once.
795795
"""
796-
return Broadcast(self, value, self._pickled_broadcast_vars)
796+
return Broadcast(self, value, self._pickled_broadcast_registry)
797797

798798
def accumulator(self, value, accum_param=None):
799799
"""

python/pyspark/rdd.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,13 +2370,14 @@ def toLocalIterator(self):
23702370
def _prepare_for_python_RDD(sc, command):
23712371
# the serialized command will be compressed by broadcast
23722372
ser = CloudPickleSerializer()
2373-
pickled_command = ser.dumps(command)
2374-
if len(pickled_command) > (1 << 20): # 1M
2375-
# The broadcast will have same life cycle as created PythonRDD
2376-
broadcast = sc.broadcast(pickled_command)
2377-
pickled_command = ser.dumps(broadcast)
2378-
broadcast_vars = [x._jbroadcast for x in sc._pickled_broadcast_vars]
2379-
sc._pickled_broadcast_vars.clear()
2373+
with sc._pickled_broadcast_registry.lock:
2374+
pickled_command = ser.dumps(command)
2375+
if len(pickled_command) > (1 << 20): # 1M
2376+
# The broadcast will have same life cycle as created PythonRDD
2377+
broadcast = sc.broadcast(pickled_command)
2378+
pickled_command = ser.dumps(broadcast)
2379+
pickled_broadcast_vars = sc._pickled_broadcast_registry.get_and_clear()
2380+
broadcast_vars = [x._jbroadcast for x in pickled_broadcast_vars]
23802381
return pickled_command, broadcast_vars, sc.environment, sc._python_includes
23812382

23822383

0 commit comments

Comments
 (0)