Skip to content

feat(event_handler): add support for additional response models in OpenAPI schema #3591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 16, 2024
79 changes: 64 additions & 15 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
from aws_lambda_powertools.event_handler.openapi.types import (
COMPONENT_REF_PREFIX,
METHODS_WITH_BODY,
OpenAPIResponse,
OpenAPIResponseContentModel,
OpenAPIResponseContentSchema,
validation_error_definition,
validation_error_response_definition,
)
Expand Down Expand Up @@ -273,7 +276,7 @@ def __init__(
cache_control: Optional[str],
summary: Optional[str],
description: Optional[str],
responses: Optional[Dict[int, Dict[str, Any]]],
responses: Optional[Dict[int, OpenAPIResponse]],
response_description: Optional[str],
tags: Optional[List[str]],
operation_id: Optional[str],
Expand Down Expand Up @@ -303,7 +306,7 @@ def __init__(
The OpenAPI summary for this route
description: Optional[str]
The OpenAPI description for this route
responses: Optional[Dict[int, Dict[str, Any]]]
responses: Optional[Dict[int, OpenAPIResponse]]
The OpenAPI responses for this route
response_description: Optional[str]
The OpenAPI response description for this route
Expand Down Expand Up @@ -442,7 +445,7 @@ def dependant(self) -> "Dependant":
if self._dependant is None:
from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant

self._dependant = get_dependant(path=self.openapi_path, call=self.func)
self._dependant = get_dependant(path=self.openapi_path, call=self.func, responses=self.responses)

return self._dependant

Expand Down Expand Up @@ -501,11 +504,54 @@ def _get_openapi_path(

# Add the response to the OpenAPI operation
if self.responses:
# If the user supplied responses, we use them and don't set a default 200 response
for status_code in list(self.responses):
response = self.responses[status_code]

# Case 1: there is not 'content' key
if "content" not in response:
response["content"] = {
"application/json": self._openapi_operation_return(
param=dependant.return_param,
model_name_map=model_name_map,
field_mapping=field_mapping,
),
}

# Case 2: there is a 'content' key
else:
# Need to iterate to transform any 'model' into a 'schema'
for content_type, payload in response["content"].items():
new_payload: OpenAPIResponseContentSchema

# Case 2.1: the 'content' has a model
if "model" in payload:
# Find the model in the dependant's extra models
return_field = next(
filter(
lambda model: model.type_ is cast(OpenAPIResponseContentModel, payload)["model"],
self.dependant.response_extra_models,
),
)
if not return_field:
raise AssertionError("Model declared in custom responses was not found")

new_payload = self._openapi_operation_return(
param=return_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
)

# Case 2.2: the 'content' has a schema
else:
# Do nothing! We already have what we need!
new_payload = payload

response["content"][content_type] = new_payload

operation["responses"] = self.responses
else:
# Set the default 200 response
responses = operation.setdefault("responses", self.responses or {})
responses = operation.setdefault("responses", {})
success_response = responses.setdefault(200, {})
success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION
success_response["content"] = {"application/json": {"schema": {}}}
Expand Down Expand Up @@ -682,7 +728,7 @@ def _openapi_operation_return(
Tuple["ModelField", Literal["validation", "serialization"]],
"JsonSchemaValue",
],
) -> Dict[str, Any]:
) -> OpenAPIResponseContentSchema:
"""
Returns the OpenAPI operation return.
"""
Expand Down Expand Up @@ -832,7 +878,7 @@ def route(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -890,7 +936,7 @@ def get(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -943,7 +989,7 @@ def post(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -997,7 +1043,7 @@ def put(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -1051,7 +1097,7 @@ def delete(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -1104,7 +1150,7 @@ def patch(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -1662,7 +1708,7 @@ def route(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -2110,6 +2156,9 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]:
if route.dependant.return_param:
responses_from_routes.append(route.dependant.return_param)

if route.dependant.response_extra_models:
responses_from_routes.extend(route.dependant.response_extra_models)

flat_models = list(responses_from_routes + request_fields_from_routes + body_fields_from_routes)
return flat_models

Expand All @@ -2132,7 +2181,7 @@ def route(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down Expand Up @@ -2221,7 +2270,7 @@ def route(
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, Dict[str, Any]]] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
Expand Down
34 changes: 32 additions & 2 deletions aws_lambda_powertools/event_handler/openapi/dependant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
create_response_field,
get_flat_dependant,
)
from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse, OpenAPIResponseContentModel

"""
This turns the opaque function signature into typed, validated models.
Expand Down Expand Up @@ -145,6 +146,7 @@ def get_dependant(
path: str,
call: Callable[..., Any],
name: Optional[str] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
) -> Dependant:
"""
Returns a dependant model for a handler function. A dependant model is a model that contains
Expand All @@ -158,6 +160,8 @@ def get_dependant(
The handler function
name: str, optional
The name of the handler function
responses: List[Dict[int, OpenAPIResponse]], optional
The list of extra responses for the handler function

Returns
-------
Expand Down Expand Up @@ -195,6 +199,34 @@ def get_dependant(
else:
add_param_to_fields(field=param_field, dependant=dependant)

_add_return_annotation(dependant, endpoint_signature)
_add_extra_responses(dependant, responses)

return dependant


def _add_extra_responses(dependant: Dependant, responses: Optional[Dict[int, OpenAPIResponse]]):
# Also add the optional extra responses to the dependant model.
if not responses:
return

for response in responses.values():
for schema in response.get("content", {}).values():
if "model" in schema:
response_field = analyze_param(
param_name="return",
annotation=cast(OpenAPIResponseContentModel, schema)["model"],
value=None,
is_path_param=False,
is_response_param=True,
)
if response_field is None:
raise AssertionError("Response field is None for response model")

dependant.response_extra_models.append(response_field)


def _add_return_annotation(dependant: Dependant, endpoint_signature: inspect.Signature):
# If the return annotation is not empty, add it to the dependant model.
return_annotation = endpoint_signature.return_annotation
if return_annotation is not inspect.Signature.empty:
Expand All @@ -210,8 +242,6 @@ def get_dependant(

dependant.return_param = param_field

return dependant


def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
"""
Expand Down
2 changes: 2 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
cookie_params: Optional[List[ModelField]] = None,
body_params: Optional[List[ModelField]] = None,
return_param: Optional[ModelField] = None,
response_extra_models: Optional[List[ModelField]] = None,
name: Optional[str] = None,
call: Optional[Callable[..., Any]] = None,
request_param_name: Optional[str] = None,
Expand All @@ -64,6 +65,7 @@ def __init__(
self.cookie_params = cookie_params or []
self.body_params = body_params or []
self.return_param = return_param or None
self.response_extra_models = response_extra_models or []
self.request_param_name = request_param_name
self.websocket_param_name = websocket_param_name
self.http_connection_param_name = http_connection_param_name
Expand Down
15 changes: 15 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type, Union

from aws_lambda_powertools.shared.types import NotRequired, TypedDict

if TYPE_CHECKING:
from pydantic import BaseModel # noqa: F401

Expand Down Expand Up @@ -43,3 +45,16 @@
},
},
}


class OpenAPIResponseContentSchema(TypedDict, total=False):
schema: Dict


class OpenAPIResponseContentModel(TypedDict):
model: Any


class OpenAPIResponse(TypedDict):
description: str
content: NotRequired[Dict[str, Union[OpenAPIResponseContentSchema, OpenAPIResponseContentModel]]]
18 changes: 9 additions & 9 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -955,15 +955,15 @@ Customize your API endpoints by adding metadata to endpoint definitions. This pr

Here's a breakdown of various customizable fields:

| Field Name | Type | Description |
| ---------------------- | --------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `summary` | `str` | A concise overview of the main functionality of the endpoint. This brief introduction is usually displayed in autogenerated API documentation and helps consumers quickly understand what the endpoint does. |
| `description` | `str` | A more detailed explanation of the endpoint, which can include information about the operation's behavior, including side effects, error states, and other operational guidelines. |
| `responses` | `Dict[int, Dict[str, Any]]` | A dictionary that maps each HTTP status code to a Response Object as defined by the [OpenAPI Specification](https://swagger.io/specification/#response-object). This allows you to describe expected responses, including default or error messages, and their corresponding schemas for different status codes. |
| `response_description` | `str` | Provides the default textual description of the response sent by the endpoint when the operation is successful. It is intended to give a human-readable understanding of the result. |
| `tags` | `List[str]` | Tags are a way to categorize and group endpoints within the API documentation. They can help organize the operations by resources or other heuristic. |
| `operation_id` | `str` | A unique identifier for the operation, which can be used for referencing this operation in documentation or code. This ID must be unique across all operations described in the API. |
| `include_in_schema` | `bool` | A boolean value that determines whether or not this operation should be included in the OpenAPI schema. Setting it to `False` can hide the endpoint from generated documentation and schema exports, which might be useful for private or experimental endpoints. |
| Field Name | Type | Description |
| ---------------------- |-----------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `summary` | `str` | A concise overview of the main functionality of the endpoint. This brief introduction is usually displayed in autogenerated API documentation and helps consumers quickly understand what the endpoint does. |
| `description` | `str` | A more detailed explanation of the endpoint, which can include information about the operation's behavior, including side effects, error states, and other operational guidelines. |
| `responses` | `Dict[int, Dict[str, OpenAPIResponse]]` | A dictionary that maps each HTTP status code to a Response Object as defined by the [OpenAPI Specification](https://swagger.io/specification/#response-object). This allows you to describe expected responses, including default or error messages, and their corresponding schemas or models for different status codes. |
| `response_description` | `str` | Provides the default textual description of the response sent by the endpoint when the operation is successful. It is intended to give a human-readable understanding of the result. |
| `tags` | `List[str]` | Tags are a way to categorize and group endpoints within the API documentation. They can help organize the operations by resources or other heuristic. |
| `operation_id` | `str` | A unique identifier for the operation, which can be used for referencing this operation in documentation or code. This ID must be unique across all operations described in the API. |
| `include_in_schema` | `bool` | A boolean value that determines whether or not this operation should be included in the OpenAPI schema. Setting it to `False` can hide the endpoint from generated documentation and schema exports, which might be useful for private or experimental endpoints. |

To implement these customizations, include extra parameters when defining your routes:

Expand Down
Loading