4
4
import httpcore
5
5
import sniffio
6
6
7
- from .._content_streams import ByteStream
7
+ from .._content_streams import AsyncIteratorStream , ByteStream
8
8
from .._utils import warn_deprecated
9
9
10
10
if typing .TYPE_CHECKING : # pragma: no cover
@@ -25,6 +25,75 @@ def create_event() -> "Event":
25
25
return asyncio .Event ()
26
26
27
27
28
+ async def create_background_task (async_fn : typing .Callable ) -> typing .Callable :
29
+ if sniffio .current_async_library () == "trio" :
30
+ import trio
31
+
32
+ nursery_manager = trio .open_nursery ()
33
+ nursery = await nursery_manager .__aenter__ ()
34
+ nursery .start_soon (async_fn )
35
+
36
+ async def aclose (exc : Exception = None ) -> None :
37
+ if exc is not None :
38
+ await nursery_manager .__aexit__ (type (exc ), exc , exc .__traceback__ )
39
+ else :
40
+ await nursery_manager .__aexit__ (None , None , None )
41
+
42
+ return aclose
43
+
44
+ else :
45
+ import asyncio
46
+
47
+ task = asyncio .create_task (async_fn ())
48
+
49
+ async def aclose (exc : Exception = None ) -> None :
50
+ if not task .done ():
51
+ task .cancel ()
52
+
53
+ return aclose
54
+
55
+
56
+ def create_channel (
57
+ capacity : int ,
58
+ ) -> typing .Tuple [
59
+ typing .Callable [[], typing .Awaitable [bytes ]],
60
+ typing .Callable [[bytes ], typing .Awaitable [None ]],
61
+ ]:
62
+ if sniffio .current_async_library () == "trio" :
63
+ import trio
64
+
65
+ send_channel , receive_channel = trio .open_memory_channel [bytes ](capacity )
66
+ return receive_channel .receive , send_channel .send
67
+
68
+ else :
69
+ import asyncio
70
+
71
+ queue : asyncio .Queue [bytes ] = asyncio .Queue (capacity )
72
+ return queue .get , queue .put
73
+
74
+
75
+ async def run_until_first_complete (* async_fns : typing .Callable ) -> None :
76
+ if sniffio .current_async_library () == "trio" :
77
+ import trio
78
+
79
+ async with trio .open_nursery () as nursery :
80
+
81
+ async def run (async_fn : typing .Callable ) -> None :
82
+ await async_fn ()
83
+ nursery .cancel_scope .cancel ()
84
+
85
+ for async_fn in async_fns :
86
+ nursery .start_soon (run , async_fn )
87
+
88
+ else :
89
+ import asyncio
90
+
91
+ coros = [async_fn () for async_fn in async_fns ]
92
+ done , pending = await asyncio .wait (coros , return_when = asyncio .FIRST_COMPLETED )
93
+ for task in pending :
94
+ task .cancel ()
95
+
96
+
28
97
class ASGITransport (httpcore .AsyncHTTPTransport ):
29
98
"""
30
99
A custom AsyncTransport that handles sending requests directly to an ASGI app.
@@ -95,18 +164,20 @@ async def request(
95
164
}
96
165
status_code = None
97
166
response_headers = None
98
- body_parts = []
167
+ consume_response_body_chunk , produce_response_body_chunk = create_channel ( 1 )
99
168
request_complete = False
100
- response_started = False
169
+ response_started = create_event ()
101
170
response_complete = create_event ()
171
+ app_crashed = create_event ()
172
+ app_exception : typing .Optional [Exception ] = None
102
173
103
174
headers = [] if headers is None else headers
104
175
stream = ByteStream (b"" ) if stream is None else stream
105
176
106
177
request_body_chunks = stream .__aiter__ ()
107
178
108
179
async def receive () -> dict :
109
- nonlocal request_complete , response_complete
180
+ nonlocal request_complete
110
181
111
182
if request_complete :
112
183
await response_complete .wait ()
@@ -120,38 +191,76 @@ async def receive() -> dict:
120
191
return {"type" : "http.request" , "body" : body , "more_body" : True }
121
192
122
193
async def send (message : dict ) -> None :
123
- nonlocal status_code , response_headers , body_parts
124
- nonlocal response_started , response_complete
194
+ nonlocal status_code , response_headers
125
195
126
196
if message ["type" ] == "http.response.start" :
127
- assert not response_started
197
+ assert not response_started . is_set ()
128
198
129
199
status_code = message ["status" ]
130
200
response_headers = message .get ("headers" , [])
131
- response_started = True
201
+ response_started . set ()
132
202
133
203
elif message ["type" ] == "http.response.body" :
134
204
assert not response_complete .is_set ()
135
205
body = message .get ("body" , b"" )
136
206
more_body = message .get ("more_body" , False )
137
207
138
208
if body and method != b"HEAD" :
139
- body_parts . append (body )
209
+ await produce_response_body_chunk (body )
140
210
141
211
if not more_body :
142
212
response_complete .set ()
143
213
144
- try :
145
- await self .app (scope , receive , send )
146
- except Exception :
147
- if self .raise_app_exceptions or not response_complete :
148
- raise
214
+ async def run_app () -> None :
215
+ nonlocal app_exception
216
+ try :
217
+ await self .app (scope , receive , send )
218
+ except Exception as exc :
219
+ app_exception = exc
220
+ app_crashed .set ()
221
+
222
+ aclose_app = await create_background_task (run_app )
223
+
224
+ await run_until_first_complete (app_crashed .wait , response_started .wait )
149
225
150
- assert response_complete .is_set ()
226
+ if app_crashed .is_set ():
227
+ assert app_exception is not None
228
+ await aclose_app (app_exception )
229
+ if self .raise_app_exceptions or not response_started .is_set ():
230
+ raise app_exception
231
+
232
+ assert response_started .is_set ()
151
233
assert status_code is not None
152
234
assert response_headers is not None
153
235
154
- stream = ByteStream (b"" .join (body_parts ))
236
+ async def aiter_response_body_chunks () -> typing .AsyncIterator [bytes ]:
237
+ chunk = b""
238
+
239
+ async def consume_chunk () -> None :
240
+ nonlocal chunk
241
+ chunk = await consume_response_body_chunk ()
242
+
243
+ while True :
244
+ await run_until_first_complete (
245
+ app_crashed .wait , consume_chunk , response_complete .wait
246
+ )
247
+
248
+ if app_crashed .is_set ():
249
+ assert app_exception is not None
250
+ if self .raise_app_exceptions :
251
+ raise app_exception
252
+ else :
253
+ break
254
+
255
+ yield chunk
256
+
257
+ if response_complete .is_set ():
258
+ break
259
+
260
+ async def aclose () -> None :
261
+ await aclose_app (app_exception )
262
+
263
+ stream = AsyncIteratorStream (aiter_response_body_chunks (), close_func = aclose )
155
264
156
265
return (b"HTTP/1.1" , status_code , b"" , response_headers , stream )
157
266
0 commit comments