Skip to content

Centralize http binding matching and omit empty payloads #503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/smithy-core/src/smithy_core/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,11 @@ def refresh_retry_token_for_retry(
if retry_count >= self.max_attempts:
raise RetryError(
f"Reached maximum number of allowed attempts: {self.max_attempts}"
)
) from error
retry_delay = self.backoff_strategy.compute_next_backoff_delay(retry_count)
return SimpleRetryToken(retry_count=retry_count, retry_delay=retry_delay)
else:
raise RetryError(f"Error is not retryable: {error}")
raise RetryError(f"Error is not retryable: {error}") from error

def record_success(self, *, token: retries_interface.RetryToken) -> None:
"""Not used by this retry strategy."""
172 changes: 172 additions & 0 deletions packages/smithy-http/src/smithy_http/bindings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from enum import Enum

from smithy_core.schemas import Schema
from smithy_core.shapes import ShapeType
from smithy_core.traits import (
ErrorFault,
ErrorTrait,
HostLabelTrait,
HTTPErrorTrait,
HTTPHeaderTrait,
HTTPLabelTrait,
HTTPPayloadTrait,
HTTPPrefixHeadersTrait,
HTTPQueryParamsTrait,
HTTPQueryTrait,
HTTPResponseCodeTrait,
StreamingTrait,
)


class Binding(Enum):
"""HTTP binding locations."""

HEADER = 0
"""Indicates the member is bound to a header."""

QUERY = 1
"""Indicates the member is bound to a query parameter."""

PAYLOAD = 2
"""Indicates the member is bound to the entire HTTP payload."""

BODY = 3
"""Indicates the member is a property in the HTTP payload structure."""

LABEL = 4
"""Indicates the member is bound to a path segment in the URI."""

STATUS = 5
"""Indicates the member is bound to the response status code."""

PREFIX_HEADERS = 6
"""Indicates the member is bound to multiple headers with a shared prefix."""

QUERY_PARAMS = 7
"""Indicates the member is bound to the query string as multiple key-value pairs."""

HOST = 8
"""Indicates the member is bound to a prefix to the host AND to the body."""


@dataclass(init=False)
class _BindingMatcher:
bindings: list[Binding]
"""A list of bindings where the index matches the index of the member schema."""

response_status: int
"""The default response status code."""

has_body: bool
"""Whether the HTTP message has members bound to the body."""

has_payload: bool
"""Whether the HTTP message has a member bound to the entire payload."""

payload_member: Schema | None
"""The member bound to the payload, if one exists."""

event_stream_member: Schema | None
"""The member bound to the event stream, if one exists."""

def __init__(self, struct: Schema, response_status: int) -> None:
self.response_status = response_status
found_body = False
found_payload = False
self.bindings = [Binding.BODY] * len(struct.members)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there we're a reason we're initializing this list here if we're going to swap things out below? Is there a benefit to not just initialize it all at once?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of the members in the dictionary isn't the same as the modeled index order, so we can't just append and expect the order to be right. I've got some ideas on how to change schemas so the member order is represented in the dict though. Perhaps I should put that up, it'll reduce the amount of generated code too.

self.payload_member = None
self.event_stream_member = None

for member in struct.members.values():
binding = self._do_match(member)
self.bindings[member.expect_member_index()] = binding
found_body = (
found_body or binding is Binding.BODY or binding is Binding.HOST
)
if binding is Binding.PAYLOAD:
found_payload = True
self.payload_member = member
if (
StreamingTrait.id in member.traits
and member.shape_type is ShapeType.UNION
):
self.event_stream_member = member

self.has_body = found_body
self.has_payload = found_payload

def should_write_body(self, omit_empty_payload: bool) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit; We may want to potentially make this a keyword only arg? Calling should_write_body(True) is kind of confusing when the parameter is indicating the opposite. I see we're aliasing the value to a parameter with the same name though in actual usage.

"""Determines whether a body should be written.

:param omit_empty_payload: Whether a body should be skipped in the case of an
empty payload.
"""
return self.has_body or (not omit_empty_payload and not self.has_payload)

def match(self, member: Schema) -> Binding:
"""Determines which part of the HTTP message the given member is bound to."""
return self.bindings[member.expect_member_index()]

def _do_match(self, member: Schema) -> Binding: ...


@dataclass(init=False)
class RequestBindingMatcher(_BindingMatcher):
"""Matches structure members to HTTP request binding locations."""

def __init__(self, struct: Schema) -> None:
"""Initialize a RequestBindingMatcher.

:param struct: The structure to examine for HTTP bindings.
"""
super().__init__(struct, -1)

def _do_match(self, member: Schema) -> Binding:
if HTTPLabelTrait.id in member.traits:
return Binding.LABEL
if HTTPQueryTrait.id in member.traits:
return Binding.QUERY
if HTTPQueryParamsTrait.id in member.traits:
return Binding.QUERY_PARAMS
if HTTPHeaderTrait.id in member.traits:
return Binding.HEADER
if HTTPPrefixHeadersTrait.id in member.traits:
return Binding.PREFIX_HEADERS
if HTTPPayloadTrait.id in member.traits:
return Binding.PAYLOAD
if HostLabelTrait.id in member.traits:
return Binding.HOST
return Binding.BODY


