Skip to content

Commit f0bae3c

Browse files
author
Michael Brewer
committed
feat: add not_found handler
1 parent 5337dc3 commit f0bae3c

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
1414

1515
from aws_lambda_powertools.event_handler import content_types
16-
from aws_lambda_powertools.event_handler.exceptions import ServiceError
16+
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
1717
from aws_lambda_powertools.shared import constants
1818
from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice
1919
from aws_lambda_powertools.shared.json_encoder import Encoder
@@ -435,7 +435,7 @@ def __init__(
435435
self._proxy_type = proxy_type
436436
self._routes: List[Route] = []
437437
self._route_keys: List[str] = []
438-
self._exception_handlers: Dict[Type, Callable] = {}
438+
self._exception_handlers: Dict[Union[int, Type], Callable] = {}
439439
self._cors = cors
440440
self._cors_enabled: bool = cors is not None
441441
self._cors_methods: Set[str] = {"OPTIONS"}
@@ -597,6 +597,11 @@ def _not_found(self, method: str) -> ResponseBuilder:
597597
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
598598
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))
599599

600+
# Allow for custom exception handlers
601+
handler = self._exception_handlers.get(404)
602+
if handler:
603+
return ResponseBuilder(handler(NotFoundError()))
604+
600605
return ResponseBuilder(
601606
Response(
602607
status_code=HTTPStatus.NOT_FOUND.value,
@@ -611,9 +616,9 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
611616
try:
612617
return ResponseBuilder(self._to_response(route.func(**args)), route)
613618
except Exception as exc:
614-
response = self._call_exception_handler(exc, route)
615-
if response:
616-
return response
619+
response_builder = self._call_exception_handler(exc, route)
620+
if response_builder:
621+
return response_builder
617622

618623
if self._debug:
619624
# If the user has turned on debug mode,
@@ -624,8 +629,10 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
624629
status_code=500,
625630
content_type=content_types.TEXT_PLAIN,
626631
body="".join(traceback.format_exc()),
627-
)
632+
),
633+
route,
628634
)
635+
629636
raise
630637

631638
def _to_response(self, result: Union[Dict, Response]) -> Response:
@@ -672,17 +679,25 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None
672679

673680
self.route(*route)(func)
674681

675-
def exception_handler(self, exception):
682+
def not_found(self):
683+
return self.exception_handler(404)
684+
685+
def exception_handler(self, exc_class_or_status_code: Union[int, Type[Exception]]):
676686
def register_exception_handler(func: Callable):
677-
self._exception_handlers[exception] = func
687+
self._exception_handlers[exc_class_or_status_code] = func
678688

679689
return register_exception_handler
680690

681-
def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]:
691+
def _lookup_exception_handler(self, exp: Exception) -> Optional[Callable]:
682692
for cls in type(exp).__mro__:
683693
if cls in self._exception_handlers:
684-
handler = self._exception_handlers[cls]
685-
return ResponseBuilder(handler(exp), route)
694+
return self._exception_handlers[cls]
695+
return None
696+
697+
def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]:
698+
handler = self._lookup_exception_handler(exp)
699+
if handler:
700+
return ResponseBuilder(handler(exp), route)
686701

687702
if isinstance(exp, ServiceError):
688703
return ResponseBuilder(

tests/functional/event_handler/test_api_gateway.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,3 +1103,22 @@ def get_lambda() -> Response:
11031103
assert result["statusCode"] == 418
11041104
assert result["headers"]["Content-Type"] == content_types.TEXT_HTML
11051105
assert result["body"] == "Foo!"
1106+
1107+
1108+
def test_exception_handler_not_found():
1109+
# GIVEN a resolver with an exception handler defined for ValueError
1110+
app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)
1111+
1112+
@app.not_found()
1113+
def handle_not_found(exc: NotFoundError):
1114+
assert isinstance(exc, NotFoundError)
1115+
return Response(status_code=404, content_type=content_types.TEXT_PLAIN, body="I am a teapot!")
1116+
1117+
# WHEN calling the event handler
1118+
# AND a ValueError is raised
1119+
result = app(LOAD_GW_EVENT, {})
1120+
1121+
# THEN call the exception_handler
1122+
assert result["statusCode"] == 404
1123+
assert result["headers"]["Content-Type"] == content_types.TEXT_PLAIN
1124+
assert result["body"] == "I am a teapot!"

0 commit comments

Comments
 (0)