Skip to content
This repository was archived by the owner on Mar 23, 2026. It is now read-only.

Commit 7e1346e

Browse files
authored
StepFunctions, support for callback failure (#8386)
1 parent ca22507 commit 7e1346e

File tree

15 files changed

+622
-11
lines changed

15 files changed

+622
-11
lines changed

localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
StateTaskServiceCallback,
1616
)
1717
from localstack.services.stepfunctions.asl.component.state.state_props import StateProps
18+
from localstack.services.stepfunctions.asl.eval.callback.callback import CallbackOutcomeFailureError
1819
from localstack.services.stepfunctions.asl.eval.environment import Environment
1920
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
2021
from localstack.utils.aws import aws_stack
@@ -24,7 +25,10 @@
2425
class StateTaskServiceAwsSdk(StateTaskServiceCallback):
2526
_API_NAMES: dict[str, str] = {"sfn": "stepfunctions"}
2627
_SFN_TO_BOTO_PARAM_NORMALISERS = {
27-
"stepfunctions": {"send_task_success": {"Output": "output", "TaskToken": "taskToken"}}
28+
"stepfunctions": {
29+
"send_task_success": {"Output": "output", "TaskToken": "taskToken"},
30+
"send_task_failure": {"TaskToken": "taskToken", "Error": "error", "Cause": "cause"},
31+
}
2832
}
2933

3034
_normalised_api_name: str
@@ -70,6 +74,8 @@ def _get_task_failure_event(self, error: str, cause: str) -> FailureEvent:
7074
)
7175

7276
def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
77+
if isinstance(ex, CallbackOutcomeFailureError):
78+
return self._get_callback_outcome_failure_event(ex=ex)
7379
if isinstance(ex, TimeoutError):
7480
return self._get_timed_out_failure_event()
7581

localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_callback.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,29 @@
44
from localstack.aws.api.stepfunctions import (
55
HistoryEventExecutionDataDetails,
66
HistoryEventType,
7+
TaskFailedEventDetails,
78
TaskScheduledEventDetails,
89
TaskStartedEventDetails,
910
TaskSubmittedEventDetails,
1011
TaskSucceededEventDetails,
1112
)
13+
from localstack.services.stepfunctions.asl.component.common.error_name.custom_error_name import (
14+
CustomErrorName,
15+
)
16+
from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import (
17+
FailureEvent,
18+
)
1219
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
1320
ResourceCondition,
1421
)
1522
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service import (
1623
StateTaskService,
1724
)
18-
from localstack.services.stepfunctions.asl.eval.callback.callback import CallbackOutcomeSuccess
25+
from localstack.services.stepfunctions.asl.eval.callback.callback import (
26+
CallbackOutcomeFailure,
27+
CallbackOutcomeFailureError,
28+
CallbackOutcomeSuccess,
29+
)
1930
from localstack.services.stepfunctions.asl.eval.environment import Environment
2031
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
2132
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
@@ -40,12 +51,30 @@ def _wait_for_task_token(self, env: Environment) -> None: # noqa
4051
if isinstance(outcome, CallbackOutcomeSuccess):
4152
outcome_output = json.loads(outcome.output)
4253
env.stack.append(outcome_output)
54+
elif isinstance(outcome, CallbackOutcomeFailure):
55+
raise CallbackOutcomeFailureError(callback_outcome_failure=outcome)
4356
else:
44-
raise NotImplementedError(f"Unsupported Callbackoutcome type '{type(outcome)}'.")
57+
raise NotImplementedError(f"Unsupported CallbackOutcome type '{type(outcome)}'.")
4558

4659
def _is_condition(self):
4760
return self.resource.condition is not None
4861

62+
def _get_callback_outcome_failure_event(self, ex: CallbackOutcomeFailureError) -> FailureEvent:
63+
callback_outcome_failure: CallbackOutcomeFailure = ex.callback_outcome_failure
64+
error: str = callback_outcome_failure.error
65+
return FailureEvent(
66+
error_name=CustomErrorName(error_name=callback_outcome_failure.error),
67+
event_type=HistoryEventType.TaskFailed,
68+
event_details=EventDetails(
69+
taskFailedEventDetails=TaskFailedEventDetails(
70+
resourceType=self._get_sfn_resource_type(),
71+
resource=self._get_sfn_resource(),
72+
error=error,
73+
cause=callback_outcome_failure.cause,
74+
)
75+
),
76+
)
77+
4978
def _eval_execution(self, env: Environment) -> None:
5079
parameters = self._eval_parameters(env=env)
5180
parameters_str = to_json_str(parameters)

localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import (
2222
StateTaskServiceCallback,
2323
)
24+
from localstack.services.stepfunctions.asl.eval.callback.callback import CallbackOutcomeFailureError
2425
from localstack.services.stepfunctions.asl.eval.environment import Environment
2526
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
2627

@@ -59,6 +60,8 @@ def _error_cause_from_client_error(client_error: ClientError) -> tuple[str, str]
5960
return error, cause
6061

6162
def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
63+
if isinstance(ex, CallbackOutcomeFailureError):
64+
return self._get_callback_outcome_failure_event(ex=ex)
6265
if isinstance(ex, TimeoutError):
6366
return self._get_timed_out_failure_event()
6467

localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.state_task_service_callback import (
1313
StateTaskServiceCallback,
1414
)
15+
from localstack.services.stepfunctions.asl.eval.callback.callback import CallbackOutcomeFailureError
1516
from localstack.services.stepfunctions.asl.eval.environment import Environment
1617
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
1718
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
@@ -38,6 +39,8 @@ def _get_supported_parameters(self) -> Optional[set[str]]:
3839
return self._SUPPORTED_API_PARAM_BINDINGS.get(self.resource.api_action.lower())
3940

4041
def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
42+
if isinstance(ex, CallbackOutcomeFailureError):
43+
return self._get_callback_outcome_failure_event(ex=ex)
4144
if isinstance(ex, TimeoutError):
4245
return self._get_timed_out_failure_event()
4346

localstack/services/stepfunctions/asl/eval/callback/callback.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ def __init__(self, callback_consumer_error: CallbackConsumerError):
7676
self.callback_consumer_error = callback_consumer_error
7777

7878

79+
class CallbackOutcomeFailureError(RuntimeError):
80+
callback_outcome_failure: CallbackOutcomeFailure
81+
82+
def __init__(self, callback_outcome_failure: CallbackOutcomeFailure):
83+
self.callback_outcome_failure = callback_outcome_failure
84+
85+
7986
class CallbackPoolManager:
8087
_pool: dict[CallbackId, CallbackEndpoint]
8188

localstack/services/stepfunctions/provider_v2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from localstack.services.stepfunctions.asl.eval.callback.callback import (
4747
CallbackConsumerTimeout,
4848
CallbackNotifyConsumerError,
49+
CallbackOutcomeFailure,
4950
CallbackOutcomeSuccess,
5051
)
5152
from localstack.services.stepfunctions.backend.execution import Execution
@@ -183,7 +184,7 @@ def send_task_failure(
183184
error: SensitiveError = None,
184185
cause: SensitiveCause = None,
185186
) -> SendTaskFailureOutput:
186-
outcome = CallbackOutcomeSuccess(callback_id=task_token)
187+
outcome = CallbackOutcomeFailure(callback_id=task_token, error=error, cause=cause)
187188
store = self.get_store(context)
188189
for exec in store.executions.values():
189190
try:

tests/integration/stepfunctions/conftest.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,24 @@ def _create_state_machine(sqs_queue_url):
148148
)
149149

150150
return _create_state_machine
151+
152+
153+
@pytest.fixture
154+
def sqs_send_task_failure_state_machine(aws_client, create_state_machine, create_iam_role_for_sfn):
155+
def _create_state_machine(sqs_queue_url):
156+
snf_role_arn = create_iam_role_for_sfn()
157+
sm_name: str = f"sqs_send_task_failure_state_machine_{short_uid()}"
158+
template = CallbackTemplates.load_sfn_template(CallbackTemplates.SQS_FAILURE_ON_TASK_TOKEN)
159+
definition = json.dumps(template)
160+
161+
creation_resp = create_state_machine(
162+
name=sm_name, definition=definition, roleArn=snf_role_arn
163+
)
164+
state_machine_arn = creation_resp["stateMachineArn"]
165+
166+
aws_client.stepfunctions.start_execution(
167+
stateMachineArn=state_machine_arn,
168+
input=json.dumps({"QueueUrl": sqs_queue_url, "Iterator": {"Count": 300}}),
169+
)
170+
171+
return _create_state_machine