@dataclass(init=False)
class ResponseBindingMatcher(_BindingMatcher):
"""Matches structure members to HTTP response binding locations."""

def __init__(self, struct: Schema) -> None:
"""Initialize a ResponseBindingMatcher.

:param struct: The structure to examine for HTTP bindings.
"""
super().__init__(struct, self._compute_response(struct))

def _compute_response(self, struct: Schema) -> int:
if (http_error := struct.get_trait(HTTPErrorTrait)) is not None:
return http_error.code
if (error := struct.get_trait(ErrorTrait)) is not None:
return 400 if error.fault is ErrorFault.CLIENT else 500
return -1

def _do_match(self, member: Schema) -> Binding:
if HTTPResponseCodeTrait.id in member.traits:
return Binding.STATUS
if HTTPHeaderTrait.id in member.traits:
return Binding.HEADER
if HTTPPrefixHeadersTrait.id in member.traits:
return Binding.PREFIX_HEADERS
if HTTPPayloadTrait.id in member.traits:
return Binding.PAYLOAD
return Binding.BODY
79 changes: 40 additions & 39 deletions packages/smithy-http/src/smithy_http/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
from smithy_core.shapes import ShapeType
from smithy_core.traits import (
HTTPHeaderTrait,
HTTPPayloadTrait,
HTTPPrefixHeadersTrait,
HTTPResponseCodeTrait,
HTTPTrait,
TimestampFormatTrait,
)
from smithy_core.types import TimestampFormat
from smithy_core.utils import ensure_utc, strict_parse_bool, strict_parse_float

from .aio.interfaces import HTTPResponse
from .bindings import Binding, ResponseBindingMatcher
from .interfaces import Field, Fields

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,47 +60,49 @@ def __init__(
def read_struct(
self, schema: Schema, consumer: Callable[[Schema, ShapeDeserializer], None]
) -> None:
has_body = False
payload_member: Schema | None = None
binding_matcher = ResponseBindingMatcher(schema)

for member in schema.members.values():
if (trait := member.get_trait(HTTPHeaderTrait)) is not None:
header = self._response.fields.entries.get(trait.key.lower())
if header is not None:
if member.shape_type is ShapeType.LIST:
consumer(member, HTTPHeaderListDeserializer(header))
else:
consumer(member, HTTPHeaderDeserializer(header.as_string()))
elif (trait := member.get_trait(HTTPPrefixHeadersTrait)) is not None:
consumer(
member,
HTTPHeaderMapDeserializer(self._response.fields, trait.prefix),
)
elif HTTPPayloadTrait in member:
has_body = True
payload_member = member
elif HTTPResponseCodeTrait in member:
consumer(member, HTTPResponseCodeDeserializer(self._response.status))
else:
has_body = True

if has_body:
deserializer = self._create_payload_deserializer(payload_member)
if payload_member is not None:
consumer(payload_member, deserializer)
else:
deserializer.read_struct(schema, consumer)

def _create_payload_deserializer(
self, payload_member: Schema | None
) -> ShapeDeserializer:
body = self._body if self._body is not None else self._response.body
if payload_member is not None and payload_member.shape_type in (
ShapeType.BLOB,
ShapeType.STRING,
):
match binding_matcher.match(member):
case Binding.HEADER:
trait = member.expect_trait(HTTPHeaderTrait)
header = self._response.fields.entries.get(trait.key.lower())
if header is not None:
if member.shape_type is ShapeType.LIST:
consumer(member, HTTPHeaderListDeserializer(header))
else:
consumer(member, HTTPHeaderDeserializer(header.as_string()))
case Binding.PREFIX_HEADERS:
trait = member.expect_trait(HTTPPrefixHeadersTrait)
consumer(
member,
HTTPHeaderMapDeserializer(self._response.fields, trait.prefix),
)
case Binding.STATUS:
consumer(
member, HTTPResponseCodeDeserializer(self._response.status)
)
case Binding.PAYLOAD:
assert binding_matcher.payload_member is not None # noqa: S101
deserializer = self._create_payload_deserializer(
binding_matcher.payload_member
)
consumer(binding_matcher.payload_member, deserializer)
case _:
pass

if binding_matcher.has_body:
deserializer = self._create_body_deserializer()
deserializer.read_struct(schema, consumer)

def _create_payload_deserializer(self, payload_member: Schema) -> ShapeDeserializer:
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
body = self._body if self._body is not None else self._response.body
return RawPayloadDeserializer(body)
return self._create_body_deserializer()

def _create_body_deserializer(self):
body = self._body if self._body is not None else self._response.body
if not is_streaming_blob(body):
raise UnsupportedStreamError(
"Unable to read async stream. This stream must be buffered prior "
Expand Down
Loading