@@ -667,6 +667,22 @@ def set_precedence(self, precedence, *nodes):
667667 for node in nodes :
668668 self ._precedences [node ] = precedence
669669
670+ def get_raw_docstring (self , node ):
671+ """If a docstring node is found in the body of the *node* parameter,
672+ return that docstring node, None otherwise.
673+
674+ Logic mirrored from ``_PyAST_GetDocString``."""
675+ if not isinstance (
676+ node , (AsyncFunctionDef , FunctionDef , ClassDef , Module )
677+ ) or len (node .body ) < 1 :
678+ return None
679+ node = node .body [0 ]
680+ if not isinstance (node , Expr ):
681+ return None
682+ node = node .value
683+ if isinstance (node , Constant ) and isinstance (node .value , str ):
684+ return node
685+
670686 def traverse (self , node ):
671687 if isinstance (node , list ):
672688 for item in node :
@@ -681,9 +697,15 @@ def visit(self, node):
681697 self .traverse (node )
682698 return "" .join (self ._source )
683699
700+ def _write_docstring_and_traverse_body (self , node ):
701+ if (docstring := self .get_raw_docstring (node )):
702+ self ._write_docstring (docstring )
703+ self .traverse (node .body [1 :])
704+ else :
705+ self .traverse (node .body )
706+
684707 def visit_Module (self , node ):
685- for subnode in node .body :
686- self .traverse (subnode )
708+ self ._write_docstring_and_traverse_body (node )
687709
688710 def visit_Expr (self , node ):
689711 self .fill ()
@@ -850,15 +872,15 @@ def visit_ClassDef(self, node):
850872 self .traverse (e )
851873
852874 with self .block ():
853- self .traverse (node . body )
875+ self ._write_docstring_and_traverse_body (node )
854876
855877 def visit_FunctionDef (self , node ):
856- self .__FunctionDef_helper (node , "def" )
878+ self ._function_helper (node , "def" )
857879
858880 def visit_AsyncFunctionDef (self , node ):
859- self .__FunctionDef_helper (node , "async def" )
881+ self ._function_helper (node , "async def" )
860882
861- def __FunctionDef_helper (self , node , fill_suffix ):
883+ def _function_helper (self , node , fill_suffix ):
862884 self .write ("\n " )
863885 for deco in node .decorator_list :
864886 self .fill ("@" )
@@ -871,15 +893,15 @@ def __FunctionDef_helper(self, node, fill_suffix):
871893 self .write (" -> " )
872894 self .traverse (node .returns )
873895 with self .block ():
874- self .traverse (node . body )
896+ self ._write_docstring_and_traverse_body (node )
875897
876898 def visit_For (self , node ):
877- self .__For_helper ("for " , node )
899+ self ._for_helper ("for " , node )
878900
879901 def visit_AsyncFor (self , node ):
880- self .__For_helper ("async for " , node )
902+ self ._for_helper ("async for " , node )
881903
882- def __For_helper (self , fill , node ):
904+ def _for_helper (self , fill , node ):
883905 self .fill (fill )
884906 self .traverse (node .target )
885907 self .write (" in " )
@@ -974,6 +996,19 @@ def _fstring_FormattedValue(self, node, write):
974996 def visit_Name (self , node ):
975997 self .write (node .id )
976998
999+ def _write_docstring (self , node ):
1000+ self .fill ()
1001+ if node .kind == "u" :
1002+ self .write ("u" )
1003+
1004+ # Preserve quotes in the docstring by escaping them
1005+ value = node .value .replace ("\\ " , "\\ \\ " )
1006+ value = value .replace ('"""' , '""\" ' )
1007+ if value [- 1 ] == '"' :
1008+ value = value .replace ('"' , '\\ "' , - 1 )
1009+
1010+ self .write (f'"""{ value } """' )
1011+
9771012 def _write_constant (self , value ):
9781013 if isinstance (value , (float , complex )):
9791014 # Substitute overflowing decimal literal for AST infinities.
0 commit comments