Skip to content

Commit 0da9865

Browse files
authored
StepFunctions: Improve Handling of Empty SendTaskFailure Calls (#10750)
1 parent 5bfa1de commit 0da9865

File tree

11 files changed

+656
-16
lines changed

11 files changed

+656
-16
lines changed

localstack/services/stepfunctions/asl/component/common/catch/catcher_decl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ def _extract_error_cause(failure_event: FailureEvent) -> dict:
7474
f"Internal Error: invalid event details declaration in FailureEvent: '{failure_event}'."
7575
)
7676
spec_event_details: dict = list(failure_event.event_details.values())[0]
77-
error = spec_event_details["error"]
78-
cause = spec_event_details.get("cause") or ""
77+
# If no cause or error fields are given, AWS binds an empty string; otherwise it attaches the value.
78+
error = spec_event_details.get("error", "")
79+
cause = spec_event_details.get("cause", "")
7980
# Stepfunctions renames these fields to capital in this scenario.
8081
return {
8182
"Error": error,

localstack/services/stepfunctions/asl/component/common/error_name/custom_error_name.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
from typing import Final
1+
from typing import Final, Optional
22

33
from localstack.services.stepfunctions.asl.component.common.error_name.error_name import ErrorName
44

5+
ILLEGAL_CUSTOM_ERROR_PREFIX: Final[str] = "States."
6+
57

68
class CustomErrorName(ErrorName):
79
"""
810
States MAY report errors with other names, which MUST NOT begin with the prefix "States.".
911
"""
1012

11-
_ILLEGAL_PREFIX: Final[str] = "States."
12-
13-
def __init__(self, error_name: str):
14-
if error_name.startswith(CustomErrorName._ILLEGAL_PREFIX):
13+
def __init__(self, error_name: Optional[str]):
14+
if error_name is not None and error_name.startswith(ILLEGAL_CUSTOM_ERROR_PREFIX):
1515
raise ValueError(
1616
f"Custom Error Names MUST NOT begin with the prefix 'States.', got '{error_name}'."
1717
)

localstack/services/stepfunctions/asl/component/common/error_name/error_name.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
from __future__ import annotations
22

33
import abc
4-
from typing import Final
4+
from typing import Final, Optional
55

66
from localstack.services.stepfunctions.asl.component.component import Component
77

88

99
class ErrorName(Component, abc.ABC):
10-
def __init__(self, error_name: str):
11-
self.error_name: Final[str] = error_name
10+
error_name: Final[Optional[str]]
1211

13-
def matches(self, error_name: str) -> bool:
12+
def __init__(self, error_name: Optional[str]):
13+
self.error_name = error_name
14+
15+
def matches(self, error_name: Optional[str]) -> bool:
1416
return self.error_name == error_name
1517

1618
def __eq__(self, other):

localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_callback.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22
import json
33
import time
4+
from typing import Optional
45

56
from localstack.aws.api.stepfunctions import (
67
HistoryEventExecutionDataDetails,
@@ -120,10 +121,10 @@ def _get_callback_outcome_failure_event(
120121
self, env: Environment, ex: CallbackOutcomeFailureError
121122
) -> FailureEvent:
122123
callback_outcome_failure: CallbackOutcomeFailure = ex.callback_outcome_failure
123-
error: str = callback_outcome_failure.error
124+
error: Optional[str] = callback_outcome_failure.error
124125
return FailureEvent(
125126
env=env,
126-
error_name=CustomErrorName(error_name=callback_outcome_failure.error),
127+
error_name=CustomErrorName(error_name=error),
127128
event_type=HistoryEventType.TaskFailed,
128129
event_details=EventDetails(
129130
taskFailedEventDetails=TaskFailedEventDetails(

localstack/services/stepfunctions/asl/eval/callback/callback.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def __init__(self, callback_id: CallbackId, output: str):
2626

2727

2828
class CallbackOutcomeFailure(CallbackOutcome):
29-
error: Final[str]
30-
cause: Final[str]
29+
error: Final[Optional[str]]
30+
cause: Final[Optional[str]]
3131

32-
def __init__(self, callback_id: CallbackId, error: str, cause: str):
32+
def __init__(self, callback_id: CallbackId, error: Optional[str], cause: Optional[str]):
3333
super().__init__(callback_id=callback_id)
3434
self.error = error
3535
self.cause = cause

tests/aws/services/stepfunctions/templates/callbacks/callback_templates.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,9 @@ class CallbackTemplates(TemplateLoader):
3737
SQS_HEARTBEAT_SUCCESS_ON_TASK_TOKEN: Final[str] = os.path.join(
3838
_THIS_FOLDER, "statemachines/sqs_hearbeat_success_on_task_token.json5"
3939
)
40+
SQS_PARALLEL_WAIT_FOR_TASK_TOKEN: Final[str] = os.path.join(
41+
_THIS_FOLDER, "statemachines/sqs_parallel_wait_for_task_token.json5"
42+
)
43+
SQS_WAIT_FOR_TASK_TOKEN_CATCH: Final[str] = os.path.join(
44+
_THIS_FOLDER, "statemachines/sqs_wait_for_task_token_catch.json5"
45+
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
{
2+
"Comment": "SQS_PARALLEL_WAIT_FOR_TASK_TOKEN",
3+
"StartAt": "ParallelJob",
4+
"States": {
5+
"ParallelJob": {
6+
"Type": "Parallel",
7+
"Branches": [
8+
{
9+
"StartAt": "SendMessageWithWait",
10+
"States": {
11+
"SendMessageWithWait": {
12+
"Type": "Task",
13+
"Resource": "arn:aws:states:::sqs:sendMessage.waitForTaskToken",
14+
"Parameters": {
15+
"QueueUrl.$": "$.QueueUrl",
16+
"MessageBody": {
17+
"Context.$": "$",
18+
"TaskToken.$": "$$.Task.Token"
19+
}
20+
},
21+
"End": true
22+
},
23+
}
24+
}
25+
],
26+
"Catch": [
27+
{
28+
"ErrorEquals": [
29+
"States.Runtime"
30+
],
31+
"ResultPath": "$.states_runtime_error",
32+
"Next": "CaughtRuntimeError"
33+
},
34+
{
35+
"ErrorEquals": [
36+
"States.ALL"
37+
],
38+
"ResultPath": "$.states_all_error",
39+
"Next": "CaughtStatesALL"
40+
}
41+
],
42+
"End": true
43+
},
44+
"CaughtRuntimeError": {
45+
"Type": "Pass",
46+
"End": true
47+
},
48+
"CaughtStatesALL": {
49+
"Type": "Pass",
50+
"End": true
51+
},
52+
}
53+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
{
2+
"Comment": "SQS_WAIT_FOR_TASK_TOKEN_CATCH",
3+
"StartAt": "SendMessageWithWait",
4+
"States": {
5+
"SendMessageWithWait": {
6+
"Type": "Task",
7+
"Resource": "arn:aws:states:::sqs:sendMessage.waitForTaskToken",
8+
"Parameters": {
9+
"QueueUrl.$": "$.QueueUrl",
10+
"MessageBody": {
11+
"Context.$": "$",
12+
"TaskToken.$": "$$.Task.Token"
13+
}
14+
},
15+
"Catch": [
16+
{
17+
"ErrorEquals": [
18+
"States.Runtime"
19+
],
20+
"ResultPath": "$.states_runtime_error",
21+
"Next": "CaughtRuntimeError"
22+
},
23+
{
24+
"ErrorEquals": [
25+
"States.ALL"
26+
],
27+
"ResultPath": "$.states_all_error",
28+
"Next": "CaughtStatesALL"
29+
}
30+
],
31+
"End": true
32+
},
33+
"CaughtRuntimeError": {
34+
"Type": "Pass",
35+
"End": true
36+
},
37+
"CaughtStatesALL": {
38+
"Type": "Pass",
39+
"End": true
40+
}
41+
}
42+
}

tests/aws/services/stepfunctions/v2/callback/test_callback.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import json
22
import threading
33

4+
import pytest
45
from localstack_snapshot.snapshots.transformer import JsonpathTransformer, RegexTransformer
56

67
from localstack.services.stepfunctions.asl.eval.count_down_latch import CountDownLatch
8+
from localstack.testing.aws.util import is_aws_cloud
79
from localstack.testing.pytest import markers
810
from localstack.utils.strings import short_uid
911
from localstack.utils.sync import retry
@@ -697,3 +699,76 @@ def test_sqs_wait_for_task_token_no_token_parameter(
697699
definition,
698700
exec_input,
699701
)
702+
703+
@markers.aws.validated
704+
@pytest.mark.parametrize(
705+
"template",
706+
[CT.SQS_PARALLEL_WAIT_FOR_TASK_TOKEN, CT.SQS_WAIT_FOR_TASK_TOKEN_CATCH],
707+
ids=["SQS_PARALLEL_WAIT_FOR_TASK_TOKEN", "SQS_WAIT_FOR_TASK_TOKEN_CATCH"],
708+
)
709+
def test_sqs_failure_in_wait_for_task_tok_no_error_field(
710+
self,
711+
aws_client,
712+
create_iam_role_for_sfn,
713+
create_state_machine,
714+
sqs_create_queue,
715+
sfn_snapshot,
716+
template,
717+
request,
718+
):
719+
if (
720+
not is_aws_cloud()
721+
and request.node.name
722+
== "test_sqs_failure_in_wait_for_task_tok_no_error_field[SQS_PARALLEL_WAIT_FOR_TASK_TOKEN]"
723+
):
724+
# TODO: The conditions in which TaskStateAborted error events are logged requires further investigations.
725+
# These appear to be logged for Task state workers but only within Parallel states. The behaviour with
726+
# other 'Abort' errors should also be investigated.
727+
pytest.skip("Investigate occurrence logic of 'TaskStateAborted' errors")
728+
729+
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sqs_integration())
730+
sfn_snapshot.add_transformer(
731+
JsonpathTransformer(
732+
jsonpath="$..TaskToken",
733+
replacement="task_token",
734+
replace_reference=True,
735+
)
736+
)
737+
738+
queue_name = f"queue-{short_uid()}"
739+
queue_url = sqs_create_queue(QueueName=queue_name)
740+
sfn_snapshot.add_transformer(RegexTransformer(queue_url, "<sqs_queue_url>"))
741+
sfn_snapshot.add_transformer(RegexTransformer(queue_name, "<sqs_queue_name>"))
742+
743+
def _empty_send_task_failure_on_sqs_message():
744+
def _get_message_body():
745+
receive_message_response = aws_client.sqs.receive_message(
746+
QueueUrl=queue_url, MaxNumberOfMessages=1
747+
)
748+
return receive_message_response["Messages"][0]["Body"]
749+
750+
message_body_str = retry(_get_message_body, retries=60, sleep=1)
751+
message_body = json.loads(message_body_str)
752+
task_token = message_body["TaskToken"]
753+
aws_client.stepfunctions.send_task_failure(taskToken=task_token)
754+
755+
thread_send_task_failure = threading.Thread(
756+
target=_empty_send_task_failure_on_sqs_message,
757+
args=(),
758+
name="Thread_empty_send_task_failure_on_sqs_message",
759+
)
760+
thread_send_task_failure.daemon = True
761+
thread_send_task_failure.start()
762+
763+
template = CT.load_sfn_template(template)
764+
definition = json.dumps(template)
765+
766+
exec_input = json.dumps({"QueueUrl": queue_url, "Message": "test_message_txt"})
767+
create_and_record_execution(
768+
aws_client.stepfunctions,
769+
create_iam_role_for_sfn,
770+
create_state_machine,
771+
sfn_snapshot,
772+
definition,
773+
exec_input,
774+
)

0 commit comments

Comments
 (0)