Skip to content

Commit d4642d1

Browse files
committed
address PR comments: rework inheritance with a QueryCompatibleProtocol Mixin, rework the __type hack to do a shallow copy
1 parent 8071bb6 commit d4642d1

File tree

1 file changed

+113
-140
lines changed

1 file changed

+113
-140
lines changed

localstack-core/localstack/aws/protocol/serializer.py

Lines changed: 113 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@
9494

9595
import abc
9696
import base64
97-
import copy
9897
import datetime
9998
import functools
10099
import json
@@ -152,44 +151,6 @@
152151
REQUEST_ID_CHARACTERS = string.digits + string.ascii_uppercase
153152

154153

155-
# Some Serializer cannot serialize an exception if it is not defined in our specs.
156-
# LocalStack defines a way to have user-defined exception by subclassing `CommonServiceException`, and those
157-
# serializers needs to be able to encode those, as well as InternalError
158-
# We are creating a default botocore structure shape to be used in such cases.
159-
160-
161-
class DefaultStringShapeResolver(ShapeResolver):
162-
def get_shape_by_name(
163-
self,
164-
shape_name,
165-
member_traits=None,
166-
):
167-
return StringShape(
168-
shape_name=shape_name,
169-
shape_model={"type": "string"},
170-
)
171-
172-
173-
DEFAULT_ERROR_STRUCTURE_SHAPE = StructureShape(
174-
shape_name="DefaultErrorStructure",
175-
shape_model={
176-
"type": "structure",
177-
"members": {
178-
"message": {"shape": "ErrorMessage"},
179-
"__type": {"shape": "ErrorType"},
180-
},
181-
"error": {"code": "DefaultErrorStructure", "httpStatusCode": 400, "senderFault": True},
182-
"exception": True,
183-
},
184-
shape_resolver=ShapeResolver(
185-
shape_map={
186-
"ErrorMessage": {"type": "string"},
187-
"ErrorType": {"type": "string"},
188-
},
189-
),
190-
)
191-
192-
193154
class ResponseSerializerError(Exception):
194155
"""
195156
Error which is thrown if the request serialization fails.
@@ -657,6 +618,13 @@ def _add_md5_header(self, response: Response):
657618
def _get_error_message(self, error: Exception) -> str | None:
658619
return str(error) if error is not None and str(error) != "None" else None
659620

621+
def _get_error_status_code(
622+
self, error: ServiceException, headers: Headers, service_model: ServiceModel
623+
) -> int:
624+
return error.status_code
625+
626+
627+
class QueryCompatibleProtocolMixin:
660628
def _get_error_status_code(
661629
self, error: ServiceException, headers: dict | Headers | None, service_model: ServiceModel
662630
) -> int:
@@ -669,7 +637,7 @@ def _get_error_status_code(
669637
if not service_model.is_query_compatible:
670638
return error.status_code
671639

672-
if self._is_request_query_compatible(headers):
640+
if headers and headers.get("x-amzn-query-mode") == "true":
673641
return error.status_code
674642

675643
# we only want to override status code 4XX
@@ -678,9 +646,6 @@ def _get_error_status_code(
678646

679647
return error.status_code
680648

681-
def _is_request_query_compatible(self, headers: Headers | dict | None) -> bool:
682-
return headers and headers.get("x-amzn-query-mode") == "true"
683-
684649
def _add_query_compatible_error_header(self, response: Response, error: ServiceException):
685650
"""
686651
Add an `x-amzn-query-error` header for client to translate errors codes from former `query` services
@@ -690,6 +655,23 @@ def _add_query_compatible_error_header(self, response: Response, error: ServiceE
690655
sender_fault = "Sender" if error.sender_fault else "Receiver"
691656
response.headers["x-amzn-query-error"] = f"{error.code};{sender_fault}"
692657

658+
def _get_error_code(
659+
self, is_query_compatible: bool, error: ServiceException, shape: Shape | None = None
660+
):
661+
# if the operation is query compatible, we need to add to use shape name
662+
if is_query_compatible:
663+
if shape:
664+
code = shape.name
665+
else:
666+
# if the shape is not defined, we are using the Exception named to derive the `Code`, like you would
667+
# from the shape. This allows us to have Exception that are valid in multi-protocols by defining its
668+
# code and its name to be different
669+
code = error.__class__.__name__
670+
else:
671+
code = error.code
672+
673+
return code
674+
693675

694676
class BaseXMLResponseSerializer(ResponseSerializer):
695677
"""
@@ -1246,11 +1228,6 @@ def _prepare_additional_traits_in_xml(self, root: ETree.Element | None, request_
12461228
request_id_element = ETree.SubElement(response_metadata, "RequestId")
12471229
request_id_element.text = request_id
12481230

1249-
def _get_error_status_code(
1250-
self, error: ServiceException, headers: Headers, service_model: ServiceModel
1251-
) -> int:
1252-
return error.status_code
1253-
12541231

12551232
class EC2ResponseSerializer(QueryResponseSerializer):
12561233
"""
@@ -1308,7 +1285,7 @@ def _prepare_additional_traits_in_xml(self, root: ETree.Element | None, request_
13081285
request_id_element.text = request_id
13091286

13101287

1311-
class JSONResponseSerializer(ResponseSerializer):
1288+
class JSONResponseSerializer(QueryCompatibleProtocolMixin, ResponseSerializer):
13121289
"""
13131290
The ``JSONResponseSerializer`` is responsible for the serialization of responses from services with the ``json``
13141291
protocol. It implements the JSON response body serialization, which is also used by the
@@ -1340,43 +1317,31 @@ def _serialize_error(
13401317
# if json-1.1, it should only be the name
13411318

13421319
is_query_compatible = operation_model.service_model.is_query_compatible
1343-
# if the operation is query compatible, we need to add to use shape name
1344-
if is_query_compatible:
1345-
if shape:
1346-
code = shape.name
1347-
else:
1348-
# if the shape is not defined, we are using the Exception named to derive the `Code`, like you would
1349-
# from the shape. This allows us to have Exception that are valid in multi-protocols by defining its
1350-
# code and its name to be different
1351-
code = error.__class__.__name__
1352-
else:
1353-
code = error.code
1320+
code = self._get_error_code(is_query_compatible, error, shape)
13541321

13551322
response.headers["X-Amzn-Errortype"] = code
13561323

1357-
if not shape:
1358-
shape = DEFAULT_ERROR_STRUCTURE_SHAPE
1359-
13601324
# the `__type` field is not defined in default botocore error shapes
13611325
body["__type"] = code
13621326

1363-
remaining_params = {}
1364-
# TODO add a possibility to serialize simple non-modelled errors (like S3 NoSuchBucket#BucketName)
1365-
for member in shape.members:
1366-
if hasattr(error, member):
1367-
value = getattr(error, member)
1327+
if shape:
1328+
remaining_params = {}
1329+
# TODO add a possibility to serialize simple non-modelled errors (like S3 NoSuchBucket#BucketName)
1330+
for member in shape.members:
1331+
if hasattr(error, member):
1332+
value = getattr(error, member)
13681333

1369-
# Default error message fields can sometimes have different casing in the specs
1370-
elif member.lower() in ["code", "message"] and hasattr(error, member.lower()):
1371-
value = getattr(error, member.lower())
1334+
# Default error message fields can sometimes have different casing in the specs
1335+
elif member.lower() in ["code", "message"] and hasattr(error, member.lower()):
1336+
value = getattr(error, member.lower())
13721337

1373-
else:
1374-
continue
1338+
else:
1339+
continue
13751340

1376-
if value:
1377-
remaining_params[member] = value
1341+
if value:
1342+
remaining_params[member] = value
13781343

1379-
self._serialize(body, remaining_params, shape, None, mime_type)
1344+
self._serialize(body, remaining_params, shape, None, mime_type)
13801345

13811346
# this is a workaround, some Error Shape do not define a `Message` field, but it is always returned
13821347
# this could be solved at the same time as the `__type` field
@@ -1390,7 +1355,7 @@ def _serialize_error(
13901355
else:
13911356
response.set_json(body)
13921357

1393-
if operation_model.service_model.is_query_compatible:
1358+
if is_query_compatible:
13941359
self._add_query_compatible_error_header(response, error)
13951360

13961361
def _serialize_response(
@@ -1565,6 +1530,11 @@ class BaseCBORResponseSerializer(ResponseSerializer):
15651530
required at the end.
15661531
AWS, for both Kinesis and `smithy-rpc-v2-cbor` services, is using indefinite data structures when returning
15671532
responses.
1533+
1534+
The CBOR serializer cannot serialize an exception if it is not defined in our specs.
1535+
LocalStack defines a way to have user-defined exception by subclassing `CommonServiceException`, so it needs to be
1536+
able to encode those, as well as InternalError
1537+
We are creating a default botocore structure shape (`_DEFAULT_ERROR_STRUCTURE_SHAPE`) to be used in such cases.
15681538
"""
15691539

15701540
SUPPORTED_MIME_TYPES = [APPLICATION_CBOR, APPLICATION_AMZ_CBOR_1_1]
@@ -1582,6 +1552,27 @@ class BaseCBORResponseSerializer(ResponseSerializer):
15821552
BREAK_CODE = b"\xff"
15831553
USE_INDEFINITE_DATA_STRUCTURE = True
15841554

1555+
_ERROR_TYPE_SHAPE = StringShape(shape_name="__type", shape_model={"type": "string"})
1556+
1557+
_DEFAULT_ERROR_STRUCTURE_SHAPE = StructureShape(
1558+
shape_name="DefaultErrorStructure",
1559+
shape_model={
1560+
"type": "structure",
1561+
"members": {
1562+
"message": {"shape": "ErrorMessage"},
1563+
"__type": {"shape": "ErrorType"},
1564+
},
1565+
"error": {"code": "DefaultErrorStructure", "httpStatusCode": 400, "senderFault": True},
1566+
"exception": True,
1567+
},
1568+
shape_resolver=ShapeResolver(
1569+
shape_map={
1570+
"ErrorMessage": {"type": "string"},
1571+
"ErrorType": {"type": "string"},
1572+
},
1573+
),
1574+
)
1575+
15851576
def _serialize_data_item(
15861577
self, serialized: bytearray, value: Any, shape: Shape | None, name: str | None = None
15871578
) -> None:
@@ -1679,7 +1670,12 @@ def _serialize_type_map(
16791670
serialized.extend(closing_bytes)
16801671

16811672
def _serialize_type_structure(
1682-
self, serialized: bytearray, value: dict, shape: Shape | None, name: str | None = None
1673+
self,
1674+
serialized: bytearray,
1675+
value: dict,
1676+
shape: Shape | None,
1677+
name: str | None = None,
1678+
shape_members: dict[str, Shape] | None = None,
16831679
) -> None:
16841680
if name is not None:
16851681
# For nested structures, we need to serialize the key first
@@ -1692,8 +1688,7 @@ def _serialize_type_structure(
16921688
value, self.MAP_MAJOR_TYPE
16931689
)
16941690
serialized.extend(initial_bytes)
1695-
1696-
members = shape.members
1691+
members = shape_members or shape.members
16971692
for member_key, member_value in value.items():
16981693
member_shape = members[member_key]
16991694
if "name" in member_shape.serialization:
@@ -1818,6 +1813,38 @@ def _get_bytes_for_data_structure(
18181813

18191814
return initial_byte, None
18201815

1816+
def _serialize_error_structure(
1817+
self, body: bytearray, shape: Shape | None, error: ServiceException, code: str
1818+
):
1819+
if not shape:
1820+
shape = self._DEFAULT_ERROR_STRUCTURE_SHAPE
1821+
shape_members = shape.members
1822+
else:
1823+
# we need to manually add the `__type` field to the shape members as it is not part of the specs
1824+
# we do a shallow copy of the shape members
1825+
shape_members = shape.members.copy()
1826+
shape_members["__type"] = self._ERROR_TYPE_SHAPE
1827+
1828+
# Error responses in the rpcv2Cbor protocol MUST be serialized identically to standard responses with one
1829+
# additional component to distinguish which error is contained: a body field named __type.
1830+
params = {"__type": code}
1831+
1832+
for member in shape_members:
1833+
if hasattr(error, member):
1834+
value = getattr(error, member)
1835+
1836+
# Default error message fields can sometimes have different casing in the specs
1837+
elif member.lower() in ["code", "message"] and hasattr(error, member.lower()):
1838+
value = getattr(error, member.lower())
1839+
1840+
else:
1841+
continue
1842+
1843+
if value:
1844+
params[member] = value
1845+
1846+
self._serialize_type_structure(body, params, shape, None, shape_members=shape_members)
1847+
18211848

18221849
class CBORResponseSerializer(BaseCBORResponseSerializer):
18231850
"""
@@ -1841,25 +1868,7 @@ def _serialize_error(
18411868
response.content_type = mime_type
18421869
response.headers["X-Amzn-Errortype"] = error.code
18431870

1844-
if shape:
1845-
# FIXME: we need to manually add the `__type` field to the shape as it is not part of the specs
1846-
# think about a better way, this is very hacky
1847-
shape_copy = copy.deepcopy(shape)
1848-
shape_copy.members["__type"] = StringShape(
1849-
shape_name="__type", shape_model={"type": "string"}
1850-
)
1851-
remaining_params = {"__type": error.code}
1852-
1853-
for member_name in shape_copy.members:
1854-
if hasattr(error, member_name):
1855-
remaining_params[member_name] = getattr(error, member_name)
1856-
# Default error message fields can sometimes have different casing in the specs
1857-
elif member_name.lower() in ["code", "message"] and hasattr(
1858-
error, member_name.lower()
1859-
):
1860-
remaining_params[member_name] = getattr(error, member_name.lower())
1861-
1862-
self._serialize_data_item(body, remaining_params, shape_copy, None)
1871+
self._serialize_error_structure(body, shape, error, code=error.code)
18631872

18641873
response.set_response(bytes(body))
18651874

@@ -1935,7 +1944,9 @@ def _serialize_body_params(
19351944
raise NotImplementedError
19361945

19371946

1938-
class RpcV2CBORResponseSerializer(BaseRpcV2ResponseSerializer, BaseCBORResponseSerializer):
1947+
class RpcV2CBORResponseSerializer(
1948+
QueryCompatibleProtocolMixin, BaseRpcV2ResponseSerializer, BaseCBORResponseSerializer
1949+
):
19391950
"""
19401951
The RpcV2CBORResponseSerializer implements the CBOR body serialization part for the RPC v2 protocol, and implements the
19411952
specific exception serialization.
@@ -1975,47 +1986,9 @@ def _serialize_error(
19751986
# Responses for the rpcv2Cbor protocol SHOULD NOT contain the X-Amzn-ErrorType header.
19761987
# Type information is always serialized in the payload. This is different from the `json` protocol
19771988
is_query_compatible = operation_model.service_model.is_query_compatible
1978-
# if the operation is query compatible, we need to add to use shape name
1979-
if is_query_compatible:
1980-
if shape:
1981-
code = shape.name
1982-
else:
1983-
# if the shape is not defined, we are using the Exception named to derive the `Code`, like you would
1984-
# from the shape. This allows us to have Exception that are valid in multi-protocols by defining its
1985-
# code and its name to be different
1986-
code = error.__class__.__name__
1987-
else:
1988-
code = error.code
1989-
1990-
if not shape:
1991-
shape_copy = DEFAULT_ERROR_STRUCTURE_SHAPE
1992-
else:
1993-
# FIXME: we need to manually add the `__type` field to the shape as it is not part of the specs
1994-
# think about a better way, this is very hacky
1995-
shape_copy = copy.deepcopy(shape)
1996-
shape_copy.members["__type"] = StringShape(
1997-
shape_name="__type", shape_model={"type": "string"}
1998-
)
1999-
2000-
# Error responses in the rpcv2Cbor protocol MUST be serialized identically to standard responses with one
2001-
# additional component to distinguish which error is contained: a body field named __type.
2002-
remaining_params = {"__type": code}
2003-
2004-
for member in shape_copy.members:
2005-
if hasattr(error, member):
2006-
value = getattr(error, member)
2007-
2008-
# Default error message fields can sometimes have different casing in the specs
2009-
elif member.lower() in ["code", "message"] and hasattr(error, member.lower()):
2010-
value = getattr(error, member.lower())
2011-
2012-
else:
2013-
continue
2014-
2015-
if value:
2016-
remaining_params[member] = value
1989+
code = self._get_error_code(is_query_compatible, error, shape)
20171990

2018-
self._serialize_data_item(body, remaining_params, shape_copy, None)
1991+
self._serialize_error_structure(body, shape, error, code=code)
20191992

20201993
response.set_response(bytes(body))
20211994

0 commit comments

Comments
 (0)