Skip to content

Commit c80e7ef

Browse files
anandoleecopybara-github
authored andcommitted
Soft deprecate python MessageFactory
Soft deprecate python MessageFactory. Added new replacement APIs GetMessageClass(descriptor) and GetMessagesFromFiles(files, pool) PiperOrigin-RevId: 501802633
1 parent f95aafd commit c80e7ef

9 files changed

Lines changed: 167 additions & 103 deletions

File tree

python/google/protobuf/internal/decoder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -806,8 +806,7 @@ def DecodeItem(buffer, pos, end, message, field_dict):
806806
if value is None:
807807
message_type = extension.message_type
808808
if not hasattr(message_type, '_concrete_class'):
809-
# pylint: disable=protected-access
810-
message._FACTORY.GetPrototype(message_type)
809+
message_factory.GetMessageClass(message_type)
811810
value = field_dict.setdefault(
812811
extension, message_type._concrete_class())
813812
if value._InternalParse(buffer, message_start,message_end) != message_end:

python/google/protobuf/internal/extension_dict.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,9 @@ def __getitem__(self, extension_handle):
8989
elif extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
9090
message_type = extension_handle.message_type
9191
if not hasattr(message_type, '_concrete_class'):
92-
# pylint: disable=protected-access
93-
self._extended_message._FACTORY.GetPrototype(message_type)
92+
# pylint: disable=g-import-not-at-top
93+
from google.protobuf import message_factory
94+
message_factory.GetMessageClass(message_type)
9495
assert getattr(extension_handle.message_type, '_concrete_class', None), (
9596
'Uninitialized concrete class found for field %r (message type %r)'
9697
% (extension_handle.full_name,

python/google/protobuf/internal/message_factory_test.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -92,36 +92,17 @@ def testGetPrototype(self):
9292
pool = descriptor_pool.DescriptorPool(db)
9393
db.Add(self.factory_test1_fd)
9494
db.Add(self.factory_test2_fd)
95-
factory = message_factory.MessageFactory()
96-
cls = factory.GetPrototype(pool.FindMessageTypeByName(
95+
cls = message_factory.GetMessageClass(pool.FindMessageTypeByName(
9796
'google.protobuf.python.internal.Factory2Message'))
9897
self.assertFalse(cls is factory_test2_pb2.Factory2Message)
9998
self._ExerciseDynamicClass(cls)
100-
cls2 = factory.GetPrototype(pool.FindMessageTypeByName(
99+
cls2 = message_factory.GetMessageClass(pool.FindMessageTypeByName(
101100
'google.protobuf.python.internal.Factory2Message'))
102101
self.assertTrue(cls is cls2)
103102

104-
def testCreatePrototypeOverride(self):
105-
class MyMessageFactory(message_factory.MessageFactory):
106-
107-
def CreatePrototype(self, descriptor):
108-
cls = super(MyMessageFactory, self).CreatePrototype(descriptor)
109-
cls.additional_field = 'Some value'
110-
return cls
111-
112-
db = descriptor_database.DescriptorDatabase()
113-
pool = descriptor_pool.DescriptorPool(db)
114-
db.Add(self.factory_test1_fd)
115-
db.Add(self.factory_test2_fd)
116-
factory = MyMessageFactory()
117-
cls = factory.GetPrototype(pool.FindMessageTypeByName(
118-
'google.protobuf.python.internal.Factory2Message'))
119-
self.assertTrue(hasattr(cls, 'additional_field'))
120-
121103
def testGetExistingPrototype(self):
122-
factory = message_factory.MessageFactory()
123104
# Get Existing Prototype should not create a new class.
124-
cls = factory.GetPrototype(
105+
cls = message_factory.GetMessageClass(
125106
descriptor=factory_test2_pb2.Factory2Message.DESCRIPTOR)
126107
msg = factory_test2_pb2.Factory2Message()
127108
self.assertIsInstance(msg, cls)
@@ -181,15 +162,14 @@ def testGetMessages(self):
181162

182163
def testDuplicateExtensionNumber(self):
183164
pool = descriptor_pool.DescriptorPool()
184-
factory = message_factory.MessageFactory(pool=pool)
185165

186166
# Add Container message.
187167
f = descriptor_pb2.FileDescriptorProto(
188168
name='google/protobuf/internal/container.proto',
189169
package='google.protobuf.python.internal')
190170
f.message_type.add(name='Container').extension_range.add(start=1, end=10)
191171
pool.Add(f)
192-
msgs = factory.GetMessages([f.name])
172+
msgs = message_factory.GetMessageClassesForFiles([f.name], pool)
193173
self.assertIn('google.protobuf.python.internal.Container', msgs)
194174

195175
# Extend container.
@@ -205,7 +185,7 @@ def testDuplicateExtensionNumber(self):
205185
type_name='Extension',
206186
extendee='Container')
207187
pool.Add(f)
208-
msgs = factory.GetMessages([f.name])
188+
msgs = message_factory.GetMessageClassesForFiles([f.name], pool)
209189
self.assertIn('google.protobuf.python.internal.Extension', msgs)
210190

211191
# Add Duplicate extending the same field number.
@@ -223,7 +203,7 @@ def testDuplicateExtensionNumber(self):
223203
pool.Add(f)
224204

225205
with self.assertRaises(Exception) as cm:
226-
factory.GetMessages([f.name])
206+
message_factory.GetMessageClassesForFiles([f.name], pool)
227207

228208
self.assertIn(str(cm.exception),
229209
['Extensions '
@@ -281,8 +261,8 @@ def FindFileByName(self, name):
281261
db = SimpleDescriptorDB({f1.name: f1, f2.name: f2, f3.name: f3})
282262

283263
pool = descriptor_pool.DescriptorPool(db)
284-
factory = message_factory.MessageFactory(pool=pool)
285-
msgs = factory.GetMessages([f1.name, f3.name]) # Deliberately not f2.
264+
msgs = message_factory.GetMessageClassesForFiles(
265+
[f1.name, f3.name], pool) # Deliberately not f2.
286266
msg = msgs['google.protobuf.python.internal.Container']
287267
desc = msgs['google.protobuf.python.internal.Extension'].DESCRIPTOR
288268
ext1 = desc.file.extensions_by_name['top_level_extension_field']
@@ -293,8 +273,8 @@ def FindFileByName(self, name):
293273
serialized = m.SerializeToString()
294274

295275
pool = descriptor_pool.DescriptorPool(db)
296-
factory = message_factory.MessageFactory(pool=pool)
297-
msgs = factory.GetMessages([f1.name, f3.name]) # Deliberately not f2.
276+
msgs = message_factory.GetMessageClassesForFiles(
277+
[f1.name, f3.name], pool) # Deliberately not f2.
298278
msg = msgs['google.protobuf.python.internal.Container']
299279
desc = msgs['google.protobuf.python.internal.Extension'].DESCRIPTOR
300280
ext1 = desc.file.extensions_by_name['top_level_extension_field']

python/google/protobuf/json_format.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
from google.protobuf.internal import type_checkers
5555
from google.protobuf import descriptor
56+
from google.protobuf import message_factory
5657
from google.protobuf import symbol_database
5758

5859

@@ -409,7 +410,7 @@ def _CreateMessageFromTypeUrl(type_url, descriptor_pool):
409410
raise TypeError(
410411
'Can not find message descriptor by type_url: {0}'.format(type_url)
411412
) from e
412-
message_class = db.GetPrototype(message_descriptor)
413+
message_class = message_factory.GetMessageClass(message_descriptor)
413414
return message_class()
414415

415416

python/google/protobuf/message_factory.py

Lines changed: 114 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939

4040
__author__ = '[email protected] (Matt Toia)'
4141

42+
import warnings
43+
4244
from google.protobuf.internal import api_implementation
4345
from google.protobuf import descriptor_pool
4446
from google.protobuf import message
@@ -53,6 +55,95 @@
5355
_GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl.GeneratedProtocolMessageType
5456

5557

58+
def GetMessageClass(descriptor):
59+
"""Obtains a proto2 message class based on the passed in descriptor.
60+
61+
Passing a descriptor with a fully qualified name matching a previous
62+
invocation will cause the same class to be returned.
63+
64+
Args:
65+
descriptor: The descriptor to build from.
66+
67+
Returns:
68+
A class describing the passed in descriptor.
69+
"""
70+
concrete_class = getattr(descriptor, '_concrete_class', None)
71+
if concrete_class:
72+
return concrete_class
73+
return _InternalCreateMessageClass(descriptor)
74+
75+
76+
def GetMessageClassesForFiles(files, pool):
77+
"""Gets all the messages from specified files.
78+
79+
This will find and resolve dependencies, failing if the descriptor
80+
pool cannot satisfy them.
81+
82+
Args:
83+
files: The file names to extract messages from.
84+
pool: The descriptor pool to find the files including the dependent
85+
files.
86+
87+
Returns:
88+
A dictionary mapping proto names to the message classes.
89+
"""
90+
result = {}
91+
for file_name in files:
92+
file_desc = pool.FindFileByName(file_name)
93+
for desc in file_desc.message_types_by_name.values():
94+
result[desc.full_name] = GetMessageClass(desc)
95+
96+
# While the extension FieldDescriptors are created by the descriptor pool,
97+
# the python classes created in the factory need them to be registered
98+
# explicitly, which is done below.
99+
#
100+
# The call to RegisterExtension will specifically check if the
101+
# extension was already registered on the object and either
102+
# ignore the registration if the original was the same, or raise
103+
# an error if they were different.
104+
105+
for extension in file_desc.extensions_by_name.values():
106+
extended_class = GetMessageClass(extension.containing_type)
107+
extended_class.RegisterExtension(extension)
108+
# Recursively load protos for extension field, in order to be able to
109+
# fully represent the extension. This matches the behavior for regular
110+
# fields too.
111+
if extension.message_type:
112+
GetMessageClass(extension.message_type)
113+
return result
114+
115+
116+
def _InternalCreateMessageClass(descriptor):
117+
"""Builds a proto2 message class based on the passed in descriptor.
118+
119+
Args:
120+
descriptor: The descriptor to build from.
121+
122+
Returns:
123+
A class describing the passed in descriptor.
124+
"""
125+
descriptor_name = descriptor.name
126+
result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE(
127+
descriptor_name,
128+
(message.Message,),
129+
{
130+
'DESCRIPTOR': descriptor,
131+
# If module not set, it wrongly points to message_factory module.
132+
'__module__': None,
133+
})
134+
for field in descriptor.fields:
135+
if field.message_type:
136+
GetMessageClass(field.message_type)
137+
for extension in result_class.DESCRIPTOR.extensions:
138+
extended_class = GetMessageClass(extension.containing_type)
139+
extended_class.RegisterExtension(extension)
140+
if extension.message_type:
141+
GetMessageClass(extension.message_type)
142+
return result_class
143+
144+
145+
# Deprecated. Please use GetMessageClass() or GetMessageClassesForFiles()
146+
# method above instead.
56147
class MessageFactory(object):
57148
"""Factory for creating Proto2 messages from descriptors in a pool."""
58149

@@ -72,44 +163,29 @@ def GetPrototype(self, descriptor):
72163
Returns:
73164
A class describing the passed in descriptor.
74165
"""
75-
concrete_class = getattr(descriptor, '_concrete_class', None)
76-
if concrete_class:
77-
return concrete_class
78-
result_class = self.CreatePrototype(descriptor)
79-
return result_class
166+
# TODO(b/258832141): add this warning
167+
# warnings.warn('MessageFactory class is deprecated. Please use '
168+
# 'GetMessageClass() instead of MessageFactory.GetPrototype. '
169+
# 'MessageFactory class will be removed after 2024.')
170+
return GetMessageClass(descriptor)
80171

81172
def CreatePrototype(self, descriptor):
82173
"""Builds a proto2 message class based on the passed in descriptor.
83174
84175
Don't call this function directly, it always creates a new class. Call
85-
GetPrototype() instead. This method is meant to be overridden in subblasses
86-
to perform additional operations on the newly constructed class.
176+
GetMessageClass() instead.
87177
88178
Args:
89179
descriptor: The descriptor to build from.
90180
91181
Returns:
92182
A class describing the passed in descriptor.
93183
"""
94-
descriptor_name = descriptor.name
95-
result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE(
96-
descriptor_name,
97-
(message.Message,),
98-
{
99-
'DESCRIPTOR': descriptor,
100-
# If module not set, it wrongly points to message_factory module.
101-
'__module__': None,
102-
})
103-
result_class._FACTORY = self # pylint: disable=protected-access
104-
for field in descriptor.fields:
105-
if field.message_type:
106-
self.GetPrototype(field.message_type)
107-
for extension in result_class.DESCRIPTOR.extensions:
108-
extended_class = self.GetPrototype(extension.containing_type)
109-
extended_class.RegisterExtension(extension)
110-
if extension.message_type:
111-
self.GetPrototype(extension.message_type)
112-
return result_class
184+
# TODO(b/258832141): add this warning
185+
# warnings.warn('Directly call CreatePrototype is wrong. Please use '
186+
# 'GetMessageClass() method instead. Directly use '
187+
# 'CreatePrototype will raise error after July 2023.')
188+
return _InternalCreateMessageClass(descriptor)
113189

114190
def GetMessages(self, files):
115191
"""Gets all the messages from a specified file.
@@ -125,37 +201,20 @@ def GetMessages(self, files):
125201
any dependent messages as well as any messages defined in the same file as
126202
a specified message.
127203
"""
128-
result = {}
129-
for file_name in files:
130-
file_desc = self.pool.FindFileByName(file_name)
131-
for desc in file_desc.message_types_by_name.values():
132-
result[desc.full_name] = self.GetPrototype(desc)
133-
134-
# While the extension FieldDescriptors are created by the descriptor pool,
135-
# the python classes created in the factory need them to be registered
136-
# explicitly, which is done below.
137-
#
138-
# The call to RegisterExtension will specifically check if the
139-
# extension was already registered on the object and either
140-
# ignore the registration if the original was the same, or raise
141-
# an error if they were different.
142-
143-
for extension in file_desc.extensions_by_name.values():
144-
extended_class = self.GetPrototype(extension.containing_type)
145-
extended_class.RegisterExtension(extension)
146-
if extension.message_type:
147-
self.GetPrototype(extension.message_type)
148-
return result
149-
150-
151-
_FACTORY = MessageFactory()
204+
# TODO(b/258832141): add this warning
205+
# warnings.warn('MessageFactory class is deprecated. Please use '
206+
# 'GetMessageClassesForFiles() instead of '
207+
# 'MessageFactory.GetMessages(). MessageFactory class '
208+
# 'will be removed after 2024.')
209+
return GetMessageClassesForFiles(files, self.pool)
152210

153211

154-
def GetMessages(file_protos):
212+
def GetMessages(file_protos, pool=None):
155213
"""Builds a dictionary of all the messages available in a set of files.
156214
157215
Args:
158216
file_protos: Iterable of FileDescriptorProto to build messages out of.
217+
pool: The descriptor pool to add the file protos.
159218
160219
Returns:
161220
A dictionary mapping proto names to the message classes. This will include
@@ -164,13 +223,15 @@ def GetMessages(file_protos):
164223
"""
165224
# The cpp implementation of the protocol buffer library requires to add the
166225
# message in topological order of the dependency graph.
226+
des_pool = pool or descriptor_pool.DescriptorPool()
167227
file_by_name = {file_proto.name: file_proto for file_proto in file_protos}
168228
def _AddFile(file_proto):
169229
for dependency in file_proto.dependency:
170230
if dependency in file_by_name:
171231
# Remove from elements to be visited, in order to cut cycles.
172232
_AddFile(file_by_name.pop(dependency))
173-
_FACTORY.pool.Add(file_proto)
233+
des_pool.Add(file_proto)
174234
while file_by_name:
175235
_AddFile(file_by_name.popitem()[1])
176-
return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos])
236+
return GetMessageClassesForFiles(
237+
[file_proto.name for file_proto in file_protos], des_pool)

0 commit comments

Comments
 (0)