13
13
from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , Union
14
14
15
15
from aws_lambda_powertools .event_handler import content_types
16
- from aws_lambda_powertools .event_handler .exceptions import ServiceError
16
+ from aws_lambda_powertools .event_handler .exceptions import NotFoundError , ServiceError
17
17
from aws_lambda_powertools .shared import constants
18
18
from aws_lambda_powertools .shared .functions import resolve_truthy_env_var_choice
19
19
from aws_lambda_powertools .shared .json_encoder import Encoder
@@ -435,7 +435,7 @@ def __init__(
435
435
self ._proxy_type = proxy_type
436
436
self ._routes : List [Route ] = []
437
437
self ._route_keys : List [str ] = []
438
- self ._exception_handlers : Dict [Type , Callable ] = {}
438
+ self ._exception_handlers : Dict [Union [ int , Type ] , Callable ] = {}
439
439
self ._cors = cors
440
440
self ._cors_enabled : bool = cors is not None
441
441
self ._cors_methods : Set [str ] = {"OPTIONS" }
@@ -597,6 +597,11 @@ def _not_found(self, method: str) -> ResponseBuilder:
597
597
headers ["Access-Control-Allow-Methods" ] = "," .join (sorted (self ._cors_methods ))
598
598
return ResponseBuilder (Response (status_code = 204 , content_type = None , headers = headers , body = None ))
599
599
600
+ # Allow for custom exception handlers
601
+ handler = self ._exception_handlers .get (404 )
602
+ if handler :
603
+ return ResponseBuilder (handler (NotFoundError ()))
604
+
600
605
return ResponseBuilder (
601
606
Response (
602
607
status_code = HTTPStatus .NOT_FOUND .value ,
@@ -611,9 +616,9 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
611
616
try :
612
617
return ResponseBuilder (self ._to_response (route .func (** args )), route )
613
618
except Exception as exc :
614
- response = self ._call_exception_handler (exc , route )
615
- if response :
616
- return response
619
+ response_builder = self ._call_exception_handler (exc , route )
620
+ if response_builder :
621
+ return response_builder
617
622
618
623
if self ._debug :
619
624
# If the user has turned on debug mode,
@@ -624,8 +629,10 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
624
629
status_code = 500 ,
625
630
content_type = content_types .TEXT_PLAIN ,
626
631
body = "" .join (traceback .format_exc ()),
627
- )
632
+ ),
633
+ route ,
628
634
)
635
+
629
636
raise
630
637
631
638
def _to_response (self , result : Union [Dict , Response ]) -> Response :
@@ -672,17 +679,25 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None
672
679
673
680
self .route (* route )(func )
674
681
675
- def exception_handler (self , exception ):
682
+ def not_found (self ):
683
+ return self .exception_handler (404 )
684
+
685
+ def exception_handler (self , exc_class_or_status_code : Union [int , Type [Exception ]]):
676
686
def register_exception_handler (func : Callable ):
677
- self ._exception_handlers [exception ] = func
687
+ self ._exception_handlers [exc_class_or_status_code ] = func
678
688
679
689
return register_exception_handler
680
690
681
- def _call_exception_handler (self , exp : Exception , route : Route ) -> Optional [ResponseBuilder ]:
691
+ def _lookup_exception_handler (self , exp : Exception ) -> Optional [Callable ]:
682
692
for cls in type (exp ).__mro__ :
683
693
if cls in self ._exception_handlers :
684
- handler = self ._exception_handlers [cls ]
685
- return ResponseBuilder (handler (exp ), route )
694
+ return self ._exception_handlers [cls ]
695
+ return None
696
+
697
+ def _call_exception_handler (self , exp : Exception , route : Route ) -> Optional [ResponseBuilder ]:
698
+ handler = self ._lookup_exception_handler (exp )
699
+ if handler :
700
+ return ResponseBuilder (handler (exp ), route )
686
701
687
702
if isinstance (exp , ServiceError ):
688
703
return ResponseBuilder (
0 commit comments