Skip to content

Commit 42ed9a1

Browse files
tiuringregfurman
andauthored
[SFN] Add new TestState API capabilities (#13418)
New capabilities have recently been added to TestState API. This commit adds the following support for the new capabilities: - Add mocking support – Mock state outputs and errors without invoking downstream services - Add support for Map (inline and distributed) states - Add support to test specific states within a full state machine definition using the new stateName parameter. - Add support for Catch and Retry fields - Add new inspection data - Rename `mocking` package to l`ocal_mocking`: clearly mark mocking functionality related to Step Functions Local. This helps to distinguish between Local mocks and TestState mocks. Co-authored-by: Greg Furman <[email protected]> Co-authored-by: Greg Furman <[email protected]>
1 parent 5cac820 commit 42ed9a1

File tree

141 files changed

+10828
-1240
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

141 files changed

+10828
-1240
lines changed

localstack-core/localstack/services/stepfunctions/asl/component/common/path/result_path.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, result_path_src: str | None):
1818
def _eval_body(self, env: Environment) -> None:
1919
state_input = env.states.get_input()
2020

21-
# Discard task output if there is one, and set the output ot be the state's input.
21+
# Discard task output if there is one, and set the output to be the state's input.
2222
if self.result_path_src is None:
2323
env.stack.clear()
2424
env.stack.append(state_input)

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/execute_state.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,6 @@ def _eval_state(self, env: Environment) -> None:
251251
)
252252
error_output = self._construct_error_output_value(failure_event=failure_event)
253253
env.states.set_error_output(error_output)
254-
env.states.set_result(error_output)
255254

