3939
4040__author__ = '[email protected] (Matt Toia)' 4141
42+ import warnings
43+
4244from google .protobuf .internal import api_implementation
4345from google .protobuf import descriptor_pool
4446from google .protobuf import message
5355_GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl .GeneratedProtocolMessageType
5456
5557
58+ def GetMessageClass (descriptor ):
59+ """Obtains a proto2 message class based on the passed in descriptor.
60+
61+ Passing a descriptor with a fully qualified name matching a previous
62+ invocation will cause the same class to be returned.
63+
64+ Args:
65+ descriptor: The descriptor to build from.
66+
67+ Returns:
68+ A class describing the passed in descriptor.
69+ """
70+ concrete_class = getattr (descriptor , '_concrete_class' , None )
71+ if concrete_class :
72+ return concrete_class
73+ return _InternalCreateMessageClass (descriptor )
74+
75+
76+ def GetMessageClassesForFiles (files , pool ):
77+ """Gets all the messages from specified files.
78+
79+ This will find and resolve dependencies, failing if the descriptor
80+ pool cannot satisfy them.
81+
82+ Args:
83+ files: The file names to extract messages from.
84+ pool: The descriptor pool to find the files including the dependent
85+ files.
86+
87+ Returns:
88+ A dictionary mapping proto names to the message classes.
89+ """
90+ result = {}
91+ for file_name in files :
92+ file_desc = pool .FindFileByName (file_name )
93+ for desc in file_desc .message_types_by_name .values ():
94+ result [desc .full_name ] = GetMessageClass (desc )
95+
96+ # While the extension FieldDescriptors are created by the descriptor pool,
97+ # the python classes created in the factory need them to be registered
98+ # explicitly, which is done below.
99+ #
100+ # The call to RegisterExtension will specifically check if the
101+ # extension was already registered on the object and either
102+ # ignore the registration if the original was the same, or raise
103+ # an error if they were different.
104+
105+ for extension in file_desc .extensions_by_name .values ():
106+ extended_class = GetMessageClass (extension .containing_type )
107+ extended_class .RegisterExtension (extension )
108+ # Recursively load protos for extension field, in order to be able to
109+ # fully represent the extension. This matches the behavior for regular
110+ # fields too.
111+ if extension .message_type :
112+ GetMessageClass (extension .message_type )
113+ return result
114+
115+
116+ def _InternalCreateMessageClass (descriptor ):
117+ """Builds a proto2 message class based on the passed in descriptor.
118+
119+ Args:
120+ descriptor: The descriptor to build from.
121+
122+ Returns:
123+ A class describing the passed in descriptor.
124+ """
125+ descriptor_name = descriptor .name
126+ result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE (
127+ descriptor_name ,
128+ (message .Message ,),
129+ {
130+ 'DESCRIPTOR' : descriptor ,
131+ # If module not set, it wrongly points to message_factory module.
132+ '__module__' : None ,
133+ })
134+ for field in descriptor .fields :
135+ if field .message_type :
136+ GetMessageClass (field .message_type )
137+ for extension in result_class .DESCRIPTOR .extensions :
138+ extended_class = GetMessageClass (extension .containing_type )
139+ extended_class .RegisterExtension (extension )
140+ if extension .message_type :
141+ GetMessageClass (extension .message_type )
142+ return result_class
143+
144+
145+ # Deprecated. Please use GetMessageClass() or GetMessageClassesForFiles()
146+ # method above instead.
56147class MessageFactory (object ):
57148 """Factory for creating Proto2 messages from descriptors in a pool."""
58149
@@ -72,44 +163,29 @@ def GetPrototype(self, descriptor):
72163 Returns:
73164 A class describing the passed in descriptor.
74165 """
75- concrete_class = getattr ( descriptor , '_concrete_class' , None )
76- if concrete_class :
77- return concrete_class
78- result_class = self . CreatePrototype ( descriptor )
79- return result_class
166+ # TODO(b/258832141): add this warning
167+ # warnings.warn('MessageFactory class is deprecated. Please use '
168+ # 'GetMessageClass() instead of MessageFactory.GetPrototype. '
169+ # 'MessageFactory class will be removed after 2024.' )
170+ return GetMessageClass ( descriptor )
80171
81172 def CreatePrototype (self , descriptor ):
82173 """Builds a proto2 message class based on the passed in descriptor.
83174
84175 Don't call this function directly, it always creates a new class. Call
85- GetPrototype() instead. This method is meant to be overridden in subblasses
86- to perform additional operations on the newly constructed class.
176+ GetMessageClass() instead.
87177
88178 Args:
89179 descriptor: The descriptor to build from.
90180
91181 Returns:
92182 A class describing the passed in descriptor.
93183 """
94- descriptor_name = descriptor .name
95- result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE (
96- descriptor_name ,
97- (message .Message ,),
98- {
99- 'DESCRIPTOR' : descriptor ,
100- # If module not set, it wrongly points to message_factory module.
101- '__module__' : None ,
102- })
103- result_class ._FACTORY = self # pylint: disable=protected-access
104- for field in descriptor .fields :
105- if field .message_type :
106- self .GetPrototype (field .message_type )
107- for extension in result_class .DESCRIPTOR .extensions :
108- extended_class = self .GetPrototype (extension .containing_type )
109- extended_class .RegisterExtension (extension )
110- if extension .message_type :
111- self .GetPrototype (extension .message_type )
112- return result_class
184+ # TODO(b/258832141): add this warning
185+ # warnings.warn('Directly call CreatePrototype is wrong. Please use '
186+ # 'GetMessageClass() method instead. Directly use '
187+ # 'CreatePrototype will raise error after July 2023.')
188+ return _InternalCreateMessageClass (descriptor )
113189
114190 def GetMessages (self , files ):
115191 """Gets all the messages from a specified file.
@@ -125,37 +201,20 @@ def GetMessages(self, files):
125201 any dependent messages as well as any messages defined in the same file as
126202 a specified message.
127203 """
128- result = {}
129- for file_name in files :
130- file_desc = self .pool .FindFileByName (file_name )
131- for desc in file_desc .message_types_by_name .values ():
132- result [desc .full_name ] = self .GetPrototype (desc )
133-
134- # While the extension FieldDescriptors are created by the descriptor pool,
135- # the python classes created in the factory need them to be registered
136- # explicitly, which is done below.
137- #
138- # The call to RegisterExtension will specifically check if the
139- # extension was already registered on the object and either
140- # ignore the registration if the original was the same, or raise
141- # an error if they were different.
142-
143- for extension in file_desc .extensions_by_name .values ():
144- extended_class = self .GetPrototype (extension .containing_type )
145- extended_class .RegisterExtension (extension )
146- if extension .message_type :
147- self .GetPrototype (extension .message_type )
148- return result
149-
150-
151- _FACTORY = MessageFactory ()
204+ # TODO(b/258832141): add this warning
205+ # warnings.warn('MessageFactory class is deprecated. Please use '
206+ # 'GetMessageClassesForFiles() instead of '
207+ # 'MessageFactory.GetMessages(). MessageFactory class '
208+ # 'will be removed after 2024.')
209+ return GetMessageClassesForFiles (files , self .pool )
152210
153211
154- def GetMessages (file_protos ):
212+ def GetMessages (file_protos , pool = None ):
155213 """Builds a dictionary of all the messages available in a set of files.
156214
157215 Args:
158216 file_protos: Iterable of FileDescriptorProto to build messages out of.
217+ pool: The descriptor pool to add the file protos.
159218
160219 Returns:
161220 A dictionary mapping proto names to the message classes. This will include
@@ -164,13 +223,15 @@ def GetMessages(file_protos):
164223 """
165224 # The cpp implementation of the protocol buffer library requires to add the
166225 # message in topological order of the dependency graph.
226+ des_pool = pool or descriptor_pool .DescriptorPool ()
167227 file_by_name = {file_proto .name : file_proto for file_proto in file_protos }
168228 def _AddFile (file_proto ):
169229 for dependency in file_proto .dependency :
170230 if dependency in file_by_name :
171231 # Remove from elements to be visited, in order to cut cycles.
172232 _AddFile (file_by_name .pop (dependency ))
173- _FACTORY . pool .Add (file_proto )
233+ des_pool .Add (file_proto )
174234 while file_by_name :
175235 _AddFile (file_by_name .popitem ()[1 ])
176- return _FACTORY .GetMessages ([file_proto .name for file_proto in file_protos ])
236+ return GetMessageClassesForFiles (
237+ [file_proto .name for file_proto in file_protos ], des_pool )
0 commit comments