Skip to content

Commit 304595c

Browse files
chore(aap): replace callbacks dict with typed block_callable attribut… (#16784)
APPSEC-61618 ## Summary - Replace the generic `callbacks` dict on `ASM_Environment` with a typed `block_callable` attribute, removing the indirect `get_value`/`set_value` pattern for storing the block request callable - Move `_block_request_callable` from `ddtrace/contrib/internal/flask/patch.py` into `ddtrace/appsec/_contrib/flask/__init__.py` as `_flask_block_request_callable`, keeping Flask-specific blocking logic in the appsec Flask module - Add `_make_block_response()` which returns a `(body, status, headers)` tuple instead of raising via `abort()`, preventing Flask's traced error handlers (`handle_exception`, `handle_http_exception`) from creating extra spans that lack fingerprint tags ## Motivation The `callbacks` dict on `ASM_Environment` was a loosely-typed bag used to store a single block callable. Replacing it with a typed `block_callable: Optional[Callable]` attribute: - Makes the data model explicit and easier to reason about - Removes unused `_CALLBACKS` and `_BLOCK_CALL` constants - Eliminates `get_value`/`set_value` indirection for this specific use case The `_make_block_response()` function fixes system test failures (`test_fingerprinting_network_block`, `test_fingerprinting_header_block`, `test_fingerprinting_endpoint_blocking`, `test_session_blocking`) where fingerprint tags were missing from some spans. The root cause was `abort()` raising an `HTTPException`, which triggered Flask's traced `handle_exception`/`handle_http_exception` handlers, creating extra spans without `_dd.appsec.fp.http.*` tags. Returning a tuple avoids Flask's error handling entirely. ## Changes - **`ddtrace/appsec/_asm_request_context.py`**: Replace `self.callbacks: dict` with `self.block_callable: Optional[Callable]`. Update `set_block_request_callable()` and `block_request()` to use the attribute directly. Remove `_CALLBACKS` and `_BLOCK_CALL` constants. - **`ddtrace/appsec/_contrib/flask/__init__.py`**: Add `_make_block_response()` (returns tuple) and `_flask_block_request_callable()` (uses `abort()`). Use `_make_block_response` in `_on_wrapped_view` for path-parameter blocking. Remove `get_value`/`set_value` imports. - **`ddtrace/contrib/internal/flask/patch.py`**: Remove `_block_request_callable()` and the `block_request_callable` context item (now handled by the appsec Flask module). - **`tests/appsec/appsec/test_asm_request_context.py`**: Update tests to use `env.block_callable` instead of `get_value("callbacks", "block")`. - **`tests/appsec/integrations/flask_tests/test_appsec_flask.py`**: Use `get_triggers()` to extract the actual `block_id` from WAF triggers instead of hardcoding `"default"`. Co-authored-by: christophe.papazian <[email protected]>
1 parent 8e3a91d commit 304595c

File tree

7 files changed

+78
-47
lines changed

7 files changed

+78
-47
lines changed

ddtrace/appsec/_asm_request_context.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,8 @@ class WARNING_TAGS(metaclass=Constant_Class):
6666

6767
_ASM_CONTEXT: Literal["_asm_env"] = "_asm_env"
6868
_WAF_ADDRESSES: Literal["waf_addresses"] = "waf_addresses"
69-
_CALLBACKS: Literal["callbacks"] = "callbacks"
7069
_TELEMETRY: Literal["telemetry"] = "telemetry"
7170
_CONTEXT_CALL: Literal["context"] = "context"
72-
_BLOCK_CALL: Literal["block"] = "block"
7371

7472

7573
GLOBAL_CALLBACKS: dict[str, list[Callable]] = {_CONTEXT_CALL: []}
@@ -114,7 +112,7 @@ def __init__(
114112
self.waf_info: Optional[Callable[[], "DDWaf_info"]] = None
115113
self.waf_addresses: dict[str, Any] = {}
116114
self.waf_callable: Optional[WafCallable] = waf_callable
117-
self.callbacks: dict[str, Any] = {}
115+
self.block_callable: Optional[Callable[[], None]] = None
118116
self.telemetry: Telemetry_result = Telemetry_result()
119117
self.addresses_sent: set[str] = set()
120118
self.waf_triggers: list[dict[str, Any]] = []
@@ -337,7 +335,7 @@ def finalize_asm_env(env: ASM_Environment) -> None:
337335
entry_span._set_tag_str(APPSEC.RC_PRODUCTS, env.rc_products)
338336

339337
# Manually clear reference cycles to simplify the work for the GC
340-
env.callbacks.clear()
338+
env.block_callable = None
341339
env.waf_callable = None
342340
core.discard_local_item(_ASM_CONTEXT)
343341

@@ -435,10 +433,10 @@ def call_waf_callback(
435433
env = get_active_asm_context()
436434
if env is not None and env.waf_callable is not None:
437435
return env.waf_callable(custom_data, crop_trace, rule_type, force_sent)
438-
else:
439-
logger.warning(WARNING_TAGS.CALL_WAF_CALLBACK_NOT_SET, extra=log_extra, stack_info=True)
440-
report_error_on_entry_span("appsec::instrumentation::diagnostic", WARNING_TAGS.CALL_WAF_CALLBACK_NOT_SET)
441-
return None
436+
437+
logger.warning(WARNING_TAGS.CALL_WAF_CALLBACK_NOT_SET, extra=log_extra, stack_info=True)
438+
report_error_on_entry_span("appsec::instrumentation::diagnostic", WARNING_TAGS.CALL_WAF_CALLBACK_NOT_SET)
439+
return None
442440

443441

444442
def call_waf_callback_no_instrumentation() -> None:
@@ -482,23 +480,25 @@ def get_headers_case_sensitive() -> bool:
482480
return get_value(_WAF_ADDRESSES, SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES_CASE, False) # type : ignore
483481

484482

485-
def set_block_request_callable(_callable: Optional[Callable[[], Any]], *_: Any) -> None:
483+
def set_block_request_callable(block_callable: Optional[Callable[[], None]]) -> None:
486484
"""
487485
Sets a callable that could be use to do a best-effort to block the request. If
488486
the callable need any params, like headers, they should be curried with
489487
functools.partial.
490488
"""
491-
if asm_config._asm_enabled and _callable:
492-
set_value(_CALLBACKS, _BLOCK_CALL, _callable)
489+
if asm_config._asm_enabled and block_callable:
490+
env = get_active_asm_context()
491+
if env is not None:
492+
env.block_callable = block_callable
493493

494494

495495
def block_request() -> None:
496496
"""
497497
Calls or returns the stored block request callable, if set.
498498
"""
499-
_callable = get_value(_CALLBACKS, _BLOCK_CALL)
500-
if _callable:
501-
_callable()
499+
env = get_active_asm_context()
500+
if env is not None and env.block_callable is not None:
501+
env.block_callable()
502502
else:
503503
logger.warning(WARNING_TAGS.BLOCK_REQUEST_NOT_CALLABLE, extra=log_extra, stack_info=True)
504504

@@ -515,7 +515,7 @@ def asm_request_context_set(
515515
remote_ip: Optional[str] = None,
516516
headers: Any = None,
517517
headers_case_sensitive: bool = False,
518-
block_request_callable: Optional[Callable] = None,
518+
block_request_callable: Optional[Callable[[], None]] = None,
519519
) -> None:
520520
set_ip(remote_ip)
521521
set_headers(headers)

ddtrace/appsec/_contrib/django/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,4 +194,4 @@ def listen():
194194
core.on("django.after_request_headers.finalize", _set_headers_and_response)
195195

196196
core.on("context.ended.django.traced_get_response", _on_context_ended)
197-
core.on("django.traced_get_response.pre", set_block_request_callable)
197+
core.on("django.traced_get_response.pre", lambda block_callable, *_: set_block_request_callable(block_callable))

ddtrace/appsec/_contrib/flask/__init__.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
import io
22
import json
33

4-
from ddtrace.appsec._asm_request_context import _CALLBACKS
54
from ddtrace.appsec._asm_request_context import _call_waf_first
65
from ddtrace.appsec._asm_request_context import _on_context_ended
76
from ddtrace.appsec._asm_request_context import _set_headers_and_response
7+
from ddtrace.appsec._asm_request_context import block_request
88
from ddtrace.appsec._asm_request_context import call_waf_callback
99
from ddtrace.appsec._asm_request_context import get_blocked
10-
from ddtrace.appsec._asm_request_context import get_value
1110
from ddtrace.appsec._asm_request_context import in_asm_context
1211
from ddtrace.appsec._asm_request_context import is_blocked
1312
from ddtrace.appsec._asm_request_context import set_block_request_callable
14-
from ddtrace.appsec._asm_request_context import set_value
1513
from ddtrace.appsec._asm_request_context import set_waf_address
1614
from ddtrace.appsec._utils import Block_config
1715
from ddtrace.contrib import trace_utils
@@ -122,6 +120,24 @@ def _on_start_response_blocked(ctx, flask_config, response_headers, status):
122120
trace_utils.set_http_meta(ctx["req_span"], flask_config, status_code=status, response_headers=response_headers)
123121

124122

123+
def _make_block_response():
124+
"""Build a blocked response as a tuple (body, status, headers).
125+
126+
Returning a tuple avoids Flask's error handling (handle_exception,
127+
handle_http_exception) which would create extra spans without
128+
fingerprint tags.
129+
"""
130+
from ddtrace.internal.utils import get_blocked as _get_blocked
131+
132+
block_config = _get_blocked()
133+
ctype = block_config.content_type if block_config else "application/json"
134+
block_id = block_config.block_id if block_config else "(default)"
135+
status = block_config.status_code if block_config else 403
136+
if block_config and block_config.type == "none":
137+
return b"", status, {"location": block_config.location}
138+
return http_utils._get_blocked_template(ctype, block_id), status, {"content-type": ctype}
139+
140+
125141
def _on_wrapped_view(kwargs):
126142
callback_block = None
127143
# if Appsec is enabled, we can try to block as we have the path parameters at that point
@@ -131,19 +147,35 @@ def _on_wrapped_view(kwargs):
131147
set_waf_address(REQUEST_PATH_PARAMS, kwargs)
132148
call_waf_callback()
133149
if is_blocked():
134-
callback_block = get_value(_CALLBACKS, "flask_block")
150+
callback_block = _make_block_response
135151
return callback_block
136152

137153

154+
def _flask_block_request_callable(span):
155+
import flask
156+
from werkzeug.exceptions import abort
157+
158+
from ddtrace.internal.utils import get_blocked as _get_blocked
159+
from ddtrace.internal.utils import set_blocked as _set_blocked
160+
161+
if not _get_blocked():
162+
_set_blocked()
163+
core.dispatch("flask.blocked_request_callable", (span,))
164+
block_config = _get_blocked()
165+
ctype = block_config.content_type if block_config else "application/json"
166+
block_id = block_config.block_id if block_config else "(default)"
167+
status = block_config.status_code if block_config else 403
168+
if block_config and block_config.type == "none":
169+
abort(flask.Response(b"", status=status, headers={"location": block_config.location}))
170+
else:
171+
abort(flask.Response(http_utils._get_blocked_template(ctype, block_id), content_type=ctype, status=status))
172+
173+
138174
def _on_pre_tracedrequest(ctx):
139175
import functools
140176

141-
current_span = ctx.span
142-
block_request_callable = ctx.get_item("block_request_callable")
143177
if asm_config._asm_enabled:
144-
from ddtrace.appsec._asm_request_context import block_request
145-
146-
set_block_request_callable(functools.partial(block_request_callable, current_span))
178+
set_block_request_callable(functools.partial(_flask_block_request_callable, ctx.span))
147179
if get_blocked():
148180
block_request()
149181

@@ -152,7 +184,6 @@ def _on_block_decided(callback):
152184
if not asm_config._asm_enabled:
153185
return
154186

155-
set_value(_CALLBACKS, "flask_block", callback)
156187
core.on("flask.block.request.content", callback, "block_requested")
157188

158189

ddtrace/contrib/internal/flask/patch.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import werkzeug
55
from werkzeug.exceptions import BadRequest
66
from werkzeug.exceptions import NotFound
7-
from werkzeug.exceptions import abort
87

98
from ddtrace.contrib import trace_utils
109
from ddtrace.ext import SpanTypes
@@ -16,8 +15,6 @@
1615
from ddtrace.internal.schema import schematize_url_operation
1716
from ddtrace.internal.schema.span_attribute_schema import SpanDirection
1817
from ddtrace.internal.utils import get_blocked
19-
from ddtrace.internal.utils import http as http_utils
20-
from ddtrace.internal.utils import set_blocked
2118

2219

2320
# Not all versions of flask/werkzeug have this mixin
@@ -543,15 +540,6 @@ def _wrap(code_or_exception, f):
543540
return _wrap(*args, **kwargs)
544541

545542

546-
def _block_request_callable(call):
547-
set_blocked()
548-
core.dispatch("flask.blocked_request_callable", (call,))
549-
block_config = get_blocked()
550-
ctype = block_config.content_type if block_config else "application/json"
551-
block_id = block_config.block_id if block_config else "(default)"
552-
abort(flask.Response(http_utils._get_blocked_template(ctype, block_id), content_type=ctype, status=403))
553-
554-
555543
def request_patcher(name):
556544
@with_instance_pin
557545
def _patched_request(pin, wrapped, instance, args, kwargs):
@@ -563,7 +551,6 @@ def _patched_request(pin, wrapped, instance, args, kwargs):
563551
service=trace_utils.int_service(pin, config.flask, pin),
564552
flask_config=config.flask,
565553
flask_request=flask.request,
566-
block_request_callable=_block_request_callable,
567554
ignored_exception_type=NotFound,
568555
tags={COMPONENT: config.flask.integration_name},
569556
) as ctx,

tests/appsec/appsec/test_asm_request_context.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ def test_context_set_and_reset():
2525
assert _asm_request_context.get_ip() == _TEST_IP
2626
assert _asm_request_context.get_headers() == _TEST_HEADERS
2727
assert _asm_request_context.get_headers_case_sensitive()
28-
assert _asm_request_context.get_value("callbacks", "block") is not None
28+
env = _asm_request_context.get_active_asm_context()
29+
assert env is not None and env.block_callable is not None
2930
assert _asm_request_context.get_ip() is None
3031
assert _asm_request_context.get_headers() == {}
31-
assert _asm_request_context.get_value("callbacks", "block") is None
32+
assert _asm_request_context.get_active_asm_context() is None
3233
assert not _asm_request_context.get_headers_case_sensitive()
3334
with asm_context(
3435
ip_addr=_TEST_IP,
@@ -64,8 +65,10 @@ def _callable():
6465

6566
with asm_context(config=config_asm):
6667
_asm_request_context.set_block_request_callable(_callable)
67-
assert _asm_request_context.get_value("callbacks", "block")() == 42
68-
assert not _asm_request_context.get_value("callbacks", "block")
68+
env = _asm_request_context.get_active_asm_context()
69+
assert env is not None and env.block_callable is not None
70+
assert env.block_callable() == 42
71+
assert _asm_request_context.get_active_asm_context() is None
6972

7073

7174
def test_call_block_callable_curried():
@@ -102,11 +105,13 @@ def test_asm_request_context_manager():
102105
assert _asm_request_context.get_ip() == _TEST_IP
103106
assert _asm_request_context.get_headers() == _TEST_HEADERS
104107
assert _asm_request_context.get_headers_case_sensitive()
105-
assert _asm_request_context.get_value("callbacks", "block")() == 42
108+
env = _asm_request_context.get_active_asm_context()
109+
assert env is not None and env.block_callable is not None
110+
assert env.block_callable() == 42
106111

107112
assert _asm_request_context.get_ip() is None
108113
assert _asm_request_context.get_headers() == {}
109-
assert _asm_request_context.get_value("callbacks", "block") is None
114+
assert _asm_request_context.get_active_asm_context() is None
110115
assert not _asm_request_context.get_headers_case_sensitive()
111116

112117

56 KB
Binary file not shown.

tests/appsec/integrations/flask_tests/test_appsec_flask.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ddtrace.appsec._constants import SPAN_DATA_NAMES
44
from ddtrace.appsec._trace_utils import block_request_if_user_blocked
5+
from ddtrace.appsec._utils import get_triggers
56
from ddtrace.contrib.internal.sqlite3.patch import patch
67
from ddtrace.ext import http
78
from ddtrace.internal import constants
@@ -70,8 +71,11 @@ def test_route(user_id):
7071
self._aux_appsec_prepare_tracer()
7172
resp = self.client.get("/checkuser/%s" % _BLOCKED_USER)
7273
assert resp.status_code == 403
73-
assert get_response_body(resp) == _format_template(constants.BLOCKED_RESPONSE_JSON, "default")
7474
root_span = self.pop_spans()[0]
75+
triggers = get_triggers(root_span)
76+
assert triggers is not None
77+
block_id = triggers[0].get("security_response_id", "default")
78+
assert get_response_body(resp) == _format_template(constants.BLOCKED_RESPONSE_JSON, block_id)
7579
assert root_span.get_tag(http.STATUS_CODE) == "403"
7680
assert root_span.get_tag(http.URL) == "http://localhost/checkuser/%s" % _BLOCKED_USER
7781
assert root_span.get_tag(http.METHOD) == "GET"
@@ -82,7 +86,11 @@ def test_route(user_id):
8286

8387
resp = self.client.get("/checkuser/%s" % _BLOCKED_USER, headers={"Accept": "text/html"})
8488
assert resp.status_code == 403
85-
assert get_response_body(resp) == _format_template(constants.BLOCKED_RESPONSE_HTML, "default")
89+
root_span = self.pop_spans()[0]
90+
triggers = get_triggers(root_span)
91+
assert triggers is not None
92+
block_id = triggers[0].get("security_response_id", "default")
93+
assert get_response_body(resp) == _format_template(constants.BLOCKED_RESPONSE_HTML, block_id)
8694

8795
resp = self.client.get("/checkuser/%s" % _ALLOWED_USER, headers={"Accept": "text/html"})
8896
assert resp.status_code == 200

0 commit comments

Comments
 (0)