-
-
Notifications
You must be signed in to change notification settings - Fork 804
🚨 Cover middleware/wsgi.py on mypy #1075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7a82d01
441c648
b181df3
a80e6bd
3359e85
13ea0fd
f8ff765
aaab4bb
910385a
6f2b069
4d19f04
dc1db63
68e381a
8a6cf03
aa80a78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,23 @@ | |
import concurrent.futures | ||
import io | ||
import sys | ||
from typing import Iterable, List, Optional, Tuple | ||
|
||
from asgiref.typing import ( | ||
ASGIReceiveCallable, | ||
ASGIReceiveEvent, | ||
ASGISendCallable, | ||
ASGISendEvent, | ||
HTTPRequestEvent, | ||
HTTPResponseBodyEvent, | ||
HTTPResponseStartEvent, | ||
HTTPScope, | ||
) | ||
|
||
def build_environ(scope, message, body): | ||
from uvicorn._types import Environ, ExcInfo, StartResponse, WSGIApp | ||
|
||
|
||
def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: bytes) -> Environ: | ||
""" | ||
Builds a scope and request message into a WSGI environ object. | ||
""" | ||
|
@@ -37,52 +51,63 @@ def build_environ(scope, message, body): | |
|
||
# Go through headers and make them into environ entries | ||
for name, value in scope.get("headers", []): | ||
name = name.decode("latin1") | ||
if name == "content-length": | ||
name_str: str = name.decode("latin1") | ||
if name_str == "content-length": | ||
corrected_name = "CONTENT_LENGTH" | ||
elif name == "content-type": | ||
elif name_str == "content-type": | ||
corrected_name = "CONTENT_TYPE" | ||
else: | ||
corrected_name = "HTTP_%s" % name.upper().replace("-", "_") | ||
corrected_name = "HTTP_%s" % name_str.upper().replace("-", "_") | ||
# HTTPbis say only ASCII chars are allowed in headers, but we latin1 | ||
# just in case | ||
value = value.decode("latin1") | ||
value_str: str = value.decode("latin1") | ||
if corrected_name in environ: | ||
value = environ[corrected_name] + "," + value | ||
environ[corrected_name] = value | ||
corrected_name_environ = environ[corrected_name] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks a little too much just to satisfy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we just ignore it here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lgtm |
||
assert isinstance(corrected_name_environ, str) | ||
value_str = corrected_name_environ + "," + value_str | ||
environ[corrected_name] = value_str | ||
return environ | ||
|
||
|
||
class WSGIMiddleware: | ||
def __init__(self, app, workers=10): | ||
def __init__(self, app: WSGIApp, workers: int = 10): | ||
self.app = app | ||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=workers) | ||
|
||
async def __call__(self, scope, receive, send): | ||
async def __call__( | ||
self, scope: HTTPScope, receive: ASGIReceiveCallable, send: ASGISendCallable | ||
Kludex marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> None: | ||
assert scope["type"] == "http" | ||
instance = WSGIResponder(self.app, self.executor, scope) | ||
await instance(receive, send) | ||
|
||
|
||
class WSGIResponder: | ||
def __init__(self, app, executor, scope): | ||
def __init__( | ||
self, | ||
app: WSGIApp, | ||
executor: concurrent.futures.ThreadPoolExecutor, | ||
Vibhu-Agarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
scope: HTTPScope, | ||
): | ||
self.app = app | ||
self.executor = executor | ||
self.scope = scope | ||
self.status = None | ||
self.response_headers = None | ||
self.send_event = asyncio.Event() | ||
self.send_queue = [] | ||
self.loop = None | ||
self.send_queue: List[Optional[ASGISendEvent]] = [] | ||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() | ||
self.response_started = False | ||
self.exc_info = None | ||
self.exc_info: Optional[ExcInfo] = None | ||
|
||
async def __call__(self, receive, send): | ||
message = await receive() | ||
async def __call__( | ||
self, receive: ASGIReceiveCallable, send: ASGISendCallable | ||
) -> None: | ||
message: HTTPRequestEvent = await receive() # type: ignore[assignment] | ||
body = message.get("body", b"") | ||
more_body = message.get("more_body", False) | ||
while more_body: | ||
body_message = await receive() | ||
body_message: HTTPRequestEvent = await receive() # type: ignore[assignment] | ||
body += body_message.get("body", b"") | ||
more_body = body_message.get("more_body", False) | ||
environ = build_environ(self.scope, message, body) | ||
|
@@ -100,7 +125,7 @@ async def __call__(self, receive, send): | |
if self.exc_info is not None: | ||
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) | ||
|
||
async def sender(self, send): | ||
async def sender(self, send: ASGISendCallable) -> None: | ||
while True: | ||
if self.send_queue: | ||
message = self.send_queue.pop(0) | ||
|
@@ -111,31 +136,43 @@ async def sender(self, send): | |
await self.send_event.wait() | ||
self.send_event.clear() | ||
|
||
def start_response(self, status, response_headers, exc_info=None): | ||
def start_response( | ||
self, | ||
status: str, | ||
response_headers: Iterable[Tuple[str, str]], | ||
exc_info: Optional[ExcInfo] = None, | ||
) -> None: | ||
self.exc_info = exc_info | ||
if not self.response_started: | ||
self.response_started = True | ||
status_code, _ = status.split(" ", 1) | ||
status_code = int(status_code) | ||
status_code_str, _ = status.split(" ", 1) | ||
status_code = int(status_code_str) | ||
headers = [ | ||
(name.encode("ascii"), value.encode("ascii")) | ||
for name, value in response_headers | ||
] | ||
self.send_queue.append( | ||
{ | ||
"type": "http.response.start", | ||
"status": status_code, | ||
"headers": headers, | ||
} | ||
) | ||
http_response_start_event: HTTPResponseStartEvent = { | ||
"type": "http.response.start", | ||
"status": status_code, | ||
"headers": headers, | ||
} | ||
self.send_queue.append(http_response_start_event) | ||
self.loop.call_soon_threadsafe(self.send_event.set) | ||
|
||
def wsgi(self, environ, start_response): | ||
for chunk in self.app(environ, start_response): | ||
self.send_queue.append( | ||
{"type": "http.response.body", "body": chunk, "more_body": True} | ||
) | ||
def wsgi(self, environ: Environ, start_response: StartResponse) -> None: | ||
for chunk in self.app(environ, start_response): # type: ignore | ||
response_body: HTTPResponseBodyEvent = { | ||
"type": "http.response.body", | ||
"body": chunk, | ||
"more_body": True, | ||
} | ||
self.send_queue.append(response_body) | ||
self.loop.call_soon_threadsafe(self.send_event.set) | ||
|
||
self.send_queue.append({"type": "http.response.body", "body": b""}) | ||
empty_body: HTTPResponseBodyEvent = { | ||
"type": "http.response.body", | ||
"body": b"", | ||
"more_body": False, | ||
} | ||
self.send_queue.append(empty_body) | ||
self.loop.call_soon_threadsafe(self.send_event.set) |
Uh oh!
There was an error while loading. Please reload this page.