tests/integration/stepfunctions/templates/callbacks/callback_templates.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ class CallbackTemplates(TemplateLoader):
1010
SQS_SUCCESS_ON_TASK_TOKEN: Final[str] = os.path.join(
1111
_THIS_FOLDER, "statemachines/sqs_success_on_task_token.json5"
1212
)
13-
13+
SQS_FAILURE_ON_TASK_TOKEN: Final[str] = os.path.join(
14+
_THIS_FOLDER, "statemachines/sqs_failure_on_task_token.json5"
15+
)
1416
SQS_WAIT_FOR_TASK_TOKEN: Final[str] = os.path.join(
1517
_THIS_FOLDER, "statemachines/sqs_wait_for_task_token.json5"
1618
)
17-
1819
SQS_WAIT_FOR_TASK_TOKEN_WITH_TIMEOUT: Final[str] = os.path.join(
1920
_THIS_FOLDER, "statemachines/sqs_wait_for_task_token_with_timeout.json5"
2021
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
{
2+
"Comment": "sqs_failure_on_task_token",
3+
"StartAt": "Iterate",
4+
"States": {
5+
"Iterate": {
6+
"Type": "Pass",
7+
"Parameters": {
8+
"Count.$": "States.MathAdd($.Iterator.Count, -1)"
9+
},
10+
"ResultPath": "$.Iterator",
11+
"Next": "IterateStep"
12+
},
13+
"IterateStep": {
14+
"Type": "Choice",
15+
"Choices": [
16+
{
17+
"Variable": "$.Iterator.Count",
18+
"NumericLessThanEquals": 0,
19+
"Next": "NoMoreCycles"
20+
}
21+
],
22+
"Default": "WaitAndReceive",
23+
},
24+
"WaitAndReceive": {
25+
"Type": "Wait",
26+
"Seconds": 1,
27+
"Next": "Receive"
28+
},
29+
"Receive": {
30+
"Type": "Task",
31+
"Parameters": {
32+
"QueueUrl.$": "$.QueueUrl"
33+
},
34+
"Resource": "arn:aws:states:::aws-sdk:sqs:receiveMessage",
35+
"ResultPath": "$.SQSOutput",
36+
"Next": "CheckMessages",
37+
},
38+
"CheckMessages": {
39+
"Type": "Choice",
40+
"Choices": [
41+
{
42+
"Variable": "$.SQSOutput.Messages",
43+
"IsPresent": true,
44+
"Next": "SendFailure"
45+
}
46+
],
47+
"Default": "Iterate"
48+
},
49+
"SendFailure": {
50+
"Type": "Map",
51+
"InputPath": "$.SQSOutput.Messages",
52+
"ItemProcessor": {
53+
"ProcessorConfig": {
54+
"Mode": "INLINE"
55+
},
56+
"StartAt": "ParseBody",
57+
"States": {
58+
"ParseBody": {
59+
"Type": "Pass",
60+
"Parameters": {
61+
"Body.$": "States.StringToJson($.Body)"
62+
},
63+
"Next": "Send"
64+
},
65+
"Send": {
66+
"Type": "Task",
67+
"Resource": "arn:aws:states:::aws-sdk:sfn:sendTaskFailure",
68+
"Parameters": {
69+
"Error": "Failure error",
70+
"Cause": "Failure cause",
71+
"TaskToken.$": "$.Body.TaskToken"
72+
},
73+
"End": true
74+
}
75+
},
76+
},
77+
"ResultPath": null,
78+
"Next": "Iterate"
79+
},
80+
"NoMoreCycles": {
81+
"Type": "Pass",
82+
"End": true
83+
}
84+
},
85+
}

tests/integration/stepfunctions/templates/errorhandling/error_handling_templates.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ class ErrorHandlingTemplate(TemplateLoader):
4141
_THIS_FOLDER, "statemachines/task_service_sqs_send_msg_catch.json5"
4242
)
4343

44+
AWS_SERVICE_SQS_SEND_MSG_CATCH_TOKEN_FAILURE: Final[str] = os.path.join(
45+
_THIS_FOLDER, "statemachines/aws_service_sqs_send_msg_catch_token_failure.json5"
46+
)
47+
4448
# Lambda Functions.
4549
LAMBDA_FUNC_RAISE_EXCEPTION: Final[str] = os.path.join(
4650
_THIS_FOLDER, "lambdafunctions/raise_exception.py"

0 commit comments

Comments
 (0)