diff --git a/packages/smithy-core/src/smithy_core/retries.py b/packages/smithy-core/src/smithy_core/retries.py index f3c79798..06bf6f98 100644 --- a/packages/smithy-core/src/smithy_core/retries.py +++ b/packages/smithy-core/src/smithy_core/retries.py @@ -234,11 +234,11 @@ def refresh_retry_token_for_retry( if retry_count >= self.max_attempts: raise RetryError( f"Reached maximum number of allowed attempts: {self.max_attempts}" - ) + ) from error retry_delay = self.backoff_strategy.compute_next_backoff_delay(retry_count) return SimpleRetryToken(retry_count=retry_count, retry_delay=retry_delay) else: - raise RetryError(f"Error is not retryable: {error}") + raise RetryError(f"Error is not retryable: {error}") from error def record_success(self, *, token: retries_interface.RetryToken) -> None: """Not used by this retry strategy.""" diff --git a/packages/smithy-http/src/smithy_http/bindings.py b/packages/smithy-http/src/smithy_http/bindings.py new file mode 100644 index 00000000..de0eee09 --- /dev/null +++ b/packages/smithy-http/src/smithy_http/bindings.py @@ -0,0 +1,172 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from enum import Enum + +from smithy_core.schemas import Schema +from smithy_core.shapes import ShapeType +from smithy_core.traits import ( + ErrorFault, + ErrorTrait, + HostLabelTrait, + HTTPErrorTrait, + HTTPHeaderTrait, + HTTPLabelTrait, + HTTPPayloadTrait, + HTTPPrefixHeadersTrait, + HTTPQueryParamsTrait, + HTTPQueryTrait, + HTTPResponseCodeTrait, + StreamingTrait, +) + + +class Binding(Enum): + """HTTP binding locations.""" + + HEADER = 0 + """Indicates the member is bound to a header.""" + + QUERY = 1 + """Indicates the member is bound to a query parameter.""" + + PAYLOAD = 2 + """Indicates the member is bound to the entire HTTP payload.""" + + BODY = 3 + """Indicates the member is a property in the HTTP payload structure.""" + + LABEL = 4 + """Indicates the member is bound to a path segment in the URI.""" + + STATUS = 5 + """Indicates the member is bound to the response status code.""" + + PREFIX_HEADERS = 6 + """Indicates the member is bound to multiple headers with a shared prefix.""" + + QUERY_PARAMS = 7 + """Indicates the member is bound to the query string as multiple key-value pairs.""" + + HOST = 8 + """Indicates the member is bound to a prefix to the host AND to the body.""" + + +@dataclass(init=False) +class _BindingMatcher: + bindings: list[Binding] + """A list of bindings where the index matches the index of the member schema.""" + + response_status: int + """The default response status code.""" + + has_body: bool + """Whether the HTTP message has members bound to the body.""" + + has_payload: bool + """Whether the HTTP message has a member bound to the entire payload.""" + + payload_member: Schema | None + """The member bound to the payload, if one exists.""" + + event_stream_member: Schema | None + """The member bound to the event stream, if one exists.""" + + def __init__(self, struct: Schema, response_status: int) -> None: + self.response_status = response_status + found_body = False + found_payload = False + self.bindings = [Binding.BODY] * len(struct.members) + self.payload_member = None + self.event_stream_member = None + + for member in struct.members.values(): + binding = self._do_match(member) + self.bindings[member.expect_member_index()] = binding + found_body = ( + found_body or binding is Binding.BODY or binding is Binding.HOST + ) + if binding is Binding.PAYLOAD: + found_payload = True + self.payload_member = member + if ( + StreamingTrait.id in member.traits + and member.shape_type is ShapeType.UNION + ): + self.event_stream_member = member + + self.has_body = found_body + self.has_payload = found_payload + + def should_write_body(self, omit_empty_payload: bool) -> bool: + """Determines whether a body should be written. + + :param omit_empty_payload: Whether a body should be skipped in the case of an + empty payload. + """ + return self.has_body or (not omit_empty_payload and not self.has_payload) + + def match(self, member: Schema) -> Binding: + """Determines which part of the HTTP message the given member is bound to.""" + return self.bindings[member.expect_member_index()] + + def _do_match(self, member: Schema) -> Binding: ... + + +@dataclass(init=False) +class RequestBindingMatcher(_BindingMatcher): + """Matches structure members to HTTP request binding locations.""" + + def __init__(self, struct: Schema) -> None: + """Initialize a RequestBindingMatcher. + + :param struct: The structure to examine for HTTP bindings. + """ + super().__init__(struct, -1) + + def _do_match(self, member: Schema) -> Binding: + if HTTPLabelTrait.id in member.traits: + return Binding.LABEL + if HTTPQueryTrait.id in member.traits: + return Binding.QUERY + if HTTPQueryParamsTrait.id in member.traits: + return Binding.QUERY_PARAMS + if HTTPHeaderTrait.id in member.traits: + return Binding.HEADER + if HTTPPrefixHeadersTrait.id in member.traits: + return Binding.PREFIX_HEADERS + if HTTPPayloadTrait.id in member.traits: + return Binding.PAYLOAD + if HostLabelTrait.id in member.traits: + return Binding.HOST + return Binding.BODY + + +@dataclass(init=False) +class ResponseBindingMatcher(_BindingMatcher): + """Matches structure members to HTTP response binding locations.""" + + def __init__(self, struct: Schema) -> None: + """Initialize a ResponseBindingMatcher. + + :param struct: The structure to examine for HTTP bindings. + """ + super().__init__(struct, self._compute_response(struct)) + + def _compute_response(self, struct: Schema) -> int: + if (http_error := struct.get_trait(HTTPErrorTrait)) is not None: + return http_error.code + if (error := struct.get_trait(ErrorTrait)) is not None: + return 400 if error.fault is ErrorFault.CLIENT else 500 + return -1 + + def _do_match(self, member: Schema) -> Binding: + if HTTPResponseCodeTrait.id in member.traits: + return Binding.STATUS + if HTTPHeaderTrait.id in member.traits: + return Binding.HEADER + if HTTPPrefixHeadersTrait.id in member.traits: + return Binding.PREFIX_HEADERS + if HTTPPayloadTrait.id in member.traits: + return Binding.PAYLOAD + return Binding.BODY diff --git a/packages/smithy-http/src/smithy_http/deserializers.py b/packages/smithy-http/src/smithy_http/deserializers.py index f48136db..5a12de2b 100644 --- a/packages/smithy-http/src/smithy_http/deserializers.py +++ b/packages/smithy-http/src/smithy_http/deserializers.py @@ -13,9 +13,7 @@ from smithy_core.shapes import ShapeType from smithy_core.traits import ( HTTPHeaderTrait, - HTTPPayloadTrait, HTTPPrefixHeadersTrait, - HTTPResponseCodeTrait, HTTPTrait, TimestampFormatTrait, ) @@ -23,6 +21,7 @@ from smithy_core.utils import ensure_utc, strict_parse_bool, strict_parse_float from .aio.interfaces import HTTPResponse +from .bindings import Binding, ResponseBindingMatcher from .interfaces import Field, Fields if TYPE_CHECKING: @@ -61,47 +60,49 @@ def __init__( def read_struct( self, schema: Schema, consumer: Callable[[Schema, ShapeDeserializer], None] ) -> None: - has_body = False - payload_member: Schema | None = None + binding_matcher = ResponseBindingMatcher(schema) for member in schema.members.values(): - if (trait := member.get_trait(HTTPHeaderTrait)) is not None: - header = self._response.fields.entries.get(trait.key.lower()) - if header is not None: - if member.shape_type is ShapeType.LIST: - consumer(member, HTTPHeaderListDeserializer(header)) - else: - consumer(member, HTTPHeaderDeserializer(header.as_string())) - elif (trait := member.get_trait(HTTPPrefixHeadersTrait)) is not None: - consumer( - member, - HTTPHeaderMapDeserializer(self._response.fields, trait.prefix), - ) - elif HTTPPayloadTrait in member: - has_body = True - payload_member = member - elif HTTPResponseCodeTrait in member: - consumer(member, HTTPResponseCodeDeserializer(self._response.status)) - else: - has_body = True - - if has_body: - deserializer = self._create_payload_deserializer(payload_member) - if payload_member is not None: - consumer(payload_member, deserializer) - else: - deserializer.read_struct(schema, consumer) - - def _create_payload_deserializer( - self, payload_member: Schema | None - ) -> ShapeDeserializer: - body = self._body if self._body is not None else self._response.body - if payload_member is not None and payload_member.shape_type in ( - ShapeType.BLOB, - ShapeType.STRING, - ): + match binding_matcher.match(member): + case Binding.HEADER: + trait = member.expect_trait(HTTPHeaderTrait) + header = self._response.fields.entries.get(trait.key.lower()) + if header is not None: + if member.shape_type is ShapeType.LIST: + consumer(member, HTTPHeaderListDeserializer(header)) + else: + consumer(member, HTTPHeaderDeserializer(header.as_string())) + case Binding.PREFIX_HEADERS: + trait = member.expect_trait(HTTPPrefixHeadersTrait) + consumer( + member, + HTTPHeaderMapDeserializer(self._response.fields, trait.prefix), + ) + case Binding.STATUS: + consumer( + member, HTTPResponseCodeDeserializer(self._response.status) + ) + case Binding.PAYLOAD: + assert binding_matcher.payload_member is not None # noqa: S101 + deserializer = self._create_payload_deserializer( + binding_matcher.payload_member + ) + consumer(binding_matcher.payload_member, deserializer) + case _: + pass + + if binding_matcher.has_body: + deserializer = self._create_body_deserializer() + deserializer.read_struct(schema, consumer) + + def _create_payload_deserializer(self, payload_member: Schema) -> ShapeDeserializer: + if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING): + body = self._body if self._body is not None else self._response.body return RawPayloadDeserializer(body) + return self._create_body_deserializer() + def _create_body_deserializer(self): + body = self._body if self._body is not None else self._response.body if not is_streaming_blob(body): raise UnsupportedStreamError( "Unable to read async stream. This stream must be buffered prior " diff --git a/packages/smithy-http/src/smithy_http/serializers.py b/packages/smithy-http/src/smithy_http/serializers.py index a29b7667..9d05ebc7 100644 --- a/packages/smithy-http/src/smithy_http/serializers.py +++ b/packages/smithy-http/src/smithy_http/serializers.py @@ -19,18 +19,11 @@ from smithy_core.shapes import ShapeType from smithy_core.traits import ( EndpointTrait, - HostLabelTrait, - HTTPErrorTrait, HTTPHeaderTrait, - HTTPLabelTrait, - HTTPPayloadTrait, HTTPPrefixHeadersTrait, - HTTPQueryParamsTrait, HTTPQueryTrait, - HTTPResponseCodeTrait, HTTPTrait, MediaTypeTrait, - StreamingTrait, TimestampFormatTrait, ) from smithy_core.types import PathPattern, TimestampFormat @@ -40,6 +33,7 @@ from .aio import HTTPRequest as _HTTPRequest from .aio import HTTPResponse as _HTTPResponse from .aio.interfaces import HTTPRequest, HTTPResponse +from .bindings import Binding, RequestBindingMatcher, ResponseBindingMatcher from .utils import join_query_params if TYPE_CHECKING: @@ -61,6 +55,7 @@ def __init__( payload_codec: Codec, http_trait: HTTPTrait, endpoint_trait: EndpointTrait | None = None, + omit_empty_payload: bool = True, ) -> None: """Initialize an HTTPRequestSerializer. @@ -69,10 +64,12 @@ def __init__( :param http_trait: The HTTP trait of the operation being handled. :param endpoint_trait: The optional endpoint trait of the operation being handled. + :param omit_empty_payload: Whether an empty payload should be omitted. """ self._http_trait = http_trait self._endpoint_trait = endpoint_trait self._payload_codec = payload_codec + self._omit_empty_payload = omit_empty_payload self.result: HTTPRequest | None = None @contextmanager @@ -86,7 +83,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: content_type = self._payload_codec.media_type - if (payload_member := self._get_payload_member(schema)) is not None: + binding_matcher = RequestBindingMatcher(schema) + if (payload_member := binding_matcher.payload_member) is not None: if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING): content_type = ( "application/octet-stream" @@ -95,7 +93,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: ) payload_serializer = RawPayloadSerializer() binding_serializer = HTTPRequestBindingSerializer( - payload_serializer, self._http_trait.path, host_prefix + payload_serializer, + self._http_trait.path, + host_prefix, + binding_matcher, ) yield binding_serializer payload = payload_serializer.payload @@ -105,17 +106,32 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: payload = BytesIO() payload_serializer = self._payload_codec.create_serializer(payload) binding_serializer = HTTPRequestBindingSerializer( - payload_serializer, self._http_trait.path, host_prefix + payload_serializer, + self._http_trait.path, + host_prefix, + binding_matcher, ) yield binding_serializer else: - if self._get_eventstreaming_member(schema) is not None: + if binding_matcher.event_stream_member is not None: content_type = "application/vnd.amazon.eventstream" payload = BytesIO() payload_serializer = self._payload_codec.create_serializer(payload) - with payload_serializer.begin_struct(schema) as body_serializer: + if binding_matcher.should_write_body(self._omit_empty_payload): + with payload_serializer.begin_struct(schema) as body_serializer: + binding_serializer = HTTPRequestBindingSerializer( + body_serializer, + self._http_trait.path, + host_prefix, + binding_matcher, + ) + yield binding_serializer + else: binding_serializer = HTTPRequestBindingSerializer( - body_serializer, self._http_trait.path, host_prefix + payload_serializer, + self._http_trait.path, + host_prefix, + binding_matcher, ) yield binding_serializer @@ -142,21 +158,6 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: body=payload, ) - def _get_payload_member(self, schema: Schema) -> Schema | None: - for member in schema.members.values(): - if HTTPPayloadTrait in member: - return member - return None - - def _get_eventstreaming_member(self, schema: Schema) -> Schema | None: - for member in schema.members.values(): - if ( - member.get_trait(StreamingTrait) is not None - and member.shape_type is ShapeType.UNION - ): - return member - return None - class HTTPRequestBindingSerializer(InterceptingSerializer): """Delegates HTTP request bindings to binding-location-specific serializers.""" @@ -166,6 +167,7 @@ def __init__( payload_serializer: ShapeSerializer, path_pattern: PathPattern, host_prefix_pattern: str, + binding_matcher: RequestBindingMatcher, ) -> None: """Initialize an HTTPRequestBindingSerializer. @@ -181,18 +183,20 @@ def __init__( self.host_prefix_serializer = HostPrefixSerializer( payload_serializer, host_prefix_pattern ) + self._binding_matcher = binding_matcher def before(self, schema: Schema) -> ShapeSerializer: - if HTTPHeaderTrait in schema or HTTPPrefixHeadersTrait in schema: - return self.header_serializer - if HTTPQueryTrait in schema or HTTPQueryParamsTrait in schema: - return self.query_serializer - if HTTPLabelTrait in schema: - return self.path_serializer - if HostLabelTrait in schema: - return self.host_prefix_serializer - - return self._payload_serializer + match self._binding_matcher.match(schema): + case Binding.HEADER | Binding.PREFIX_HEADERS: + return self.header_serializer + case Binding.QUERY | Binding.QUERY_PARAMS: + return self.query_serializer + case Binding.LABEL: + return self.path_serializer + case Binding.HOST: + return self.host_prefix_serializer + case _: + return self._payload_serializer def after(self, schema: Schema) -> None: pass @@ -205,38 +209,55 @@ def __init__( self, payload_codec: Codec, http_trait: HTTPTrait, + omit_empty_payload: bool = True, ) -> None: """Initialize an HTTPResponseSerializer. :param payload_codec: The codec to use to serialize the HTTP payload, if one is present. :param http_trait: The HTTP trait of the operation being handled. + :param omit_empty_payload: Whether an empty payload should be omitted. """ self._http_trait = http_trait self._payload_codec = payload_codec self.result: HTTPResponse | None = None + self._omit_empty_payload = omit_empty_payload @contextmanager def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: payload: Any binding_serializer: HTTPResponseBindingSerializer - if (payload_member := self._get_payload_member(schema)) is not None: + binding_matcher = ResponseBindingMatcher(schema) + if (payload_member := binding_matcher.payload_member) is not None: if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING): payload_serializer = RawPayloadSerializer() - binding_serializer = HTTPResponseBindingSerializer(payload_serializer) + binding_serializer = HTTPResponseBindingSerializer( + payload_serializer, binding_matcher + ) yield binding_serializer payload = payload_serializer.payload else: payload = BytesIO() payload_serializer = self._payload_codec.create_serializer(payload) - binding_serializer = HTTPResponseBindingSerializer(payload_serializer) + binding_serializer = HTTPResponseBindingSerializer( + payload_serializer, binding_matcher + ) yield binding_serializer else: payload = BytesIO() payload_serializer = self._payload_codec.create_serializer(payload) - with payload_serializer.begin_struct(schema) as body_serializer: - binding_serializer = HTTPResponseBindingSerializer(body_serializer) + if binding_matcher.should_write_body(self._omit_empty_payload): + with payload_serializer.begin_struct(schema) as body_serializer: + binding_serializer = HTTPResponseBindingSerializer( + body_serializer, binding_matcher + ) + yield binding_serializer + else: + binding_serializer = HTTPResponseBindingSerializer( + payload_serializer, + binding_matcher, + ) yield binding_serializer if ( @@ -244,28 +265,28 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]: ) is not None and not iscoroutinefunction(seek): seek(0) - default_code = self._http_trait.code - explicit_code = binding_serializer.response_code_serializer.response_code - if (http_error_trait := schema.get_trait(HTTPErrorTrait)) is not None: - default_code = http_error_trait.code + status = binding_serializer.response_code_serializer.response_code + if status is None: + if binding_matcher.response_status > 0: + status = binding_matcher.response_status + else: + status = self._http_trait.code self.result = _HTTPResponse( fields=tuples_to_fields(binding_serializer.header_serializer.headers), body=payload, - status=explicit_code or default_code, + status=status, ) - def _get_payload_member(self, schema: Schema) -> Schema | None: - for member in schema.members.values(): - if HTTPPayloadTrait in member: - return member - return None - class HTTPResponseBindingSerializer(InterceptingSerializer): """Delegates HTTP response bindings to binding-location-specific serializers.""" - def __init__(self, payload_serializer: ShapeSerializer) -> None: + def __init__( + self, + payload_serializer: ShapeSerializer, + binding_matcher: ResponseBindingMatcher, + ) -> None: """Initialize an HTTPResponseBindingSerializer. :param payload_serializer: The :py:class:`ShapeSerializer` to use to serialize @@ -274,14 +295,16 @@ def __init__(self, payload_serializer: ShapeSerializer) -> None: self._payload_serializer = payload_serializer self.header_serializer = HTTPHeaderSerializer() self.response_code_serializer = HTTPResponseCodeSerializer() + self._binding_matcher = binding_matcher def before(self, schema: Schema) -> ShapeSerializer: - if HTTPHeaderTrait in schema or HTTPPrefixHeadersTrait in schema: - return self.header_serializer - if HTTPResponseCodeTrait in schema: - return self.response_code_serializer - - return self._payload_serializer + match self._binding_matcher.match(schema): + case Binding.HEADER | Binding.PREFIX_HEADERS: + return self.header_serializer + case Binding.STATUS: + return self.response_code_serializer + case _: + return self._payload_serializer def after(self, schema: Schema) -> None: pass diff --git a/packages/smithy-http/tests/unit/test_bindings.py b/packages/smithy-http/tests/unit/test_bindings.py new file mode 100644 index 00000000..5d84b2e0 --- /dev/null +++ b/packages/smithy-http/tests/unit/test_bindings.py @@ -0,0 +1,162 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from smithy_core.prelude import INTEGER, STRING +from smithy_core.schemas import Schema +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import ( + ErrorTrait, + HostLabelTrait, + HTTPErrorTrait, + HTTPHeaderTrait, + HTTPLabelTrait, + HTTPPayloadTrait, + HTTPPrefixHeadersTrait, + HTTPQueryParamsTrait, + HTTPQueryTrait, + HTTPResponseCodeTrait, + StreamingTrait, +) +from smithy_http.bindings import Binding, RequestBindingMatcher, ResponseBindingMatcher + +PAYLOAD_BINDING = Schema.collection( + id=ShapeID("com.example#Payload"), + members={"payload": {"index": 0, "target": STRING, "traits": [HTTPPayloadTrait()]}}, +) + +EVENT_STREAM_SCHEMA = Schema.collection( + id=ShapeID("com.example#EventStream"), + shape_type=ShapeType.UNION, + members={ + "stream": { + "index": 0, + "target": Schema.collection(id=ShapeID("com.example#Event")), + } + }, + traits=[StreamingTrait()], +) +EVENT_STREAM_BINDING = Schema.collection( + id=ShapeID("com.example#Events"), + members={"stream": {"index": 0, "target": EVENT_STREAM_SCHEMA}}, +) + +STRING_MAP = Schema.collection( + id=ShapeID("com.example#StringMap"), + shape_type=ShapeType.MAP, + members={ + "key": {"index": 0, "target": STRING}, + "value": {"index": 0, "target": STRING}, + }, +) + +GENERAL_BINDINGS = Schema.collection( + id=ShapeID("com.example#BodyBindings"), + members={ + "label": {"index": 0, "target": STRING, "traits": [HTTPLabelTrait()]}, + "query": {"index": 1, "target": STRING, "traits": [HTTPQueryTrait()]}, + "queryParams": { + "index": 2, + "target": STRING_MAP, + "traits": [HTTPQueryParamsTrait()], + }, + "header": {"index": 3, "target": STRING, "traits": [HTTPHeaderTrait()]}, + "prefixHeaders": { + "index": 4, + "target": STRING_MAP, + "traits": [HTTPPrefixHeadersTrait("foo")], + }, + "hostLabel": {"index": 5, "target": STRING, "traits": [HostLabelTrait()]}, + "status": { + "index": 6, + "target": INTEGER, + "traits": [HTTPResponseCodeTrait()], + }, + "body": {"index": 7, "target": STRING}, + }, +) + + +def test_request_payload_matching() -> None: + matcher = RequestBindingMatcher(PAYLOAD_BINDING) + member_schema = PAYLOAD_BINDING.members["payload"] + actual = matcher.match(member_schema) + assert actual == Binding.PAYLOAD + assert matcher.payload_member is member_schema + + +def test_response_payload_matching() -> None: + matcher = ResponseBindingMatcher(PAYLOAD_BINDING) + member_schema = PAYLOAD_BINDING.members["payload"] + actual = matcher.match(member_schema) + assert actual == Binding.PAYLOAD + assert matcher.payload_member is member_schema + + +def test_request_event_stream_matching() -> None: + matcher = RequestBindingMatcher(EVENT_STREAM_BINDING) + member_schema = EVENT_STREAM_BINDING.members["stream"] + assert matcher.event_stream_member is member_schema + + +def test_response_event_stream_matching() -> None: + matcher = ResponseBindingMatcher(EVENT_STREAM_BINDING) + member_schema = EVENT_STREAM_BINDING.members["stream"] + assert matcher.event_stream_member is member_schema + + +def test_response_matches_http_error_trait() -> None: + schema = Schema.collection( + id=ShapeID("com.example#HTTPErrorTrait"), traits=[HTTPErrorTrait(404)] + ) + matcher = ResponseBindingMatcher(schema) + assert matcher.response_status == 404 + + +def test_response_matches_error_trait() -> None: + schema = Schema.collection( + id=ShapeID("com.example#ErrorTrait"), traits=[ErrorTrait("client")] + ) + matcher = ResponseBindingMatcher(schema) + assert matcher.response_status == 400 + + schema = Schema.collection( + id=ShapeID("com.example#ErrorTrait"), traits=[ErrorTrait("server")] + ) + matcher = ResponseBindingMatcher(schema) + assert matcher.response_status == 500 + + +def test_request_matching() -> None: + matcher = RequestBindingMatcher(GENERAL_BINDINGS) + assert matcher.match(GENERAL_BINDINGS.members["label"]) == Binding.LABEL + assert matcher.match(GENERAL_BINDINGS.members["query"]) == Binding.QUERY + + query_params_member = GENERAL_BINDINGS.members["queryParams"] + assert matcher.match(query_params_member) == Binding.QUERY_PARAMS + + assert matcher.match(GENERAL_BINDINGS.members["header"]) == Binding.HEADER + + prefix_member = GENERAL_BINDINGS.members["prefixHeaders"] + assert matcher.match(prefix_member) == Binding.PREFIX_HEADERS + + assert matcher.match(GENERAL_BINDINGS.members["hostLabel"]) == Binding.HOST + assert matcher.match(GENERAL_BINDINGS.members["status"]) == Binding.BODY + assert matcher.match(GENERAL_BINDINGS.members["body"]) == Binding.BODY + + +def test_response_matching() -> None: + matcher = ResponseBindingMatcher(GENERAL_BINDINGS) + assert matcher.match(GENERAL_BINDINGS.members["label"]) == Binding.BODY + assert matcher.match(GENERAL_BINDINGS.members["query"]) == Binding.BODY + + query_params_member = GENERAL_BINDINGS.members["queryParams"] + assert matcher.match(query_params_member) == Binding.BODY + + assert matcher.match(GENERAL_BINDINGS.members["header"]) == Binding.HEADER + + prefix_member = GENERAL_BINDINGS.members["prefixHeaders"] + assert matcher.match(prefix_member) == Binding.PREFIX_HEADERS + + assert matcher.match(GENERAL_BINDINGS.members["hostLabel"]) == Binding.BODY + assert matcher.match(GENERAL_BINDINGS.members["status"]) == Binding.STATUS + assert matcher.match(GENERAL_BINDINGS.members["body"]) == Binding.BODY diff --git a/packages/smithy-http/tests/unit/test_serializers.py b/packages/smithy-http/tests/unit/test_serializers.py index 199193bf..0dd5eea6 100644 --- a/packages/smithy-http/tests/unit/test_serializers.py +++ b/packages/smithy-http/tests/unit/test_serializers.py @@ -1649,6 +1649,20 @@ async def test_serialize_http_request(case: HTTPMessageTestCase) -> None: assert type(actual.body) is type(case.request.body) +async def test_serialize_request_omitting_empty_payload() -> None: + shape = HTTPStringLabel(label="foo/bar") + serializer = HTTPRequestSerializer( + payload_codec=JSONCodec(), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/{label}"}), + omit_empty_payload=True, + ) + shape.serialize(serializer) + actual = serializer.result + assert actual is not None + actual_body_value = await AsyncBytesReader(actual.body).read() + assert actual_body_value == b"" + + RESPONSE_SER_CASES: list[HTTPMessageTestCase] = ( header_cases() + empty_prefix_header_ser_cases() + payload_cases() ) @@ -1677,6 +1691,20 @@ async def test_serialize_http_response(case: HTTPMessageTestCase) -> None: assert type(actual.body) is type(case.request.body) +async def test_serialize_response_omitting_empty_payload() -> None: + shape = HTTPHeaders(boolean_member=True) + serializer = HTTPResponseSerializer( + payload_codec=JSONCodec(), + http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/"}), + omit_empty_payload=True, + ) + shape.serialize(serializer) + actual = serializer.result + assert actual is not None + actual_body_value = await AsyncBytesReader(actual.body).read() + assert actual_body_value == b"" + + RESPONSE_DESER_CASES: list[HTTPMessageTestCase] = ( header_cases() + empty_prefix_header_deser_cases() + payload_cases() )