Skip to content

🚨 Cover middleware/wsgi.py on mypy #1075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ files =
uvicorn/protocols/websockets/auto.py,
uvicorn/supervisors/__init__.py,
uvicorn/middleware/debug.py,
uvicorn/middleware/wsgi.py,
tests/middleware/test_wsgi.py,
uvicorn/supervisors/watchgodreload.py,
uvicorn/logging.py,
uvicorn/middleware/asgi2.py,
Expand Down
42 changes: 28 additions & 14 deletions tests/middleware/test_wsgi.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,41 @@
import sys
from typing import List

import httpx
import pytest
from asgiref.typing import HTTPRequestEvent, HTTPScope

from uvicorn._types import Environ, StartResponse
from uvicorn.middleware.wsgi import WSGIMiddleware, build_environ


def hello_world(environ, start_response):
def hello_world(environ: Environ, start_response: StartResponse) -> List[bytes]:
status = "200 OK"
output = b"Hello World!\n"
headers = [
("Content-Type", "text/plain; charset=utf-8"),
("Content-Length", str(len(output))),
]
start_response(status, headers)
start_response(status, headers, None)
return [output]


def echo_body(environ, start_response):
def echo_body(environ: Environ, start_response: StartResponse) -> List[bytes]:
status = "200 OK"
output = environ["wsgi.input"].read()
headers = [
("Content-Type", "text/plain; charset=utf-8"),
("Content-Length", str(len(output))),
]
start_response(status, headers)
start_response(status, headers, None)
return [output]


def raise_exception(environ, start_response):
def raise_exception(environ: Environ, start_response: StartResponse) -> RuntimeError:
raise RuntimeError("Something went wrong")


def return_exc_info(environ, start_response):
def return_exc_info(environ: Environ, start_response: StartResponse) -> List[bytes]:
try:
raise RuntimeError("Something went wrong")
except RuntimeError:
Expand All @@ -42,12 +45,12 @@ def return_exc_info(environ, start_response):
("Content-Type", "text/plain; charset=utf-8"),
("Content-Length", str(len(output))),
]
start_response(status, headers, exc_info=sys.exc_info())
start_response(status, headers, sys.exc_info()) # type: ignore[arg-type]
return [output]


@pytest.mark.asyncio
async def test_wsgi_get():
async def test_wsgi_get() -> None:
app = WSGIMiddleware(hello_world)
async with httpx.AsyncClient(app=app, base_url="http://testserver") as client:
response = await client.get("/")
Expand All @@ -56,7 +59,7 @@ async def test_wsgi_get():


@pytest.mark.asyncio
async def test_wsgi_post():
async def test_wsgi_post() -> None:
app = WSGIMiddleware(echo_body)
async with httpx.AsyncClient(app=app, base_url="http://testserver") as client:
response = await client.post("/", json={"example": 123})
Expand All @@ -65,7 +68,7 @@ async def test_wsgi_post():


@pytest.mark.asyncio
async def test_wsgi_exception():
async def test_wsgi_exception() -> None:
# Note that we're testing the WSGI app directly here.
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(raise_exception)
Expand All @@ -75,7 +78,7 @@ async def test_wsgi_exception():


@pytest.mark.asyncio
async def test_wsgi_exc_info():
async def test_wsgi_exc_info() -> None:
# Note that we're testing the WSGI app directly here.
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(return_exc_info)
Expand All @@ -96,16 +99,27 @@ async def test_wsgi_exc_info():
assert response.text == "Internal Server Error"


def test_build_environ_encoding():
scope = {
def test_build_environ_encoding() -> None:
scope: HTTPScope = {
"asgi": {"version": "3.0", "spec_version": "2.0"},
"scheme": "http",
"raw_path": b"/\xe6\x96\x87",
"type": "http",
"http_version": "1.1",
"method": "GET",
"path": "/文",
"root_path": "/文",
"client": None,
"server": None,
"query_string": b"a=123&b=456",
"headers": [(b"key", b"value1"), (b"key", b"value2")],
"extensions": {},
}
environ = build_environ(scope, b"", b"")
message: HTTPRequestEvent = {
"type": "http.request",
"body": b"",
"more_body": False,
}
environ = build_environ(scope, message, b"")
assert environ["PATH_INFO"] == "/文".encode("utf8").decode("latin-1")
assert environ["HTTP_KEY"] == "value1,value2"
3 changes: 3 additions & 0 deletions uvicorn/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@
StartResponse = typing.Callable[
[str, typing.Iterable[typing.Tuple[str, str]], typing.Optional[ExcInfo]], None
]
WSGIApp = typing.Callable[
[Environ, StartResponse], typing.Union[typing.Iterable[bytes], BaseException]
]
105 changes: 71 additions & 34 deletions uvicorn/middleware/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,23 @@
import concurrent.futures
import io
import sys
from typing import Iterable, List, Optional, Tuple

from asgiref.typing import (
ASGIReceiveCallable,
ASGIReceiveEvent,
ASGISendCallable,
ASGISendEvent,
HTTPRequestEvent,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
HTTPScope,
)

def build_environ(scope, message, body):
from uvicorn._types import Environ, ExcInfo, StartResponse, WSGIApp


def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: bytes) -> Environ:
"""
Builds a scope and request message into a WSGI environ object.
"""
Expand Down Expand Up @@ -37,52 +51,63 @@ def build_environ(scope, message, body):

