@@ -858,6 +858,47 @@ def test_multiple_broadcasts(self):
858858 self .assertEqual (N , size )
859859 self .assertEqual (checksum , csum )
860860
861+ def test_multithread_broadcast_pickle (self ):
862+ import threading
863+
864+ b1 = self .sc .broadcast (list (range (3 )))
865+ b2 = self .sc .broadcast (list (range (3 )))
866+
867+ def f1 (): return b1 .value
868+
869+ def f2 (): return b2 .value
870+
871+ funcs_num_pickled = {f1 : None , f2 : None }
872+
873+ def do_pickle (f , sc ):
874+ command = (f , None , sc .serializer , sc .serializer )
875+ ser = CloudPickleSerializer ()
876+ ser .dumps (command )
877+
878+ def process_vars (sc ):
879+ broadcast_vars = [x for x in sc ._pickled_broadcast_vars ]
880+ num_pickled = len (broadcast_vars )
881+ sc ._pickled_broadcast_vars .clear ()
882+ return num_pickled
883+
884+ def run (f , sc ):
885+ do_pickle (f , sc )
886+ funcs_num_pickled [f ] = process_vars (sc )
887+
888+ # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage
889+ do_pickle (f1 , self .sc )
890+
891+ # run all for f2, should only add/count/clear b2 from worker thread local storage
892+ t = threading .Thread (target = run , args = (f2 , self .sc ))
893+ t .start ()
894+ t .join ()
895+
896+ # count number of vars pickled in main thread, only b1 should be counted and cleared
897+ funcs_num_pickled [f1 ] = process_vars (self .sc )
898+
899+ self .assertEqual (funcs_num_pickled [f1 ], 1 )
900+ self .assertEqual (funcs_num_pickled [f2 ], 1 )
901+
861902 def test_large_closure (self ):
862903 N = 200000
863904 data = [float (i ) for i in xrange (N )]
0 commit comments