|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import base64 |
3 | 4 | import dataclasses |
4 | 5 | import json |
5 | 6 | import logging |
| 7 | +import warnings |
6 | 8 | from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, Union, cast |
7 | 9 | from urllib.parse import parse_qs |
8 | 10 |
|
|
25 | 27 | RequestValidationError, |
26 | 28 | ResponseValidationError, |
27 | 29 | ) |
28 | | -from aws_lambda_powertools.event_handler.openapi.params import Param |
| 30 | +from aws_lambda_powertools.event_handler.openapi.params import Param, UploadFile |
29 | 31 | from aws_lambda_powertools.event_handler.openapi.types import UnionType |
30 | 32 |
|
31 | 33 | if TYPE_CHECKING: |
|
44 | 46 | CONTENT_DISPOSITION_NAME_PARAM = "name=" |
45 | 47 | APPLICATION_JSON_CONTENT_TYPE = "application/json" |
46 | 48 | APPLICATION_FORM_CONTENT_TYPE = "application/x-www-form-urlencoded" |
| 49 | +MULTIPART_FORM_DATA_CONTENT_TYPE = "multipart/form-data" |
47 | 50 |
|
48 | 51 |
|
49 | 52 | class OpenAPIRequestValidationMiddleware(BaseMiddlewareHandler): |
@@ -134,14 +137,18 @@ def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]: |
134 | 137 | elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE): |
135 | 138 | return self._parse_form_data(app) |
136 | 139 |
|
| 140 | + # Handle multipart/form-data (file uploads) |
| 141 | + elif content_type.startswith(MULTIPART_FORM_DATA_CONTENT_TYPE): |
| 142 | + return self._parse_multipart_data(app, content_type) |
| 143 | + |
137 | 144 | else: |
138 | 145 | raise RequestUnsupportedContentType( |
139 | | - "Only JSON body or Form() are supported", |
| 146 | + "Unsupported content type", |
140 | 147 | errors=[ |
141 | 148 | { |
142 | 149 | "type": "unsupported_content_type", |
143 | 150 | "loc": ("body",), |
144 | | - "msg": "Only JSON body or Form() are supported", |
| 151 | + "msg": f"Unsupported content type: {content_type}", |
145 | 152 | "input": {}, |
146 | 153 | "ctx": {}, |
147 | 154 | }, |
@@ -188,6 +195,49 @@ def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]: |
188 | 195 | ], |
189 | 196 | ) from e |
190 | 197 |
|
| 198 | + def _parse_multipart_data(self, app: EventHandlerInstance, content_type: str) -> dict[str, Any]: |
| 199 | + """Parse multipart/form-data from the request body (file uploads).""" |
| 200 | + try: |
| 201 | + # Extract the boundary from the content-type header |
| 202 | + boundary = _extract_multipart_boundary(content_type) |
| 203 | + if not boundary: |
| 204 | + raise ValueError("Missing boundary in multipart/form-data content-type header") |
| 205 | + |
| 206 | + # Get raw body bytes |
| 207 | + raw_body = app.current_event.body or "" |
| 208 | + if app.current_event.is_base64_encoded: |
| 209 | + body_bytes = base64.b64decode(raw_body) |
| 210 | + else: |
| 211 | + warnings.warn( |
| 212 | + "Received multipart/form-data without base64 encoding. " |
| 213 | + "Binary file uploads may be corrupted. " |
| 214 | + "If using API Gateway REST API (v1), configure Binary Media Types " |
| 215 | + "to include 'multipart/form-data'. " |
| 216 | + "See: https://docs.aws.amazon.com/apigateway/latest/developerguide/" |
| 217 | + "api-gateway-payload-encodings.html", |
| 218 | + stacklevel=2, |
| 219 | + ) |
| 220 | + # Use latin-1 to preserve all byte values (0-255) since the body |
| 221 | + # may contain raw binary data that isn't valid UTF-8 |
| 222 | + body_bytes = raw_body.encode("latin-1") |
| 223 | + |
| 224 | + return _parse_multipart_body(body_bytes, boundary) |
| 225 | + |
| 226 | + except ValueError: |
| 227 | + raise |
| 228 | + except Exception as e: |
| 229 | + raise RequestValidationError( |
| 230 | + [ |
| 231 | + { |
| 232 | + "type": "multipart_invalid", |
| 233 | + "loc": ("body",), |
| 234 | + "msg": "Multipart form data parsing error", |
| 235 | + "input": {}, |
| 236 | + "ctx": {"error": str(e)}, |
| 237 | + }, |
| 238 | + ], |
| 239 | + ) from e |
| 240 | + |
191 | 241 |
|
192 | 242 | class OpenAPIResponseValidationMiddleware(BaseMiddlewareHandler): |
193 | 243 | """ |
@@ -391,7 +441,12 @@ def _request_body_to_args( |
391 | 441 | continue |
392 | 442 |
|
393 | 443 | value = _normalize_field_value(value=value, field_info=field.field_info) |
394 | | - values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) |
| 444 | + |
| 445 | + # UploadFile objects bypass Pydantic validation — they're already constructed |
| 446 | + if isinstance(value, UploadFile): |
| 447 | + values[field.name] = value |
| 448 | + else: |
| 449 | + values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) |
395 | 450 |
|
396 | 451 | return values, errors |
397 | 452 |
|
@@ -467,6 +522,10 @@ def _is_or_contains_sequence(annotation: Any) -> bool: |
467 | 522 |
|
468 | 523 | def _normalize_field_value(value: Any, field_info: FieldInfo) -> Any: |
469 | 524 | """Normalize field value, converting lists to single values for non-sequence fields.""" |
| 525 | + # When annotation is bytes but value is UploadFile, extract raw content |
| 526 | + if isinstance(value, UploadFile) and field_info.annotation is bytes: |
| 527 | + return value.content |
| 528 | + |
470 | 529 | if _is_or_contains_sequence(field_info.annotation): |
471 | 530 | return value |
472 | 531 | elif isinstance(value, list) and value: |
@@ -580,3 +639,106 @@ def _get_param_value( |
580 | 639 | value = input_dict.get(field_name) |
581 | 640 |
|
582 | 641 | return value |
| 642 | + |
| 643 | + |
| 644 | +def _extract_multipart_boundary(content_type: str) -> str | None: |
| 645 | + """Extract the boundary string from a multipart/form-data content-type header.""" |
| 646 | + for segment in content_type.split(";"): |
| 647 | + stripped = segment.strip() |
| 648 | + if stripped.startswith("boundary="): |
| 649 | + boundary = stripped[len("boundary=") :] |
| 650 | + # Remove optional quotes around boundary |
| 651 | + if boundary.startswith('"') and boundary.endswith('"'): |
| 652 | + boundary = boundary[1:-1] |
| 653 | + return boundary |
| 654 | + return None |
| 655 | + |
| 656 | + |
| 657 | +def _parse_multipart_body(body: bytes, boundary: str) -> dict[str, Any]: |
| 658 | + """ |
| 659 | + Parse a multipart/form-data body into a dict of field names to values. |
| 660 | +
|
| 661 | + File fields get bytes values; regular form fields get string values. |
| 662 | + Multiple values for the same field name are collected into lists. |
| 663 | + """ |
| 664 | + delimiter = f"--{boundary}".encode() |
| 665 | + end_delimiter = f"--{boundary}--".encode() |
| 666 | + |
| 667 | + result: dict[str, Any] = {} |
| 668 | + |
| 669 | + # Split body by the boundary delimiter |
| 670 | + raw_parts = body.split(delimiter) |
| 671 | + |
| 672 | + for raw_part in raw_parts: |
| 673 | + # Skip the preamble (before first boundary) and epilogue (after closing boundary) |
| 674 | + if not raw_part or raw_part.strip() == b"" or raw_part.strip() == b"--": |
| 675 | + continue |
| 676 | + |
| 677 | + # Remove the end delimiter marker if present |
| 678 | + chunk = raw_part |
| 679 | + if chunk.endswith(end_delimiter): |
| 680 | + chunk = chunk[: -len(end_delimiter)] |
| 681 | + |
| 682 | + # Strip leading \r\n |
| 683 | + if chunk.startswith(b"\r\n"): |
| 684 | + chunk = chunk[2:] |
| 685 | + |
| 686 | + # Strip trailing \r\n |
| 687 | + if chunk.endswith(b"\r\n"): |
| 688 | + chunk = chunk[:-2] |
| 689 | + |
| 690 | + # Split headers from body at the double CRLF |
| 691 | + header_end = chunk.find(b"\r\n\r\n") |
| 692 | + if header_end == -1: |
| 693 | + continue |
| 694 | + |
| 695 | + header_section = chunk[:header_end].decode("utf-8") |
| 696 | + body_section = chunk[header_end + 4 :] |
| 697 | + |
| 698 | + # Parse Content-Disposition to get the field name and optional filename |
| 699 | + field_name = None |
| 700 | + filename = None |
| 701 | + content_type_header = None |
| 702 | + |
| 703 | + for header_line in header_section.split("\r\n"): |
| 704 | + header_lower = header_line.lower() |
| 705 | + if header_lower.startswith("content-disposition:"): |
| 706 | + field_name = _extract_header_param(header_line, "name") |
| 707 | + filename = _extract_header_param(header_line, "filename") |
| 708 | + elif header_lower.startswith("content-type:"): |
| 709 | + content_type_header = header_line.split(":", 1)[1].strip() |
| 710 | + |
| 711 | + if field_name is None: |
| 712 | + continue |
| 713 | + |
| 714 | + # If it has a filename, it's a file upload — wrap as UploadFile |
| 715 | + # Otherwise it's a regular form field — decode to string |
| 716 | + if filename is not None: |
| 717 | + value: Any = UploadFile(content=body_section, filename=filename, content_type=content_type_header) |
| 718 | + else: |
| 719 | + value = body_section.decode("utf-8") |
| 720 | + |
| 721 | + # Collect multiple values for same field name into a list |
| 722 | + if field_name in result: |
| 723 | + existing = result[field_name] |
| 724 | + if isinstance(existing, list): |
| 725 | + existing.append(value) |
| 726 | + else: |
| 727 | + result[field_name] = [existing, value] |
| 728 | + else: |
| 729 | + result[field_name] = value |
| 730 | + |
| 731 | + return result |
| 732 | + |
| 733 | + |
| 734 | +def _extract_header_param(header_line: str, param_name: str) -> str | None: |
| 735 | + """Extract a parameter value from a header line (e.g., name="file" from Content-Disposition).""" |
| 736 | + search = f'{param_name}="' |
| 737 | + idx = header_line.find(search) |
| 738 | + if idx == -1: |
| 739 | + return None |
| 740 | + start = idx + len(search) |
| 741 | + end = header_line.find('"', start) |
| 742 | + if end == -1: |
| 743 | + return None |
| 744 | + return header_line[start:end] |
0 commit comments