Skip to content

Commit dff92bb

Browse files
authored
bpo-38870: Implement round tripping support for typed AST in ast.unparse (GH-17797)
1 parent e966af7 commit dff92bb

File tree

2 files changed

+56
-8
lines changed

2 files changed

+56
-8
lines changed

Lib/ast.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ def __init__(self):
648648
self._source = []
649649
self._buffer = []
650650
self._precedences = {}
651+
self._type_ignores = {}
651652
self._indent = 0
652653

653654
def interleave(self, inter, f, seq):
@@ -697,11 +698,15 @@ def buffer(self):
697698
return value
698699

699700
@contextmanager
700-
def block(self):
701+
def block(self, *, extra = None):
701702
"""A context manager for preparing the source for blocks. It adds
702703
the character':', increases the indentation on enter and decreases
703-
the indentation on exit."""
704+
the indentation on exit. If *extra* is given, it will be directly
705+
appended after the colon character.
706+
"""
704707
self.write(":")
708+
if extra:
709+
self.write(extra)
705710
self._indent += 1
706711
yield
707712
self._indent -= 1
@@ -748,6 +753,11 @@ def get_raw_docstring(self, node):
748753
if isinstance(node, Constant) and isinstance(node.value, str):
749754
return node
750755

756+
def get_type_comment(self, node):
757+
comment = self._type_ignores.get(node.lineno) or node.type_comment
758+
if comment is not None:
759+
return f" # type: {comment}"
760+
751761
def traverse(self, node):
752762
if isinstance(node, list):
753763
for item in node:
@@ -770,7 +780,12 @@ def _write_docstring_and_traverse_body(self, node):
770780
self.traverse(node.body)
771781

772782
def visit_Module(self, node):
783+
self._type_ignores = {
784+
ignore.lineno: f"ignore{ignore.tag}"
785+
for ignore in node.type_ignores
786+
}
773787
self._write_docstring_and_traverse_body(node)
788+
self._type_ignores.clear()
774789

775790
def visit_FunctionType(self, node):
776791
with self.delimit("(", ")"):
@@ -811,6 +826,8 @@ def visit_Assign(self, node):
811826
self.traverse(target)
812827
self.write(" = ")
813828
self.traverse(node.value)
829+
if type_comment := self.get_type_comment(node):
830+
self.write(type_comment)
814831

815832
def visit_AugAssign(self, node):
816833
self.fill()
@@ -966,7 +983,7 @@ def _function_helper(self, node, fill_suffix):
966983
if node.returns:
967984
self.write(" -> ")
968985
self.traverse(node.returns)
969-
with self.block():
986+
with self.block(extra=self.get_type_comment(node)):
970987
self._write_docstring_and_traverse_body(node)
971988

972989
def visit_For(self, node):
@@ -980,7 +997,7 @@ def _for_helper(self, fill, node):
980997
self.traverse(node.target)
981998
self.write(" in ")
982999
self.traverse(node.iter)
983-
with self.block():
1000+
with self.block(extra=self.get_type_comment(node)):
9841001
self.traverse(node.body)
9851002
if node.orelse:
9861003
self.fill("else")
@@ -1018,13 +1035,13 @@ def visit_While(self, node):
10181035
def visit_With(self, node):
10191036
self.fill("with ")
10201037
self.interleave(lambda: self.write(", "), self.traverse, node.items)
1021-
with self.block():
1038+
with self.block(extra=self.get_type_comment(node)):
10221039
self.traverse(node.body)
10231040

10241041
def visit_AsyncWith(self, node):
10251042
self.fill("async with ")
10261043
self.interleave(lambda: self.write(", "), self.traverse, node.items)
1027-
with self.block():
1044+
with self.block(extra=self.get_type_comment(node)):
10281045
self.traverse(node.body)
10291046

10301047
def visit_JoinedStr(self, node):

Lib/test/test_unparse.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,12 @@ class Foo: pass
108108
suite1
109109
"""
110110

111-
docstring_prefixes = [
111+
docstring_prefixes = (
112112
"",
113113
"class foo:\n ",
114114
"def foo():\n ",
115115
"async def foo():\n ",
116-
]
116+
)
117117

118118
class ASTTestCase(unittest.TestCase):
119119
def assertASTEqual(self, ast1, ast2):
@@ -340,6 +340,37 @@ def test_function_type(self):
340340
):
341341
self.check_ast_roundtrip(function_type, mode="func_type")
342342

343+
def test_type_comments(self):
344+
for statement in (
345+
"a = 5 # type:",
346+
"a = 5 # type: int",
347+
"a = 5 # type: int and more",
348+
"def x(): # type: () -> None\n\tpass",
349+
"def x(y): # type: (int) -> None and more\n\tpass",
350+
"async def x(): # type: () -> None\n\tpass",
351+
"async def x(y): # type: (int) -> None and more\n\tpass",
352+
"for x in y: # type: int\n\tpass",
353+
"async for x in y: # type: int\n\tpass",
354+
"with x(): # type: int\n\tpass",
355+
"async with x(): # type: int\n\tpass"
356+
):
357+
self.check_ast_roundtrip(statement, type_comments=True)
358+
359+
def test_type_ignore(self):
360+
for statement in (
361+
"a = 5 # type: ignore",
362+
"a = 5 # type: ignore and more",
363+
"def x(): # type: ignore\n\tpass",
364+
"def x(y): # type: ignore and more\n\tpass",
365+
"async def x(): # type: ignore\n\tpass",
366+
"async def x(y): # type: ignore and more\n\tpass",
367+
"for x in y: # type: ignore\n\tpass",
368+
"async for x in y: # type: ignore\n\tpass",
369+
"with x(): # type: ignore\n\tpass",
370+
"async with x(): # type: ignore\n\tpass"
371+
):
372+
self.check_ast_roundtrip(statement, type_comments=True)
373+
343374

344375
class CosmeticTestCase(ASTTestCase):
345376
"""Test if there are cosmetic issues caused by unnecesary additions"""

0 commit comments

Comments
 (0)