-
Notifications
You must be signed in to change notification settings - Fork 429
feat(event_handler): allow multiple CORS origins #2279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1f59251
be743c3
61a6581
3ca9aaf
0acb944
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,6 +84,7 @@ def with_cors(): | |
|
||
cors_config = CORSConfig( | ||
allow_origin="https://wwww.example.com/", | ||
extra_origins=["https://dev.example.com/"], | ||
expose_headers=["x-exposed-response-header"], | ||
allow_headers=["x-custom-request-header"], | ||
max_age=100, | ||
|
@@ -106,6 +107,7 @@ def without_cors(): | |
def __init__( | ||
self, | ||
allow_origin: str = "*", | ||
extra_origins: Optional[List[str]] = None, | ||
allow_headers: Optional[List[str]] = None, | ||
expose_headers: Optional[List[str]] = None, | ||
max_age: Optional[int] = None, | ||
|
@@ -117,6 +119,8 @@ def __init__( | |
allow_origin: str | ||
The value of the `Access-Control-Allow-Origin` to send in the response. Defaults to "*", but should | ||
only be used during development. | ||
extra_origins: Optional[List[str]] | ||
The list of additional allowed origins. | ||
allow_headers: Optional[List[str]] | ||
The list of additional allowed headers. This list is added to list of | ||
built-in allowed headers: `Authorization`, `Content-Type`, `X-Amz-Date`, | ||
|
@@ -128,16 +132,29 @@ def __init__( | |
allow_credentials: bool | ||
A boolean value that sets the value of `Access-Control-Allow-Credentials` | ||
""" | ||
self.allow_origin = allow_origin | ||
self._allowed_origins = [allow_origin] | ||
if extra_origins: | ||
self._allowed_origins.extend(extra_origins) | ||
self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or [])) | ||
self.expose_headers = expose_headers or [] | ||
self.max_age = max_age | ||
self.allow_credentials = allow_credentials | ||
|
||
def to_dict(self) -> Dict[str, str]: | ||
def to_dict(self, origin: Optional[str]) -> Dict[str, str]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. origin: str = ""? Annotation says it's optional but we are not setting a default value. You can drop the Optional (None), and simply set to an empty str There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this mean I would have to compare There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TIL There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've changed my mind again and I believe it's better to keep the Optional. Reason is, the caller will try to fetch an Origin from the headers. The Origin is not always present, so |
||
"""Builds the configured Access-Control http headers""" | ||
|
||
# If there's no Origin, don't add any CORS headers | ||
if not origin: | ||
return {} | ||
|
||
# If the origin doesn't match any of the allowed origins, and we don't allow all origins ("*"), | ||
# don't add any CORS headers | ||
if origin not in self._allowed_origins and "*" not in self._allowed_origins: | ||
return {} | ||
|
||
# The origin matched an allowed origin, so return the CORS headers | ||
headers: Dict[str, str] = { | ||
"Access-Control-Allow-Origin": self.allow_origin, | ||
"Access-Control-Allow-Origin": origin, | ||
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)), | ||
} | ||
|
||
|
@@ -207,9 +224,9 @@ def __init__(self, response: Response, route: Optional[Route] = None): | |
self.response = response | ||
self.route = route | ||
|
||
def _add_cors(self, cors: CORSConfig): | ||
def _add_cors(self, event: BaseProxyEvent, cors: CORSConfig): | ||
"""Update headers to include the configured Access-Control headers""" | ||
self.response.headers.update(cors.to_dict()) | ||
self.response.headers.update(cors.to_dict(event.get_header_value("Origin"))) | ||
rubenfonseca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _add_cache_control(self, cache_control: str): | ||
"""Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used.""" | ||
|
@@ -230,7 +247,7 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]): | |
if self.route is None: | ||
return | ||
if self.route.cors: | ||
self._add_cors(cors or CORSConfig()) | ||
self._add_cors(event, cors or CORSConfig()) | ||
if self.route.cache_control: | ||
self._add_cache_control(self.route.cache_control) | ||
if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""): | ||
|
@@ -644,7 +661,7 @@ def _not_found(self, method: str) -> ResponseBuilder: | |
headers: Dict[str, Union[str, List[str]]] = {} | ||
if self._cors: | ||
logger.debug("CORS is enabled, updating headers.") | ||
headers.update(self._cors.to_dict()) | ||
headers.update(self._cors.to_dict(self.current_event.get_header_value("Origin"))) | ||
rubenfonseca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if method == "OPTIONS": | ||
logger.debug("Pre-flight request detected. Returning CORS with null response") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import requests | ||
from requests import Response | ||
|
||
from aws_lambda_powertools import Logger, Tracer | ||
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, CORSConfig | ||
from aws_lambda_powertools.logging import correlation_paths | ||
from aws_lambda_powertools.utilities.typing import LambdaContext | ||
|
||
tracer = Tracer() | ||
logger = Logger() | ||
# CORS will match when Origin is https://www.example.com OR https://dev.example.com | ||
cors_config = CORSConfig(allow_origin="https://www.example.com", extra_origins=["https://dev.example.com"], max_age=300) | ||
app = APIGatewayRestResolver(cors=cors_config) | ||
|
||
|
||
@app.get("/todos") | ||
@tracer.capture_method | ||
def get_todos(): | ||
todos: Response = requests.get("https://jsonplaceholder.typicode.com/todos") | ||
todos.raise_for_status() | ||
|
||
# for brevity, we'll limit to the first 10 only | ||
return {"todos": todos.json()[:10]} | ||
|
||
|
||
@app.get("/todos/<todo_id>") | ||
@tracer.capture_method | ||
def get_todo_by_id(todo_id: str): # value come as str | ||
todos: Response = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}") | ||
todos.raise_for_status() | ||
|
||
return {"todos": todos.json()} | ||
|
||
|
||
@app.get("/healthcheck", cors=False) # optionally removes CORS for a given route | ||
@tracer.capture_method | ||
def am_i_alive(): | ||
return {"am_i_alive": "yes"} | ||
|
||
|
||
# You can continue to use other utilities just as before | ||
@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST) | ||
@tracer.capture_lambda_handler | ||
def lambda_handler(event: dict, context: LambdaContext) -> dict: | ||
return app.resolve(event, context) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
{ | ||
"statusCode": 200, | ||
"multiValueHeaders": { | ||
"Content-Type": ["application/json"], | ||
"Access-Control-Allow-Origin": ["https://www.example.com","https://dev.example.com"], | ||
"Access-Control-Allow-Headers": ["Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token,X-Api-Key"] | ||
}, | ||
"body": "{\"todos\":[{\"userId\":1,\"id\":1,\"title\":\"delectus aut autem\",\"completed\":false},{\"userId\":1,\"id\":2,\"title\":\"quis ut nam facilis et officia qui\",\"completed\":false},{\"userId\":1,\"id\":3,\"title\":\"fugiat veniam minus\",\"completed\":false},{\"userId\":1,\"id\":4,\"title\":\"et porro tempora\",\"completed\":true},{\"userId\":1,\"id\":5,\"title\":\"laboriosam mollitia et enim quasi adipisci quia provident illum\",\"completed\":false},{\"userId\":1,\"id\":6,\"title\":\"qui ullam ratione quibusdam voluptatem quia omnis\",\"completed\":false},{\"userId\":1,\"id\":7,\"title\":\"illo expedita consequatur quia in\",\"completed\":false},{\"userId\":1,\"id\":8,\"title\":\"quo adipisci enim quam ut ab\",\"completed\":true},{\"userId\":1,\"id\":9,\"title\":\"molestiae perspiciatis ipsa\",\"completed\":false},{\"userId\":1,\"id\":10,\"title\":\"illo est ratione doloremque quia maiores aut\",\"completed\":true}]}", | ||
"isBase64Encoded": false | ||
} |
Uh oh!
There was an error while loading. Please reload this page.