2121from threading import Lock
2222from tempfile import NamedTemporaryFile
2323
24+ from pip .commands .install import InstallCommand as pip_InstallCommand
25+
2426from py4j .java_collections import ListConverter
2527
2628from pyspark import accumulators
@@ -62,9 +64,9 @@ class SparkContext(object):
6264 _next_accum_id = 0
6365 _active_spark_context = None
6466 _lock = Lock ()
65- _python_includes = None # zip and egg files that need to be added to PYTHONPATH
67+ _python_includes = None # whl, egg, zip and jar files that need to be added to PYTHONPATH
6668
67- PACKAGE_EXTENSIONS = ('.zip ' , '.egg' , '.jar' )
69+ PACKAGE_EXTENSIONS = ('.whl ' , '.egg' , '.zip ' , '.jar' )
6870
6971 def __init__ (self , master = None , appName = None , sparkHome = None , pyFiles = None ,
7072 environment = None , batchSize = 0 , serializer = PickleSerializer (), conf = None ,
@@ -77,9 +79,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
7779 (e.g. mesos://host:port, spark://host:port, local[4]).
7880 :param appName: A name for your job, to display on the cluster web UI.
7981 :param sparkHome: Location where Spark is installed on cluster nodes.
80- :param pyFiles: Collection of .zip or .py files to send to the cluster
81- and add to PYTHONPATH. These can be paths on the local file
82- system or HDFS, HTTP, HTTPS, or FTP URLs.
82+ :param pyFiles: Collection of .py, .whl, .egg or .zip files to send
83+ to the cluster and add to PYTHONPATH. These can be paths on
84+ the local file system or HDFS, HTTP, HTTPS, or FTP URLs.
8385 :param environment: A dictionary of environment variables to set on
8486 worker nodes.
8587 :param batchSize: The number of Python objects represented as a single
@@ -178,18 +180,24 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
178180 sys .path .insert (1 , root_dir )
179181
180182 # Deploy any code dependencies specified in the constructor
183+ # Wheel files will be installed by pip later.
181184 self ._python_includes = list ()
182- for path in (pyFiles or []):
183- self .addPyFile (path )
185+ if pyFiles :
186+ for path in pyFiles :
187+ self .addFile (path )
188+ self ._include_python_packages (paths = pyFiles )
189+ else :
190+ pyFiles = []
184191
185192 # Deploy code dependencies set by spark-submit; these will already have been added
186- # with SparkContext.addFile, so we just need to add them to the PYTHONPATH
187- for path in self ._conf .get ("spark.submit.pyFiles" , "" ).split ("," ):
188- if path != "" :
189- (dirname , filename ) = os .path .split (path )
190- if filename [- 4 :].lower () in self .PACKAGE_EXTENSIONS :
191- self ._python_includes .append (filename )
192- sys .path .insert (1 , os .path .join (SparkFiles .getRootDirectory (), filename ))
193+ # with SparkContext.addFile, so we just need to include them.
194+ # Wheel files will be installed by pip later.
195+ spark_submit_pyfiles = self ._conf .get ("spark.submit.pyFiles" , "" ).split ("," )
196+ if spark_submit_pyfiles :
197+ self ._include_python_packages (paths = spark_submit_pyfiles )
198+
199+ # Install all wheel files at once.
200+ self ._install_wheel_files (paths = pyFiles + spark_submit_pyfiles )
193201
194202 # Create a temporary directory inside spark.local.dir:
195203 local_dir = self ._jvm .org .apache .spark .util .Utils .getLocalDir (self ._jsc .sc ().conf ())
@@ -693,23 +701,71 @@ def clearFiles(self):
693701 Clear the job's list of files added by L{addFile} or L{addPyFile} so
694702 that they do not get downloaded to any new nodes.
695703 """
696- # TODO: remove added .py or .zip files from the PYTHONPATH?
704+ # TODO: remove added .py, .whl, .egg or .zip files from the PYTHONPATH?
697705 self ._jsc .sc ().clearFiles ()
698706
699707 def addPyFile (self , path ):
700708 """
701- Add a .py or .zip dependency for all tasks to be executed on this
702- SparkContext in the future. The C{path} passed can be either a local
703- file, a file in HDFS (or other Hadoop-supported filesystems), or an
704- HTTP, HTTPS or FTP URI.
709+ Add a .py, .whl, .egg or .zip dependency for all tasks to be
710+ executed on this SparkContext in the future. The C{path} passed can
711+ be either a local file, a file in HDFS (or other Hadoop-supported
712+ filesystems), or an HTTP, HTTPS or FTP URI.
705713 """
706714 self .addFile (path )
707- (dirname , filename ) = os .path .split (path ) # dirname may be directory or HDFS/S3 prefix
715+ self ._include_python_packages (paths = (path ,))
716+ self ._install_wheel_files (paths = (path ,))
708717
709- if filename [- 4 :].lower () in self .PACKAGE_EXTENSIONS :
710- self ._python_includes .append (filename )
711- # for tests in local mode
712- sys .path .insert (1 , os .path .join (SparkFiles .getRootDirectory (), filename ))
718+ def _include_python_packages (self , paths ):
719+ """
720+ Add Python package dependencies. Python packages (except for .whl) are
721+ added to PYTHONPATH.
722+ """
723+ root_dir = SparkFiles .getRootDirectory ()
724+ for path in paths :
725+ basename = os .path .basename (path )
726+ extname = os .path .splitext (basename )[1 ].lower ()
727+ if extname in self .PACKAGE_EXTENSIONS \
728+ and basename not in self ._python_includes :
729+ self ._python_includes .append (basename )
730+ if extname != '.whl' :
731+ # Prepend the python package (except for *.whl) to sys.path
732+ sys .path .insert (1 , os .path .join (root_dir , basename ))
733+
734+ def _install_wheel_files (
735+ self ,
736+ paths ,
737+ quiet = True ,
738+ upgrade = True ,
739+ no_deps = True ,
740+ no_index = True ,
741+ ):
742+ """
743+ Install .whl files at once by pip install.
744+ """
745+ root_dir = SparkFiles .getRootDirectory ()
746+ paths = {
747+ os .path .join (root_dir , os .path .basename (path ))
748+ for path in paths
749+ if os .path .splitext (path )[1 ].lower () == '.whl'
750+ }
751+ if not paths :
752+ return
753+
754+ pip_args = [
755+ '--find-links' , root_dir ,
756+ '--target' , os .path .join (root_dir , 'site-packages' ),
757+ ]
758+ if quiet :
759+ pip_args .append ('--quiet' )
760+ if upgrade :
761+ pip_args .append ('--upgrade' )
762+ if no_deps :
763+ pip_args .append ('--no-deps' )
764+ if no_index :
765+ pip_args .append ('--no-index' )
766+ pip_args .extend (paths )
767+
768+ pip_InstallCommand ().main (args = pip_args )
713769
714770 def setCheckpointDir (self , dirName ):
715771 """
0 commit comments