Skip to content

Commit 66ab28b

Browse files
baermatbentsku
andauthored
Sns:v2 platform endpoint operations (#13327)
Co-authored-by: Benjamin Simon <[email protected]>
1 parent 37b23d0 commit 66ab28b

File tree

6 files changed

+1437
-26
lines changed

6 files changed

+1437
-26
lines changed

localstack-core/localstack/services/sns/v2/models.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Literal, TypedDict
66

77
from localstack.aws.api.sns import (
8+
Endpoint,
89
MessageAttributeMap,
910
PlatformApplication,
1011
PublishBatchRequestEntry,
@@ -39,6 +40,12 @@ class Topic(TypedDict, total=True):
3940
]
4041

4142

43+
class EndpointAttributeNames(StrEnum):
44+
CUSTOM_USER_DATA = "CustomUserData"
45+
Token = "Token"
46+
ENABLED = "Enabled"
47+
48+
4249
SMS_ATTRIBUTE_NAMES = [
4350
"DeliveryStatusIAMRole",
4451
"DeliveryStatusSuccessSamplingRate",
@@ -143,6 +150,19 @@ def from_batch_entry(cls, entry: PublishBatchRequestEntry, is_fifo=False) -> "Sn
143150
)
144151

145152

153+
@dataclass
154+
class PlatformEndpoint:
155+
platform_application_arn: str
156+
platform_endpoint: Endpoint
157+
158+
159+
@dataclass
160+
class PlatformApplicationDetails:
161+
platform_application: PlatformApplication
162+
# maps all Endpoints of the PlatformApplication, from their Token to their ARN
163+
platform_endpoints: dict[str, str]
164+
165+
146166
class SnsStore(BaseStore):
147167
topics: dict[str, Topic] = LocalAttribute(default=dict)
148168

@@ -156,7 +176,10 @@ class SnsStore(BaseStore):
156176
subscription_tokens: dict[str, str] = LocalAttribute(default=dict)
157177

158178
# maps platform application arns to platform applications
159-
platform_applications: dict[str, PlatformApplication] = LocalAttribute(default=dict)
179+
platform_applications: dict[str, PlatformApplicationDetails] = LocalAttribute(default=dict)
180+
181+
# maps endpoint arns to platform endpoints
182+
platform_endpoints: dict[str, PlatformEndpoint] = LocalAttribute(default=dict)
160183

161184
# topic/subscription independent default values for sending sms messages
162185
sms_attributes: dict[str, str] = LocalAttribute(default=dict)

localstack-core/localstack/services/sns/v2/provider.py

Lines changed: 144 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
from localstack.aws.api.sns import (
1111
AmazonResourceName,
1212
ConfirmSubscriptionResponse,
13+
CreateEndpointResponse,
1314
CreatePlatformApplicationResponse,
1415
CreateTopicResponse,
16+
Endpoint,
17+
GetEndpointAttributesResponse,
1518
GetPlatformApplicationAttributesResponse,
1619
GetSMSAttributesResponse,
1720
GetSubscriptionAttributesResponse,
@@ -60,6 +63,9 @@
6063
SMS_ATTRIBUTE_NAMES,
6164
SMS_DEFAULT_SENDER_REGEX,
6265
SMS_TYPES,
66+
EndpointAttributeNames,
67+
PlatformApplicationDetails,
68+
PlatformEndpoint,
6369
SnsMessage,
6470
SnsMessageType,
6571
SnsStore,
@@ -68,6 +74,7 @@
6874
sns_stores,
6975
)
7076
from localstack.services.sns.v2.utils import (
77+
create_platform_endpoint_arn,
7178
create_subscription_arn,
7279
encode_subscription_token_with_region,
7380
get_next_page_token_from_arn,
@@ -237,10 +244,11 @@ def subscribe(
237244
raise InvalidParameterException("Invalid parameter: SQS endpoint ARN")
238245

239246
elif protocol == "application":
240-
# TODO: This needs to be implemented once applications are ported from moto to the new provider
241-
raise NotImplementedError(
242-
"This functionality needs yet to be ported to the new SNS provider"
243-
)
247+
# TODO: Validate exact behaviour
248+
try:
249+
parse_arn(endpoint)
250+
except InvalidArnException:
251+
raise InvalidParameterException("Invalid parameter: ApplicationEndpoint ARN")
244252

245253
if ".fifo" in endpoint and ".fifo" not in topic_arn:
246254
# TODO: move to sqs protocol block if possible
@@ -591,17 +599,24 @@ def create_platform_application(
591599
account_id=context.account_id,
592600
region_name=context.region,
593601
)
594-
platform_application = PlatformApplication(
595-
PlatformApplicationArn=application_arn, Attributes=_attributes
602+
platform_application_details = PlatformApplicationDetails(
603+
platform_application=PlatformApplication(
604+
PlatformApplicationArn=application_arn,
605+
Attributes=_attributes,
606+
),
607+
platform_endpoints={},
596608
)
597-
store.platform_applications[application_arn] = platform_application
598-
return CreatePlatformApplicationResponse(**platform_application)
609+
store.platform_applications[application_arn] = platform_application_details
610+
611+
return platform_application_details.platform_application
599612

600613
def delete_platform_application(
601614
self, context: RequestContext, platform_application_arn: String, **kwargs
602615
) -> None:
603616
store = self.get_store(context.account_id, context.region)
604617
store.platform_applications.pop(platform_application_arn, None)
618+
# TODO: if the platform had endpoints, should we remove them from the store? There is no way to list
619+
# endpoints without an application, so this is impossible to check the state of AWS here
605620

606621
def list_platform_applications(
607622
self, context: RequestContext, next_token: String | None = None, **kwargs
@@ -615,7 +630,9 @@ def list_platform_applications(
615630
next_token=next_token,
616631
)
617632

618-
response = ListPlatformApplicationsResponse(PlatformApplications=page)
633+
response = ListPlatformApplicationsResponse(
634+
PlatformApplications=[platform_app.platform_application for platform_app in page]
635+
)
619636
if token:
620637
response["NextToken"] = token
621638
return response
@@ -644,15 +661,112 @@ def set_platform_application_attributes(
644661
# Platform Endpoints
645662
#
646663

664+
def create_platform_endpoint(
665+
self,
666+
context: RequestContext,
667+
platform_application_arn: String,
668+
token: String,
669+
custom_user_data: String | None = None,
670+
attributes: MapStringToString | None = None,
671+
**kwargs,
672+
) -> CreateEndpointResponse:
673+
store = self.get_store(context.account_id, context.region)
674+
application = store.platform_applications.get(platform_application_arn)
675+
if not application:
676+
raise NotFoundException("PlatformApplication does not exist")
677+
endpoint_arn = application.platform_endpoints.get(token, {})
678+
attributes = attributes or {}
679+
_validate_endpoint_attributes(attributes, allow_empty=True)
680+
# CustomUserData can be specified both in attributes and as parameter. Attributes take precedence
681+
attributes.setdefault(EndpointAttributeNames.CUSTOM_USER_DATA, custom_user_data)
682+
_attributes = {"Enabled": "true", "Token": token, **attributes}
683+
if endpoint_arn and (
684+
platform_endpoint_details := store.platform_endpoints.get(endpoint_arn)
685+
):
686+
# endpoint for that application with that particular token already exists
687+
if not platform_endpoint_details.platform_endpoint["Attributes"] == _attributes:
688+
raise InvalidParameterException(
689+
f"Invalid parameter: Token Reason: Endpoint {endpoint_arn} already exists with the same Token, but different attributes."
690+
)
691+
else:
692+
return CreateEndpointResponse(EndpointArn=endpoint_arn)
693+
694+
endpoint_arn = create_platform_endpoint_arn(platform_application_arn)
695+
platform_endpoint = PlatformEndpoint(
696+
platform_application_arn=endpoint_arn,
697+
platform_endpoint=Endpoint(
698+
Attributes=_attributes,
699+
EndpointArn=endpoint_arn,
700+
),
701+
)
702+
store.platform_endpoints[endpoint_arn] = platform_endpoint
703+
application.platform_endpoints[token] = endpoint_arn
704+
705+
return CreateEndpointResponse(EndpointArn=endpoint_arn)
706+
707+
def delete_endpoint(self, context: RequestContext, endpoint_arn: String, **kwargs) -> None:
708+
store = self.get_store(context.account_id, context.region)
709+
platform_endpoint_details = store.platform_endpoints.pop(endpoint_arn, None)
710+
if platform_endpoint_details:
711+
platform_application = store.platform_applications.get(
712+
platform_endpoint_details.platform_application_arn
713+
)
714+
if platform_application:
715+
platform_endpoint = platform_endpoint_details.platform_endpoint
716+
platform_application.platform_endpoints.pop(
717+
platform_endpoint["Attributes"]["Token"], None
718+
)
719+
647720
def list_endpoints_by_platform_application(
648721
self,
649722
context: RequestContext,
650723
platform_application_arn: String,
651724
next_token: String | None = None,
652725
**kwargs,
653726
) -> ListEndpointsByPlatformApplicationResponse:
654-
# TODO: stub so cleanup fixture won't fail
655-
return ListEndpointsByPlatformApplicationResponse(Endpoints=[])
727+
store = self.get_store(context.account_id, context.region)
728+
platform_application = store.platform_applications.get(platform_application_arn)
729+
if not platform_application:
730+
raise NotFoundException("PlatformApplication does not exist")
731+
endpoint_arns = platform_application.platform_endpoints.values()
732+
paginated_endpoint_arns = PaginatedList(endpoint_arns)
733+
page, token = paginated_endpoint_arns.get_page(
734+
token_generator=lambda x: get_next_page_token_from_arn(x),
735+
page_size=100,
736+
next_token=next_token,
737+
)
738+
739+
response = ListEndpointsByPlatformApplicationResponse(
740+
Endpoints=[
741+
store.platform_endpoints[endpoint_arn].platform_endpoint
742+
for endpoint_arn in page
743+
if endpoint_arn in store.platform_endpoints
744+
]
745+
)
746+
if token:
747+
response["NextToken"] = token
748+
return response
749+
750+
def get_endpoint_attributes(
751+
self, context: RequestContext, endpoint_arn: String, **kwargs
752+
) -> GetEndpointAttributesResponse:
753+
store = self.get_store(context.account_id, context.region)
754+
platform_endpoint_details = store.platform_endpoints.get(endpoint_arn)
755+
if not platform_endpoint_details:
756+
raise NotFoundException("Endpoint does not exist")
757+
attributes = platform_endpoint_details.platform_endpoint["Attributes"]
758+
return GetEndpointAttributesResponse(Attributes=attributes)
759+
760+
def set_endpoint_attributes(
761+
self, context: RequestContext, endpoint_arn: String, attributes: MapStringToString, **kwargs
762+
) -> None:
763+
store = self.get_store(context.account_id, context.region)
764+
platform_endpoint_details = store.platform_endpoints.get(endpoint_arn)
765+
if not platform_endpoint_details:
766+
raise NotFoundException("Endpoint does not exist")
767+
_validate_endpoint_attributes(attributes)
768+
attributes = attributes or {}
769+
platform_endpoint_details.platform_endpoint["Attributes"].update(attributes)
656770

657771
#
658772
# Sms operations
@@ -736,7 +850,7 @@ def _get_platform_application(
736850
parse_and_validate_platform_application_arn(platform_application_arn)
737851
try:
738852
store = SnsProvider.get_store(context.account_id, context.region)
739-
return store.platform_applications[platform_application_arn]
853+
return store.platform_applications[platform_application_arn].platform_application
740854
except KeyError:
741855
raise NotFoundException("PlatformApplication does not exist")
742856

@@ -821,6 +935,10 @@ def _validate_platform_application_name(name: str) -> None:
821935

822936

823937
def _validate_platform_application_attributes(attributes: dict) -> None:
938+
_check_empty_attributes(attributes)
939+
940+
941+
def _check_empty_attributes(attributes: dict) -> None:
824942
if not attributes:
825943
raise CommonServiceException(
826944
code="ValidationError",
@@ -829,6 +947,20 @@ def _validate_platform_application_attributes(attributes: dict) -> None:
829947
)
830948

831949

950+
def _validate_endpoint_attributes(attributes: dict, allow_empty: bool = False) -> None:
951+
if not allow_empty:
952+
_check_empty_attributes(attributes)
953+
for key in attributes:
954+
if key not in EndpointAttributeNames:
955+
raise InvalidParameterException(
956+
f"Invalid parameter: Attributes Reason: Invalid attribute name: {key}"
957+
)
958+
if len(attributes.get(EndpointAttributeNames.CUSTOM_USER_DATA, "")) > 2048:
959+
raise InvalidParameterException(
960+
"Invalid parameter: Attributes Reason: Invalid value for attribute: CustomUserData: must be at most 2048 bytes long in UTF-8 encoding"
961+
)
962+
963+
832964
def _validate_sms_attributes(attributes: dict) -> None:
833965
for k, v in attributes.items():
834966
if k not in SMS_ATTRIBUTE_NAMES:

localstack-core/localstack/services/sns/v2/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ def create_subscription_arn(topic_arn: str) -> str:
103103
return f"{topic_arn}:{uuid4()}"
104104

105105

106+
def create_platform_endpoint_arn(
107+
platform_application_arn: str,
108+
) -> str:
109+
# This is the format of an Endpoint Arn
110+
# arn:aws:sns:us-west-2:1234567890:endpoint/GCM/MyApplication/12345678-abcd-9012-efgh-345678901234
111+
return f"{platform_application_arn.replace('app', 'endpoint', 1)}/{uuid4()}"
112+
113+
106114
def encode_subscription_token_with_region(region: str) -> str:
107115
"""
108116
Create a 64 characters Subscription Token with the region encoded

0 commit comments

Comments
 (0)