Skip to content

Commit fb7ea03

Browse files
committed
feat: add custom route on route level
1 parent 5905c3f commit fb7ea03

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

fastapi/routing.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,7 @@ def api_route(
628628
include_in_schema: bool = True,
629629
response_class: Type[Response] = Default(JSONResponse),
630630
name: Optional[str] = None,
631+
route_class_override: Optional[Type[APIRoute]] = None,
631632
callbacks: Optional[List[BaseRoute]] = None,
632633
openapi_extra: Optional[Dict[str, Any]] = None,
633634
generate_unique_id_function: Callable[[APIRoute], str] = Default(
@@ -658,6 +659,7 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
658659
include_in_schema=include_in_schema,
659660
response_class=response_class,
660661
name=name,
662+
route_class_override=route_class_override,
661663
callbacks=callbacks,
662664
openapi_extra=openapi_extra,
663665
generate_unique_id_function=generate_unique_id_function,
@@ -822,6 +824,7 @@ def get(
822824
include_in_schema: bool = True,
823825
response_class: Type[Response] = Default(JSONResponse),
824826
name: Optional[str] = None,
827+
route_class_override: Optional[Type[APIRoute]] = None,
825828
callbacks: Optional[List[BaseRoute]] = None,
826829
openapi_extra: Optional[Dict[str, Any]] = None,
827830
generate_unique_id_function: Callable[[APIRoute], str] = Default(
@@ -850,6 +853,7 @@ def get(
850853
include_in_schema=include_in_schema,
851854
response_class=response_class,
852855
name=name,
856+
route_class_override=route_class_override,
853857
callbacks=callbacks,
854858
openapi_extra=openapi_extra,
855859
generate_unique_id_function=generate_unique_id_function,
@@ -878,6 +882,7 @@ def put(
878882
include_in_schema: bool = True,
879883
response_class: Type[Response] = Default(JSONResponse),
880884
name: Optional[str] = None,
885+
route_class_override: Optional[Type[APIRoute]] = None,
881886
callbacks: Optional[List[BaseRoute]] = None,
882887
openapi_extra: Optional[Dict[str, Any]] = None,
883888
generate_unique_id_function: Callable[[APIRoute], str] = Default(
@@ -906,6 +911,7 @@ def put(
906911
include_in_schema=include_in_schema,
907912
response_class=response_class,
908913
name=name,
914+
route_class_override=route_class_override,
909915
callbacks=callbacks,
910916
openapi_extra=openapi_extra,
911917
generate_unique_id_function=generate_unique_id_function,
@@ -934,6 +940,7 @@ def post(
934940
include_in_schema: bool = True,
935941
response_class: Type[Response] = Default(JSONResponse),
936942
name: Optional[str] = None,
943+
route_class_override: Optional[Type[APIRoute]] = None,
937944
callbacks: Optional[List[BaseRoute]] = None,
938945
openapi_extra: Optional[Dict[str, Any]] = None,
939946
generate_unique_id_function: Callable[[APIRoute], str] = Default(
@@ -962,6 +969,7 @@ def post(
962969
include_in_schema=include_in_schema,
963970
response_class=response_class,
964971
name=name,
972+
route_class_override=route_class_override,
965973
callbacks=callbacks,
966974
openapi_extra=openapi_extra,
967975
generate_unique_id_function=generate_unique_id_function,
@@ -990,6 +998,7 @@ def delete(
990998
include_in_schema: bool = True,
991999
response_class: Type[Response] = Default(JSONResponse),
9921000
name: Optional[str] = None,
1001+
route_class_override: Optional[Type[APIRoute]] = None,
9931002
callbacks: Optional[List[BaseRoute]] = None,
9941003
openapi_extra: Optional[Dict[str, Any]] = None,
9951004
generate_unique_id_function: Callable[[APIRoute], str] = Default(
@@ -1018,6 +1027,7 @@ def delete(
10181027
include_in_schema=include_in_schema,
10191028
response_class=response_class,
10201029
name=name,
1030+
route_class_override=route_class_override,
10211031
callbacks=callbacks,
10221032
openapi_extra=openapi_extra,
10231033
generate_unique_id_function=generate_unique_id_function,
@@ -1046,6 +1056,7 @@ def options(
10461056
include_in_schema: bool = True,
10471057
response_class: Type[Response] = Default(JSONResponse),
10481058
name: Optional[str] = None,
1059+
route_class_override: Optional[Type[APIRoute]] = None,
10491060
callbacks: Optional[List[BaseRoute]] = None,
10501061
openapi_extra: Optional[Dict[str, Any]] = None,
10511062
generate_unique_id_function: Callable[[APIRoute], str] = Default(
@@ -1074,6 +1085,7 @@ def options(
10741085
include_in_schema=include_in_schema,
10751086
response_class=response_class,
10761087
name=name,
1088+
route_class_override=route_class_override,
10771089
callbacks=callbacks,
10781090
openapi_extra=openapi_extra,
10791091
generate_unique_id_function=generate_unique_id_function,
@@ -1102,6 +1114,7 @@ def head(
11021114
include_in_schema: bool = True,
11031115
response_class: Type[Response] = Default(JSONResponse),
11041116
name: Optional[str] = None,
1117+
route_class_override: Optional[Type[APIRoute]] = None,
11051118
callbacks: Optional[List[BaseRoute]] = None,
11061119
openapi_extra: Optional[Dict[str, Any]] = None,
11071120
generate_unique_id_function: Callable[[APIRoute], str] = Default(
@@ -1130,6 +1143,7 @@ def head(
11301143
include_in_schema=include_in_schema,
11311144
response_class=response_class,
11321145
name=name,
1146+
route_class_override=route_class_override,
11331147
callbacks=callbacks,
11341148
openapi_extra=openapi_extra,
11351149
generate_unique_id_function=generate_unique_id_function,
@@ -1158,6 +1172,7 @@ def patch(
11581172
include_in_schema: bool = True,
11591173
response_class: Type[Response] = Default(JSONResponse),
11601174
name: Optional[str] = None,
1175+
route_class_override: Optional[Type[APIRoute]] = None,
11611176
callbacks: Optional[List[BaseRoute]] = None,
11621177
openapi_extra: Optional[Dict[str, Any]] = None,
11631178
generate_unique_id_function: Callable[[APIRoute], str] = Default(
@@ -1186,6 +1201,7 @@ def patch(
11861201
include_in_schema=include_in_schema,
11871202
response_class=response_class,
11881203
name=name,
1204+
route_class_override=route_class_override,
11891205
callbacks=callbacks,
11901206
openapi_extra=openapi_extra,
11911207
generate_unique_id_function=generate_unique_id_function,
@@ -1214,6 +1230,7 @@ def trace(
12141230
include_in_schema: bool = True,
12151231
response_class: Type[Response] = Default(JSONResponse),
12161232
name: Optional[str] = None,
1233+
route_class_override: Optional[Type[APIRoute]] = None,
12171234
callbacks: Optional[List[BaseRoute]] = None,
12181235
openapi_extra: Optional[Dict[str, Any]] = None,
12191236
generate_unique_id_function: Callable[[APIRoute], str] = Default(
@@ -1243,6 +1260,7 @@ def trace(
12431260
include_in_schema=include_in_schema,
12441261
response_class=response_class,
12451262
name=name,
1263+
route_class_override=route_class_override,
12461264
callbacks=callbacks,
12471265
openapi_extra=openapi_extra,
12481266
generate_unique_id_function=generate_unique_id_function,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from typing import Callable
2+
from urllib.request import Request
3+
4+
import pytest
5+
from fastapi import APIRouter, FastAPI, HTTPException, status
6+
from fastapi.openapi.models import Response
7+
from fastapi.routing import APIRoute
8+
from fastapi.testclient import TestClient
9+
10+
app = FastAPI()
11+
router = APIRouter()
12+
13+
14+
class CustomRoute(APIRoute):
15+
def get_route_handler(self) -> Callable:
16+
original_route_handler = super().get_route_handler()
17+
18+
async def custom_route_handler(request: Request) -> Response:
19+
if "test_header" not in request.headers:
20+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
21+
return await original_route_handler(request)
22+
23+
return custom_route_handler
24+
25+
26+
@router.get("/a")
27+
def get_a():
28+
return {"msg": "A"}
29+
30+
31+
@router.get("/b", route_class_override=CustomRoute)
32+
def get_b():
33+
return {"msg": "B"}
34+
35+
36+
app.include_router(router=router, prefix="")
37+
38+
39+
client = TestClient(app)
40+
41+
42+
@pytest.mark.parametrize(
43+
"path,expected_status,headers",
44+
[
45+
("/a", 200, {"test_header": "value"}),
46+
("/a", 200, None),
47+
("/b", 200, {"test_header": "value"}),
48+
("/b", 400, None),
49+
],
50+
ids=[
51+
"/a with test_header header",
52+
"/a without test_header headers",
53+
"/b with test_header headers",
54+
"/b without test_header headers",
55+
],
56+
)
57+
def test_get_path(path, expected_status, headers):
58+
response = client.get(path, headers=headers)
59+
assert response.status_code == expected_status

0 commit comments

Comments
 (0)