Skip to content

Commit c70d345

Browse files
Add support for streaming responses to ASGITransport
1 parent 66a4537 commit c70d345

File tree

2 files changed

+155
-17
lines changed

2 files changed

+155
-17
lines changed

httpx/_transports/asgi.py

Lines changed: 125 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import httpcore
55
import sniffio
66

7-
from .._content_streams import ByteStream
7+
from .._content_streams import AsyncIteratorStream, ByteStream
88
from .._utils import warn_deprecated
99

1010
if typing.TYPE_CHECKING: # pragma: no cover
@@ -25,6 +25,75 @@ def create_event() -> "Event":
2525
return asyncio.Event()
2626

2727

28+
async def create_background_task(async_fn: typing.Callable) -> typing.Callable:
29+
if sniffio.current_async_library() == "trio":
30+
import trio
31+
32+
nursery_manager = trio.open_nursery()
33+
nursery = await nursery_manager.__aenter__()
34+
nursery.start_soon(async_fn)
35+
36+
async def aclose(exc: Exception = None) -> None:
37+
if exc is not None:
38+
await nursery_manager.__aexit__(type(exc), exc, exc.__traceback__)
39+
else:
40+
await nursery_manager.__aexit__(None, None, None)
41+
42+
return aclose
43+
44+
else:
45+
import asyncio
46+
47+
task = asyncio.create_task(async_fn())
48+
49+
async def aclose(exc: Exception = None) -> None:
50+
if not task.done():
51+
task.cancel()
52+
53+
return aclose
54+
55+
56+
def create_channel(
57+
capacity: int,
58+
) -> typing.Tuple[
59+
typing.Callable[[], typing.Awaitable[bytes]],
60+
typing.Callable[[bytes], typing.Awaitable[None]],
61+
]:
62+
if sniffio.current_async_library() == "trio":
63+
import trio
64+
65+
send_channel, receive_channel = trio.open_memory_channel[bytes](capacity)
66+
return receive_channel.receive, send_channel.send
67+
68+
else:
69+
import asyncio
70+
71+
queue: asyncio.Queue[bytes] = asyncio.Queue(capacity)
72+
return queue.get, queue.put
73+
74+
75+
async def run_until_first_complete(*async_fns: typing.Callable) -> None:
76+
if sniffio.current_async_library() == "trio":
77+
import trio
78+
79+
async with trio.open_nursery() as nursery:
80+
81+
async def run(async_fn: typing.Callable) -> None:
82+
await async_fn()
83+
nursery.cancel_scope.cancel()
84+
85+
for async_fn in async_fns:
86+
nursery.start_soon(run, async_fn)
87+
88+
else:
89+
import asyncio
90+
91+
coros = [async_fn() for async_fn in async_fns]
92+
done, pending = await asyncio.wait(coros, return_when=asyncio.FIRST_COMPLETED)
93+
for task in pending:
94+
task.cancel()
95+
96+
2897
class ASGITransport(httpcore.AsyncHTTPTransport):
2998
"""
3099
A custom AsyncTransport that handles sending requests directly to an ASGI app.
@@ -95,18 +164,20 @@ async def request(
95164
}
96165
status_code = None
97166
response_headers = None
98-
body_parts = []
167+
consume_response_body_chunk, produce_response_body_chunk = create_channel(1)
99168
request_complete = False
100-
response_started = False
169+
response_started = create_event()
101170
response_complete = create_event()
171+
app_crashed = create_event()
172+
app_exception: typing.Optional[Exception] = None
102173

103174
headers = [] if headers is None else headers
104175
stream = ByteStream(b"") if stream is None else stream
105176

106177
request_body_chunks = stream.__aiter__()
107178

108179
async def receive() -> dict:
109-
nonlocal request_complete, response_complete
180+
nonlocal request_complete
110181

111182
if request_complete:
112183
await response_complete.wait()
@@ -120,38 +191,76 @@ async def receive() -> dict:
120191
return {"type": "http.request", "body": body, "more_body": True}
121192

122193
async def send(message: dict) -> None:
123-
nonlocal status_code, response_headers, body_parts
124-
nonlocal response_started, response_complete
194+
nonlocal status_code, response_headers
125195

126196
if message["type"] == "http.response.start":
127-
assert not response_started
197+
assert not response_started.is_set()
128198

129199
status_code = message["status"]
130200
response_headers = message.get("headers", [])
131-
response_started = True
201+
response_started.set()
132202

133203
elif message["type"] == "http.response.body":
134204
assert not response_complete.is_set()
135205
body = message.get("body", b"")
136206
more_body = message.get("more_body", False)
137207

138208
if body and method != b"HEAD":
139-
body_parts.append(body)
209+
await produce_response_body_chunk(body)
140210

141211
if not more_body:
142212
response_complete.set()
143213

144-
try:
145-
await self.app(scope, receive, send)
146-
except Exception:
147-
if self.raise_app_exceptions or not response_complete:
148-
raise
214+
async def run_app() -> None:
215+
nonlocal app_exception
216+
try:
217+
await self.app(scope, receive, send)
218+
except Exception as exc:
219+
app_exception = exc
220+
app_crashed.set()
221+
222+
aclose_app = await create_background_task(run_app)
223+
224+
await run_until_first_complete(app_crashed.wait, response_started.wait)
149225

150-
assert response_complete.is_set()
226+
if app_crashed.is_set():
227+
assert app_exception is not None
228+
await aclose_app(app_exception)
229+
if self.raise_app_exceptions or not response_started.is_set():
230+
raise app_exception
231+
232+
assert response_started.is_set()
151233
assert status_code is not None
152234
assert response_headers is not None
153235

154-
stream = ByteStream(b"".join(body_parts))
236+
async def aiter_response_body_chunks() -> typing.AsyncIterator[bytes]:
237+
chunk = b""
238+
239+
async def consume_chunk() -> None:
240+
nonlocal chunk
241+
chunk = await consume_response_body_chunk()
242+
243+
while True:
244+
await run_until_first_complete(
245+
app_crashed.wait, consume_chunk, response_complete.wait
246+
)
247+
248+
if app_crashed.is_set():
249+
assert app_exception is not None
250+
if self.raise_app_exceptions:
251+
raise app_exception
252+
else:
253+
break
254+
255+
yield chunk
256+
257+
if response_complete.is_set():
258+
break
259+
260+
async def aclose() -> None:
261+
await aclose_app(app_exception)
262+
263+
stream = AsyncIteratorStream(aiter_response_body_chunks(), close_func=aclose)
155264

156265
return (b"HTTP/1.1", status_code, b"", response_headers, stream)
157266

tests/test_asgi.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import httpx
44

5+
from .concurrency import sleep
6+
57

68
async def hello_world(scope, receive, send):
79
status = 200
@@ -35,7 +37,8 @@ async def raise_exc_after_response(scope, receive, send):
3537
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]
3638

3739
await send({"type": "http.response.start", "status": status, "headers": headers})
38-
await send({"type": "http.response.body", "body": output})
40+
await send({"type": "http.response.body", "body": output, "more_body": True})
41+
await sleep(0.001) # Let the transport detect that the response has started.
3942
raise ValueError()
4043

4144

@@ -99,3 +102,29 @@ async def read_body(scope, receive, send):
99102
response = await client.post("http://www.example.org/", data=b"example")
100103
assert response.status_code == 200
101104
assert disconnect
105+
106+
107+
@pytest.mark.asyncio
108+
async def test_asgi_streaming():
109+
client = httpx.AsyncClient(app=hello_world)
110+
async with client.stream("GET", "http://www.example.org/") as response:
111+
assert response.status_code == 200
112+
text = "".join([chunk async for chunk in response.aiter_text()])
113+
assert text == "Hello, World!"
114+
115+
116+
@pytest.mark.asyncio
117+
async def test_asgi_streaming_exc():
118+
client = httpx.AsyncClient(app=raise_exc)
119+
with pytest.raises(ValueError):
120+
async with client.stream("GET", "http://www.example.org/"):
121+
pass # pragma: no cover
122+
123+
124+
@pytest.mark.asyncio
125+
async def test_asgi_streaming_exc_after_response():
126+
client = httpx.AsyncClient(app=raise_exc_after_response)
127+
async with client.stream("GET", "http://www.example.org/") as response:
128+
with pytest.raises(ValueError):
129+
async for _ in response.aiter_bytes():
130+
pass # pragma: no cover

0 commit comments

Comments
 (0)