Skip to content

Commit 7d4ce24

Browse files
committed
rewrite: add support for "import module.submodule".
PiperOrigin-RevId: 623629676
1 parent 8808ade commit 7d4ce24

7 files changed

Lines changed: 95 additions & 3 deletions

File tree

pytype/rewrite/abstract/classes_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,23 @@ def test_get_attribute(self):
5858
self.assertEqual(instance.get_attribute('x'), self.ctx.consts[3])
5959

6060

61+
class ModuleTest(test_utils.ContextfulTestBase):
62+
63+
def test_instance_attribute(self):
64+
attr = classes.Module(self.ctx, 'os').get_attribute('name')
65+
self.assertIsInstance(attr, classes.FrozenInstance)
66+
self.assertEqual(attr.cls.name, 'str')
67+
68+
def test_class_attribute(self):
69+
attr = classes.Module(self.ctx, 'os').get_attribute('__name__')
70+
self.assertIsInstance(attr, classes.FrozenInstance)
71+
self.assertEqual(attr.cls.name, 'str')
72+
73+
def test_submodule(self):
74+
attr = classes.Module(self.ctx, 'os').get_attribute('path')
75+
self.assertIsInstance(attr, classes.Module)
76+
self.assertEqual(attr.name, 'os.path')
77+
78+
6179
if __name__ == '__main__':
6280
unittest.main()

pytype/rewrite/convert.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,9 @@ def _pytd_type_to_value(self, typ: pytd.Type) -> abstract.BaseValue:
107107
f'Abstract conversion not yet implemented for {typ}')
108108
else:
109109
raise ValueError(f'Cannot convert {typ} to an abstract value')
110+
111+
def pytd_alias_to_value(self, alias: pytd.Alias) -> abstract.BaseValue:
112+
if isinstance(alias.type, pytd.Module):
113+
return abstract.Module(self._ctx, alias.type.module_name)
114+
raise NotImplementedError(
115+
f'Abstract conversion not yet implemented for {alias}')

pytype/rewrite/convert_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,5 +92,14 @@ class D: ...
9292
self.assertEqual(nested_class.name, 'D')
9393

9494

95+
class PytdAliasToValueTest(ConverterTestBase):
96+
97+
def test_alias(self):
98+
alias = self.build_pytd('import os.path', name='os.path')
99+
module = self.conv.pytd_alias_to_value(alias)
100+
self.assertIsInstance(module, abstract.Module)
101+
self.assertEqual(module.name, 'os.path')
102+
103+
95104
if __name__ == '__main__':
96105
unittest.main()

pytype/rewrite/frame.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,8 +533,17 @@ def byte_LOAD_METHOD(self, opcode):
533533

534534
def byte_IMPORT_NAME(self, opcode):
535535
full_name = opcode.argval
536-
unused_level_var, unused_fromlist = self._stack.popn(2)
537-
module = abstract.Module(self._ctx, full_name)
536+
unused_level_var, fromlist = self._stack.popn(2)
537+
# The IMPORT_NAME for an "import a.b.c" will push the module "a".
538+
# However, for "from a.b.c import Foo" it'll push the module "a.b.c". Those
539+
# two cases are distinguished by whether fromlist is None or not.
540+
try:
541+
abstract.get_atomic_constant(fromlist, None)
542+
except ValueError:
543+
module_name = full_name
544+
else:
545+
module_name = full_name.split('.', 1)[0] # "a.b.c" -> "a"
546+
module = abstract.Module(self._ctx, module_name)
538547
return self._stack.push(module.to_variable())
539548

540549
# ---------------------------------------------------------------

pytype/rewrite/load_abstract.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def _load_pytd_node(self, pytd_node: pytd.Node) -> abstract.BaseValue:
6060
elif isinstance(pytd_node, pytd.Constant):
6161
typ = self._ctx.abstract_converter.pytd_type_to_value(pytd_node.type)
6262
return typ.instantiate()
63+
elif isinstance(pytd_node, pytd.Alias):
64+
return self._ctx.abstract_converter.pytd_alias_to_value(pytd_node)
6365
else:
6466
raise NotImplementedError(f'I do not know how to load {pytd_node}')
6567

pytype/rewrite/load_abstract_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ def test_stdlib(self):
5050
self.assertIsInstance(name, abstract.FrozenInstance)
5151
self.assertEqual(name.cls.name, 'str')
5252

53+
def test_submodule(self):
54+
submodule = self.ctx.abstract_loader.load_value('os', 'path')
55+
self.assertIsInstance(submodule, abstract.Module)
56+
self.assertEqual(submodule.name, 'os.path')
57+
5358

5459
class LoadRawTypeTest(test_utils.ContextfulTestBase):
5560

pytype/rewrite/tests/test_basic.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
from pytype.tests import test_base
44

55

6-
class BasicTest(test_base.BaseTest):
6+
class RewriteTest(test_base.BaseTest):
7+
8+
def setUp(self):
9+
super().setUp()
10+
self.options.tweak(use_rewrite=True)
11+
12+
13+
class BasicTest(RewriteTest):
714
"""Basic functional tests."""
815

916
def setUp(self):
@@ -90,6 +97,10 @@ def __init__(self) -> None: ...
9097
def f(self) -> int: ...
9198
""")
9299

100+
101+
class ImportsTest(RewriteTest):
102+
"""Import tests."""
103+
93104
def test_import(self):
94105
self.Check("""
95106
import os
@@ -102,6 +113,38 @@ def test_builtins(self):
102113
assert_type(__builtins__.int, "Type[int]")
103114
""")
104115

116+
def test_dotted_import(self):
117+
self.Check("""
118+
import os.path
119+
assert_type(os.path, "module")
120+
""")
121+
122+
@test_base.skip('Not yet implemented')
123+
def test_from_import(self):
124+
self.Check("""
125+
from os import name, path
126+
assert_type(name, "str")
127+
assert_type(path, "module")
128+
""")
129+
130+
@test_base.skip('Not yet implemented')
131+
def test_errors(self):
132+
self.CheckWithErrors("""
133+
import nonsense # import-error
134+
import os.nonsense # import-error
135+
from os import nonsense # import-error
136+
""")
137+
138+
def test_aliases(self):
139+
self.Check("""
140+
import os as so
141+
assert_type(so.name, "str")
142+
143+
# Not yet implemented:
144+
# import os.path as path1
145+
# from os import path as path2
146+
""")
147+
105148

106149
if __name__ == '__main__':
107150
test_base.main()

0 commit comments

Comments
 (0)