1616#
1717
1818from functools import total_ordering
19+ from importlib import import_module
20+ import inspect
1921import itertools
2022import os
23+ from pkgutil import iter_modules
2124import re
22- import glob
25+ import sys
26+ import unittest
2327
2428from sparktestsupport import SPARK_HOME
2529
30+
2631all_modules = []
32+ pyspark_path = os .path .join (SPARK_HOME , "python" )
33+ sys .path .append (pyspark_path )
34+
35+
36+ def _contain_unittests_class (module_name ):
37+ """
38+ Check if the module with specific module_name has classes are derived from unittest.TestCase.
39+ Such as:
40+ pyspark.tests.test_appsubmit, it will return True, because there is SparkSubmitTests which is
41+ included under the module of pyspark.tests.test_appsubmit, inherits from unittest.TestCase.
42+ ``
43+ :param module_name: the complete name of module to be checked.
44+ :return: True if contains unittest classes otherwise False.
45+ An ``ModuleNotFoundError`` will raise if the module is not found
46+ """
47+ _module = import_module (module_name )
48+ for _ , _class in inspect .getmembers (_module , inspect .isclass ):
49+ if issubclass (_class , unittest .TestCase ):
50+ return True
51+ return False
2752
2853
2954def _discover_python_unittests (paths ):
55+ """
56+ Discover the python module which contains unittests under paths.
57+ Such as:
58+ ['pyspark/tests'], it will return the set of module name under the path of pyspark/tests, like
59+ {'pyspark.tests.test_appsubmit', 'pyspark.tests.test_broadcast', ...}
60+ :param paths: paths of module to be discovered.
61+ :return: A set of complete test module name discovered udner the paths
62+ """
3063 if not paths :
31- return set ([] )
32- tests = set ([] )
33- pyspark_path = os . path . join ( SPARK_HOME , "python" )
64+ return set ()
65+ tests = set ()
66+
3467 for path in paths :
35- # Discover the test*.py in every path
36- files = glob . glob ( os . path .join ( pyspark_path , path , "test_*.py" ) )
37- for f in files :
38- # Convert 'pyspark_path/pyspark/tests/test_abc.py' to 'pyspark.tests.test_abc'
39- file2module = os . path . relpath ( f , pyspark_path )[: - 3 ]. replace ( "/" , "." )
40- tests .add (file2module )
68+ real_path = os . path . join ( pyspark_path , path )
69+ _prefix = path .replace ( '/' , '.' )
70+ # iter modules under the specific tests path
71+ for module in iter_modules ([ real_path ], prefix = _prefix + '.' ):
72+ if _contain_unittests_class ( module . name ):
73+ tests .add (module . name )
4174 return tests
4275
4376
@@ -56,7 +89,6 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=
5689 should_run_build_tests = False , python_test_paths = (), python_excluded_tests = ()):
5790 """
5891 Define a new module.
59-
6092 :param name: A short module name, for display in logging and error messages.
6193 :param dependencies: A set of dependencies for this module. This should only include direct
6294 dependencies; transitive dependencies are resolved automatically.
0 commit comments