Skip to content

Commit 22015a5

Browse files
committed
Change the 'test_* name discover' to 'unittest module discover'
1 parent 23f2c87 commit 22015a5

File tree

1 file changed

+43
-11
lines changed

1 file changed

+43
-11
lines changed

dev/sparktestsupport/modules.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,61 @@
1616
#
1717

1818
from functools import total_ordering
19+
from importlib import import_module
20+
import inspect
1921
import itertools
2022
import os
23+
from pkgutil import iter_modules
2124
import re
22-
import glob
25+
import sys
26+
import unittest
2327

2428
from sparktestsupport import SPARK_HOME
2529

30+
2631
all_modules = []
32+
pyspark_path = os.path.join(SPARK_HOME, "python")
33+
sys.path.append(pyspark_path)
34+
35+
36+
def _contain_unittests_class(module_name):
37+
"""
38+
Check if the module with specific module_name has classes are derived from unittest.TestCase.
39+
Such as:
40+
pyspark.tests.test_appsubmit, it will return True, because there is SparkSubmitTests which is
41+
included under the module of pyspark.tests.test_appsubmit, inherits from unittest.TestCase.
42+
``
43+
:param module_name: the complete name of module to be checked.
44+
:return: True if contains unittest classes otherwise False.
45+
An ``ModuleNotFoundError`` will raise if the module is not found
46+
"""
47+
_module = import_module(module_name)
48+
for _, _class in inspect.getmembers(_module, inspect.isclass):
49+
if issubclass(_class, unittest.TestCase):
50+
return True
51+
return False
2752

2853

2954
def _discover_python_unittests(paths):
55+
"""
56+
Discover the python module which contains unittests under paths.
57+
Such as:
58+
['pyspark/tests'], it will return the set of module name under the path of pyspark/tests, like
59+
{'pyspark.tests.test_appsubmit', 'pyspark.tests.test_broadcast', ...}
60+
:param paths: paths of module to be discovered.
61+
:return: A set of complete test module name discovered udner the paths
62+
"""
3063
if not paths:
31-
return set([])
32-
tests = set([])
33-
pyspark_path = os.path.join(SPARK_HOME, "python")
64+
return set()
65+
tests = set()
66+
3467
for path in paths:
35-
# Discover the test*.py in every path
36-
files = glob.glob(os.path.join(pyspark_path, path, "test_*.py"))
37-
for f in files:
38-
# Convert 'pyspark_path/pyspark/tests/test_abc.py' to 'pyspark.tests.test_abc'
39-
file2module = os.path.relpath(f, pyspark_path)[:-3].replace("/", ".")
40-
tests.add(file2module)
68+
real_path = os.path.join(pyspark_path, path)
69+
_prefix = path.replace('/', '.')
70+
# iter modules under the specific tests path
71+
for module in iter_modules([real_path], prefix=_prefix+'.'):
72+
if _contain_unittests_class(module.name):
73+
tests.add(module.name)
4174
return tests
4275

4376

@@ -56,7 +89,6 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=
5689
should_run_build_tests=False, python_test_paths=(), python_excluded_tests=()):
5790
"""
5891
Define a new module.
59-
6092
:param name: A short module name, for display in logging and error messages.
6193
:param dependencies: A set of dependencies for this module. This should only include direct
6294
dependencies; transitive dependencies are resolved automatically.

0 commit comments

Comments
 (0)