256255
if self.retry:
257256
retry_outcome: RetryOutcome = self._handle_retry(

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_map/state_map.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,6 @@ def _eval_state(self, env: Environment) -> None:
311311
failure_event: FailureEvent = self._from_error(env=env, ex=ex)
312312
error_output = self._construct_error_output_value(failure_event=failure_event)
313313
env.states.set_error_output(error_output)
314-
env.states.set_result(error_output)
315314

316315
if self.retry:
317316
retry_outcome: RetryOutcome = self._handle_retry(

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
77
StateCredentials,
88
)
9-
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.mock_eval_utils import (
10-
eval_mocked_response,
9+
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.local_mock_eval_utils import (
10+
eval_local_mocked_response,
1111
)
1212
from localstack.services.stepfunctions.asl.eval.environment import Environment
1313
from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for
1414
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
15-
from localstack.services.stepfunctions.mocking.mock_config import MockedResponse
15+
from localstack.services.stepfunctions.local_mocking.mock_config import LocalMockedResponse
1616
from localstack.utils.collections import select_from_typed_dict
1717
from localstack.utils.strings import to_bytes
1818

@@ -42,9 +42,9 @@ def _from_payload(payload_streaming_body: IO[bytes]) -> Any | str:
4242
return decoded_data
4343

4444

45-
def _mocked_invoke_lambda_function(env: Environment) -> InvocationResponse:
46-
mocked_response: MockedResponse = env.get_current_mocked_response()
47-
eval_mocked_response(env=env, mocked_response=mocked_response)
45+
def _local_mocked_invoke_lambda_function(env: Environment) -> InvocationResponse:
46+
mocked_response: LocalMockedResponse = env.get_current_local_mocked_response()
47+
eval_local_mocked_response(env=env, mocked_response=mocked_response)
4848
invocation_resp: InvocationResponse = env.stack.pop()
4949
return invocation_resp
5050

@@ -68,8 +68,8 @@ def _invoke_lambda_function(
6868
def execute_lambda_function_integration(
6969
env: Environment, parameters: dict, region: str, state_credentials: StateCredentials
7070
) -> None:
71-
if env.is_mocked_mode():
72-
invocation_response: InvocationResponse = _mocked_invoke_lambda_function(env=env)
71+
if env.is_local_mocked_mode():
72+
invocation_response: InvocationResponse = _local_mocked_invoke_lambda_function(env=env)
7373
else:
7474
invocation_response: InvocationResponse = _invoke_lambda_function(
7575
parameters=parameters, region=region, state_credentials=state_credentials

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/mock_eval_utils.py renamed to localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/local_mock_eval_utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010
)
1111
from localstack.services.stepfunctions.asl.eval.environment import Environment
1212
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
13-
from localstack.services.stepfunctions.mocking.mock_config import (
14-
MockedResponse,
15-
MockedResponseReturn,
16-
MockedResponseThrow,
13+
from localstack.services.stepfunctions.local_mocking.mock_config import (
14+
LocalMockedResponse,
15+
LocalMockedResponseReturn,
16+
LocalMockedResponseThrow,
1717
)
1818

1919

20-
def _eval_mocked_response_throw(env: Environment, mocked_response: MockedResponseThrow) -> None:
20+
def _eval_mocked_response_throw(
21+
env: Environment, mocked_response: LocalMockedResponseThrow
22+
) -> None:
2123
task_failed_event_details = TaskFailedEventDetails(
2224
error=mocked_response.error, cause=mocked_response.cause
2325
)
@@ -31,15 +33,17 @@ def _eval_mocked_response_throw(env: Environment, mocked_response: MockedRespons
3133
raise FailureEventException(failure_event=failure_event)
3234

3335

34-
def _eval_mocked_response_return(env: Environment, mocked_response: MockedResponseReturn) -> None:
36+
def _eval_mocked_response_return(
37+
env: Environment, mocked_response: LocalMockedResponseReturn
38+
) -> None:
3539
payload_copy = copy.deepcopy(mocked_response.payload)
3640
env.stack.append(payload_copy)
3741

3842

39-
def eval_mocked_response(env: Environment, mocked_response: MockedResponse) -> None:
40-
if isinstance(mocked_response, MockedResponseReturn):
43+
def eval_local_mocked_response(env: Environment, mocked_response: LocalMockedResponse) -> None:
44+
if isinstance(mocked_response, LocalMockedResponseReturn):
4145
_eval_mocked_response_return(env=env, mocked_response=mocked_response)
42-
elif isinstance(mocked_response, MockedResponseThrow):
46+
elif isinstance(mocked_response, LocalMockedResponseThrow):
4347
_eval_mocked_response_throw(env=env, mocked_response=mocked_response)
4448
else:
4549
raise RuntimeError(f"Invalid MockedResponse type '{type(mocked_response)}'")

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
3434
StateCredentials,
3535
)
36-
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.mock_eval_utils import (
37-
eval_mocked_response,
36+
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.local_mock_eval_utils import (
37+
eval_local_mocked_response,
3838
)
3939
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
4040
ResourceRuntimePart,
@@ -47,7 +47,7 @@
4747
from localstack.services.stepfunctions.asl.eval.environment import Environment
4848
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
4949
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
50-
from localstack.services.stepfunctions.mocking.mock_config import MockedResponse
50+
from localstack.services.stepfunctions.local_mocking.mock_config import LocalMockedResponse
5151
from localstack.services.stepfunctions.quotas import is_within_size_quota
5252
from localstack.utils.strings import camel_to_snake_case, snake_to_camel_case, to_bytes, to_str
5353

@@ -356,9 +356,9 @@ def _eval_execution(self, env: Environment) -> None:
356356
normalised_parameters = copy.deepcopy(raw_parameters)
357357
self._normalise_parameters(normalised_parameters)
358358

359-
if env.is_mocked_mode():
360-
mocked_response: MockedResponse = env.get_current_mocked_response()
361-
eval_mocked_response(env=env, mocked_response=mocked_response)
359+
if env.is_local_mocked_mode():
360+
mocked_response: LocalMockedResponse = env.get_current_local_mocked_response()
361+
eval_local_mocked_response(env=env, mocked_response=mocked_response)
362362
else:
363363
self._eval_service_task(
364364
env=env,

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def _after_eval_execution(
346346
)
347347
),
348348
)
349-
if not env.is_mocked_mode():
349+
if not env.is_local_mocked_mode() and not env.is_test_state_mocked_mode():
350350
self._eval_integration_pattern(
351351
env=env,
352352
resource_runtime_part=resource_runtime_part,

localstack-core/localstack/services/stepfunctions/asl/component/state/state_fail/state_fail.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
FailureEventException,
88
)
99
from localstack.services.stepfunctions.asl.component.state.state import CommonStateField
10+
from localstack.services.stepfunctions.asl.component.state.state_continue_with import (
11+
ContinueWithEnd,
12+
)
1013
from localstack.services.stepfunctions.asl.component.state.state_fail.cause_decl import CauseDecl
1114
from localstack.services.stepfunctions.asl.component.state.state_fail.error_decl import ErrorDecl
1215
from localstack.services.stepfunctions.asl.component.state.state_props import StateProps
@@ -27,6 +30,7 @@ def from_state_props(self, state_props: StateProps) -> None:
2730
super().from_state_props(state_props)
2831
self.cause = state_props.get(CauseDecl)
2932
self.error = state_props.get(ErrorDecl)
33+
self.continue_with = ContinueWithEnd()
3034

3135
def _eval_state(self, env: Environment) -> None:
3236
task_failed_event_details = TaskFailedEventDetails()
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import abc
2+
import copy
3+
from typing import Generic, TypeVar
4+
5+
from localstack.services.stepfunctions.asl.component.eval_component import EvalComponent
6+
from localstack.services.stepfunctions.asl.component.state.state import CommonStateField
7+
from localstack.services.stepfunctions.asl.component.state.state_continue_with import (
8+
ContinueWithNext,
9+
)
10+
from localstack.services.stepfunctions.asl.eval.test_state.environment import TestStateEnvironment
11+
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
12+
from localstack.services.stepfunctions.backend.test_state.test_state_mock import (
13+
TestStateResponseReturn,
14+
TestStateResponseThrow,
15+
eval_mocked_response_throw,
16+
)
17+
18+
T = TypeVar("T", bound=CommonStateField)
19+
20+
21+
class MockedBaseState(Generic[T], abc.ABC):
22+
is_single_state: bool
23+
_wrapped: T
24+
25+
def __init__(self, wrapped: T):
26+
super().__init__()
27+
self._wrapped = wrapped
28+
self.apply_patches()
29+
30+
def apply_patches(self):
31+
self._apply_patches()
32+
33+
original_eval_body = self._wrapped._eval_body
34+
self._wrapped._eval_body = self.wrap_with_post_return(
35+
original_eval_body, self.stop_execution
36+
)
37+
38+
@abc.abstractmethod
39+
def _apply_patches(self): ...
40+
41+
@classmethod
42+
def wrap(cls, state: T, is_single_state: bool = False) -> T:
43+
cls.is_single_state = is_single_state
44+
cls._wrapped = state
45+
return cls(state)._wrapped
46+
47+
def __getattr__(self, attr: str):
48+
return getattr(self._wrapped, attr)
49+
50+
@classmethod
51+
def before_mock(self, env: TestStateEnvironment):
52+
return
53+
54+
@classmethod
55+
def do_mock(self, env: TestStateEnvironment):
56+
mocked_response = env.mock.get_next_result()
57+
if not mocked_response:
58+
return
59+
60+
if isinstance(mocked_response, TestStateResponseThrow):
61+
eval_mocked_response_throw(env, mocked_response)
62+
return
63+
64+
if isinstance(mocked_response, TestStateResponseReturn):
65+
result_copy = copy.deepcopy(mocked_response.payload)
66+
env.stack.append(result_copy)
67+
68+
@classmethod
69+
def after_mock(self, env: TestStateEnvironment):
70+
return
71+
72+
@classmethod
73+
def wrap_with_mock(cls, original_method):
74+
def wrapper(env: TestStateEnvironment, *args, **kwargs):
75+
if not env.mock.is_mocked():
76+
original_method(env, *args, **kwargs)
77+
return
78+
79+
cls.before_mock(env)
80+
try:
81+
cls.do_mock(env)
82+
finally:
83+
cls.after_mock(env)
84+
85+
return wrapper
86+
87+
@staticmethod
88+
def wrap_with_post_return(method, post_return_fn):
89+
def wrapper(env: TestStateEnvironment, *args, **kwargs):
90+
try:
91+
method(env, *args, **kwargs)
92+
finally:
93+
post_return_fn(env)
94+
95+
return wrapper
96+
97+
@staticmethod
98+
def _eval_with_inspect(component: EvalComponent, key: str):
99+
if not component:
100+
return
101+
102+
eval_body_fn = component._eval_body
103+
104+
def _update(env: TestStateEnvironment, *args, **kwargs):
105+
# if inspectionData already populated, don't execute again
106+
if key in env.inspection_data:
107+
return
108+
109+
eval_body_fn(env, *args, **kwargs)
110+
result = env.stack[-1]
111+
env.inspection_data[key] = to_json_str(result)
112+
113+
component._eval_body = MockedBaseState.wrap_with_post_return(eval_body_fn, _update)
114+
115+
def stop_execution(self, env: TestStateEnvironment):
116+
if isinstance(self._wrapped.continue_with, ContinueWithNext):
117+
if next_state := self._wrapped.continue_with.next_state:
118+
env.set_choice_selected(next_state.name)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from localstack.services.stepfunctions.asl.component.state.state import CommonStateField
2+
from localstack.services.stepfunctions.asl.component.state.state_choice.state_choice import (
3+
StateChoice,
4+
)
5+
from localstack.services.stepfunctions.asl.component.state.state_continue_with import (
6+
ContinueWithEnd,
7+
)
8+
from localstack.services.stepfunctions.asl.component.state.state_fail.state_fail import StateFail
9+
from localstack.services.stepfunctions.asl.component.state.state_pass.state_pass import StatePass
10+
from localstack.services.stepfunctions.asl.component.state.state_succeed.state_succeed import (
11+
StateSucceed,
12+
)
13+
from localstack.services.stepfunctions.asl.component.test_state.state.base_mock import (
14+
MockedBaseState,
15+
)
16+
from localstack.services.stepfunctions.asl.eval.test_state.environment import TestStateEnvironment
17+
18+
19+
class MockedCommonState(MockedBaseState[CommonStateField]):
20+
def add_inspection_data(self, env: TestStateEnvironment):
21+
state = self._wrapped
22+
23+
if not isinstance(state, StatePass):
24+
if not self.is_single_state:
25+
return
26+
27+
if "afterInputPath" not in env.inspection_data:
28+
env.inspection_data["afterInputPath"] = env.states.get_input()
29+
return
30+
31+
# If not a terminal state, only populate inspection data from pre-processor.
32+
if not isinstance(self._wrapped.continue_with, ContinueWithEnd):
33+
return
34+
35+
if state.result:
36+
# TODO: investigate interactions between these inspectionData field types.
37+
# i.e parity tests shows that if "Result" is defined, 'afterInputPath' and 'afterParameters'
38+
# cannot be present in the inspection data.
39+
env.inspection_data.pop("afterInputPath", None)
40+
env.inspection_data.pop("afterParameters", None)
41+
42+
if "afterResultSelector" not in env.inspection_data:
43+
env.inspection_data["afterResultSelector"] = state.result.result_obj
44+
45+
if "afterResultPath" not in env.inspection_data:
46+
env.inspection_data["afterResultPath"] = env.inspection_data.get(
47+
"afterResultSelector", env.states.get_input()
48+
)
49+
return
50+
51+
if "afterInputPath" not in env.inspection_data:
52+
env.inspection_data["afterInputPath"] = env.states.get_input()
53+
54+
if "afterParameters" not in env.inspection_data:
55+
env.inspection_data["afterParameters"] = env.inspection_data.get(
56+
"afterInputPath", env.states.get_input()
57+
)
58+
59+
if "afterResultSelector" not in env.inspection_data:
60+
env.inspection_data["afterResultSelector"] = env.inspection_data["afterParameters"]
61+
62+
if "afterResultPath" not in env.inspection_data:
63+
env.inspection_data["afterResultPath"] = env.inspection_data.get(
64+
"afterResultSelector", env.states.get_input()
65+
)
66+
67+
def _apply_patches(self):
68+
if not isinstance(self._wrapped, (StatePass, StateFail, StateChoice, StateSucceed)):
69+
raise ValueError("Needs to be a Pass, Fail, Choice, or Succeed state.")
70+
71+
original_eval_body = self.wrap_with_mock(self._wrapped._eval_body)
72+
73+
def mock_eval_execution(env: TestStateEnvironment):
74+
original_eval_body(env)
75+
env.set_choice_selected(env.next_state_name)
76+
77+
mock_eval_execution = self.wrap_with_post_return(
78+
method=mock_eval_execution,
79+
post_return_fn=self.add_inspection_data,
80+
)
81+
82+
self._wrapped._eval_body = mock_eval_execution

0 commit comments

Comments
 (0)