# Go through headers and make them into environ entries
for name, value in scope.get("headers", []):
name = name.decode("latin1")
if name == "content-length":
name_str: str = name.decode("latin1")
if name_str == "content-length":
corrected_name = "CONTENT_LENGTH"
elif name == "content-type":
elif name_str == "content-type":
corrected_name = "CONTENT_TYPE"
else:
corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
corrected_name = "HTTP_%s" % name_str.upper().replace("-", "_")
# HTTPbis say only ASCII chars are allowed in headers, but we latin1
# just in case
value = value.decode("latin1")
value_str: str = value.decode("latin1")
if corrected_name in environ:
value = environ[corrected_name] + "," + value
environ[corrected_name] = value
corrected_name_environ = environ[corrected_name]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a little too much just to satisfy mypy 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just ignore it here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

assert isinstance(corrected_name_environ, str)
value_str = corrected_name_environ + "," + value_str
environ[corrected_name] = value_str
return environ


class WSGIMiddleware:
def __init__(self, app, workers=10):
def __init__(self, app: WSGIApp, workers: int = 10):
self.app = app
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=workers)

async def __call__(self, scope, receive, send):
async def __call__(
self, scope: HTTPScope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
assert scope["type"] == "http"
instance = WSGIResponder(self.app, self.executor, scope)
await instance(receive, send)


class WSGIResponder:
def __init__(self, app, executor, scope):
def __init__(
self,
app: WSGIApp,
executor: concurrent.futures.ThreadPoolExecutor,
scope: HTTPScope,
):
self.app = app
self.executor = executor
self.scope = scope
self.status = None
self.response_headers = None
self.send_event = asyncio.Event()
self.send_queue = []
self.loop = None
self.send_queue: List[Optional[ASGISendEvent]] = []
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
self.response_started = False
self.exc_info = None
self.exc_info: Optional[ExcInfo] = None

async def __call__(self, receive, send):
message = await receive()
async def __call__(
self, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
message: HTTPRequestEvent = await receive() # type: ignore[assignment]
body = message.get("body", b"")
more_body = message.get("more_body", False)
while more_body:
body_message = await receive()
body_message: HTTPRequestEvent = await receive() # type: ignore[assignment]
body += body_message.get("body", b"")
more_body = body_message.get("more_body", False)
environ = build_environ(self.scope, message, body)
Expand All @@ -100,7 +125,7 @@ async def __call__(self, receive, send):
if self.exc_info is not None:
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])

async def sender(self, send):
async def sender(self, send: ASGISendCallable) -> None:
while True:
if self.send_queue:
message = self.send_queue.pop(0)
Expand All @@ -111,31 +136,43 @@ async def sender(self, send):
await self.send_event.wait()
self.send_event.clear()

def start_response(self, status, response_headers, exc_info=None):
def start_response(
self,
status: str,
response_headers: Iterable[Tuple[str, str]],
exc_info: Optional[ExcInfo] = None,
) -> None:
self.exc_info = exc_info
if not self.response_started:
self.response_started = True
status_code, _ = status.split(" ", 1)
status_code = int(status_code)
status_code_str, _ = status.split(" ", 1)
status_code = int(status_code_str)
headers = [
(name.encode("ascii"), value.encode("ascii"))
for name, value in response_headers
]
self.send_queue.append(
{
"type": "http.response.start",
"status": status_code,
"headers": headers,
}
)
http_response_start_event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": status_code,
"headers": headers,
}
self.send_queue.append(http_response_start_event)
self.loop.call_soon_threadsafe(self.send_event.set)

def wsgi(self, environ, start_response):
for chunk in self.app(environ, start_response):
self.send_queue.append(
{"type": "http.response.body", "body": chunk, "more_body": True}
)
def wsgi(self, environ: Environ, start_response: StartResponse) -> None:
for chunk in self.app(environ, start_response): # type: ignore
response_body: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": chunk,
"more_body": True,
}
self.send_queue.append(response_body)
self.loop.call_soon_threadsafe(self.send_event.set)

self.send_queue.append({"type": "http.response.body", "body": b""})
empty_body: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": b"",
"more_body": False,
}
self.send_queue.append(empty_body)
self.loop.call_soon_threadsafe(self.send_event.set)