Skip to content

Commit eaf43b3

Browse files
Add support for async auth flows
1 parent 4161d7a commit eaf43b3

File tree

3 files changed

+75
-4
lines changed

3 files changed

+75
-4
lines changed

httpx/_auth.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ class Auth:
1717
1818
To implement a custom authentication scheme, subclass `Auth` and override
1919
the `.auth_flow()` method.
20+
21+
If the authentication scheme does I/O, such as disk access or network calls, or uses
22+
synchronization primitives such as locks, you should override `.async_auth_flow()`
23+
to provide an async-friendly implementation that will be used by the `AsyncClient`.
24+
Usage of sync I/O within an async codebase would block the event loop, and could
25+
cause performance issues.
2026
"""
2127

2228
requires_request_body = False
@@ -46,6 +52,26 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non
4652
"""
4753
yield request
4854

55+
async def async_auth_flow(
56+
self, request: Request
57+
) -> typing.AsyncGenerator[Request, Response]:
58+
"""
59+
Execute the authentication flow asynchronously.
60+
61+
By default, this defers to `.auth_flow()`. You should override this method
62+
when the authentication scheme does I/O, such as disk access or network calls,
63+
or uses concurrency primitives such as locks.
64+
"""
65+
flow = self.auth_flow(request)
66+
request = next(flow)
67+
68+
while True:
69+
response = yield request
70+
try:
71+
request = flow.send(response)
72+
except StopIteration:
73+
break
74+
4975

5076
class FunctionAuth(Auth):
5177
"""

httpx/_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,15 +1365,15 @@ async def _send_handling_auth(
13651365
if auth.requires_request_body:
13661366
await request.aread()
13671367

1368-
auth_flow = auth.auth_flow(request)
1369-
request = next(auth_flow)
1368+
auth_flow = auth.async_auth_flow(request)
1369+
request = await auth_flow.__anext__()
13701370
while True:
13711371
response = await self._send_single_request(request, timeout)
13721372
if auth.requires_response_body:
13731373
await response.aread()
13741374
try:
1375-
next_request = auth_flow.send(response)
1376-
except StopIteration:
1375+
next_request = await auth_flow.asend(response)
1376+
except StopAsyncIteration:
13771377
return response
13781378
except BaseException as exc:
13791379
await response.aclose()

tests/client/test_auth.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
12
import hashlib
23
import os
34
import typing
5+
import threading
46

57
import httpcore
68
import pytest
@@ -184,6 +186,29 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non
184186
yield request
185187

186188

189+
class SyncOrAsyncAuth(Auth):
190+
"""
191+
A mock authentication scheme that uses a different implementation for the
192+
sync and async cases.
193+
"""
194+
195+
def __init__(self):
196+
self._lock = threading.Lock()
197+
self._async_lock = asyncio.Lock()
198+
199+
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
200+
with self._lock:
201+
request.headers["Authorization"] = "sync-auth"
202+
yield request
203+
204+
async def async_auth_flow(
205+
self, request: Request
206+
) -> typing.AsyncGenerator[Request, Response]:
207+
async with self._async_lock:
208+
request.headers["Authorization"] = "async-auth"
209+
yield request
210+
211+
187212
@pytest.mark.asyncio
188213
async def test_basic_auth() -> None:
189214
url = "https://example.org/"
@@ -641,3 +666,23 @@ def test_sync_auth_reads_response_body() -> None:
641666
response = client.get(url, auth=auth)
642667
assert response.status_code == 200
643668
assert response.json() == {"auth": '{"auth": "xyz"}'}
669+
670+
671+
@pytest.mark.asyncio
672+
async def test_sync_async_auth() -> None:
673+
"""
674+
Test that we can use a different auth flow implementation in the async case, to
675+
support cases that require performing I/O or using concurrency primitives (such
676+
as checking a disk-based cache or fetching a token from a remote auth server).
677+
"""
678+
url = "https://example.org/"
679+
auth = SyncOrAsyncAuth()
680+
681+
client = AsyncClient(transport=AsyncMockTransport())
682+
response = await client.get(url, auth=auth)
683+
assert response.status_code == 200
684+
assert response.json() == {"auth": "async-auth"}
685+
686+
response = Client(transport=SyncMockTransport()).get(url, auth=auth)
687+
assert response.status_code == 200
688+
assert response.json() == {"auth": "sync-auth"}

0 commit comments

Comments
 (0)