Skip to content

Properly buffer async response bodies #504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions packages/smithy-http/src/smithy_http/aio/protocols.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from collections.abc import AsyncIterable
from inspect import iscoroutinefunction
from io import BytesIO
from typing import Any

from smithy_core.aio.interfaces import ClientProtocol
from smithy_core.aio.interfaces import AsyncByteStream, ClientProtocol
from smithy_core.aio.interfaces import StreamingBlob as AsyncStreamingBlob
from smithy_core.codecs import Codec
from smithy_core.deserializers import DeserializeableShape
from smithy_core.documents import TypeRegistry
Expand Down Expand Up @@ -109,35 +110,45 @@ async def deserialize_response[
error_registry: TypeRegistry,
context: TypedProperties,
) -> OperationOutput:
body = response.body

# if body is not streaming and is async, we have to buffer it
if not operation.output_stream_member and not is_streaming_blob(body):
if (
read := getattr(body, "read", None)
) is not None and iscoroutinefunction(read):
body = BytesIO(await read())

if not self._is_success(operation, context, response):
raise await self._create_error(
operation=operation,
request=request,
response=response,
response_body=body, # type: ignore
response_body=await self._buffer_async_body(response.body),
error_registry=error_registry,
context=context,
)

# if body is not streaming and is async, we have to buffer it
body: SyncStreamingBlob | None = None
if not operation.output_stream_member and not is_streaming_blob(body):
body = await self._buffer_async_body(response.body)

# TODO(optimization): response binding cache like done in SJ
deserializer = HTTPResponseDeserializer(
payload_codec=self.payload_codec,
http_trait=operation.schema.expect_trait(HTTPTrait),
response=response,
body=body, # type: ignore
body=body,
)

return operation.output.deserialize(deserializer)

async def _buffer_async_body(self, stream: AsyncStreamingBlob) -> SyncStreamingBlob:
match stream:
case AsyncByteStream():
if not iscoroutinefunction(stream.read):
return stream # type: ignore
return await stream.read()
case AsyncIterable():
full = b""
async for chunk in stream:
full += chunk
return full
case _:
return stream

def _is_success(
self,
operation: APIOperation[Any, Any],
Expand Down