Skip to content

Commit 467e032

Browse files
🚨 Cover middleware/wsgi.py on mypy (#1075)
* add types to wsgi module * Add middleware/wsgi.py to setup.cfg * Smoothen out a few (mypy-related) bugs I don't know how to solve the remaining ones * Small Fix: use HTTPRequestEvent * Apply suggestions from PR review * Remove all mypy issues Taking ideas from: https://github.com/encode/uvicorn/blob/dd85cdacf154529ea1c3d12b5bda7808673979f2/uvicorn/middleware/wsgi.py * Adjusting w.r.t. to #1067 * Adjusting w.r.t. to #1067 * Applied suggestions from PR review * Fixed remaining test_wsgi.py mypy issues * Trying by removing "[typeddict-item]" beside "type: ignore" * Revert previous commit * Final fix ... yaayyy! Co-authored-by: Jaakko Lappalainen <[email protected]>
1 parent e885bbd commit 467e032

File tree

4 files changed

+104
-48
lines changed

4 files changed

+104
-48
lines changed

‎setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ files =
2323
uvicorn/protocols/websockets/auto.py,
2424
uvicorn/supervisors/__init__.py,
2525
uvicorn/middleware/debug.py,
26+
uvicorn/middleware/wsgi.py,
27+
tests/middleware/test_wsgi.py,
2628
uvicorn/supervisors/watchgodreload.py,
2729
uvicorn/logging.py,
2830
uvicorn/middleware/asgi2.py,

‎tests/middleware/test_wsgi.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,41 @@
11
import sys
2+
from typing import List
23

34
import httpx
45
import pytest
6+
from asgiref.typing import HTTPRequestEvent, HTTPScope
57

8+
from uvicorn._types import Environ, StartResponse
69
from uvicorn.middleware.wsgi import WSGIMiddleware, build_environ
710

811

9-
def hello_world(environ, start_response):
12+
def hello_world(environ: Environ, start_response: StartResponse) -> List[bytes]:
1013
status = "200 OK"
1114
output = b"Hello World!\n"
1215
headers = [
1316
("Content-Type", "text/plain; charset=utf-8"),
1417
("Content-Length", str(len(output))),
1518
]
16-
start_response(status, headers)
19+
start_response(status, headers, None)
1720
return [output]
1821

1922

20-
def echo_body(environ, start_response):
23+
def echo_body(environ: Environ, start_response: StartResponse) -> List[bytes]:
2124
status = "200 OK"
2225
output = environ["wsgi.input"].read()
2326
headers = [
2427
("Content-Type", "text/plain; charset=utf-8"),
2528
("Content-Length", str(len(output))),
2629
]
27-
start_response(status, headers)
30+
start_response(status, headers, None)
2831
return [output]
2932

3033

31-
def raise_exception(environ, start_response):
34+
def raise_exception(environ: Environ, start_response: StartResponse) -> RuntimeError:
3235
raise RuntimeError("Something went wrong")
3336

3437

35-
def return_exc_info(environ, start_response):
38+
def return_exc_info(environ: Environ, start_response: StartResponse) -> List[bytes]:
3639
try:
3740
raise RuntimeError("Something went wrong")
3841
except RuntimeError:
@@ -42,12 +45,12 @@ def return_exc_info(environ, start_response):
4245
("Content-Type", "text/plain; charset=utf-8"),
4346
("Content-Length", str(len(output))),
4447
]
45-
start_response(status, headers, exc_info=sys.exc_info())
48+
start_response(status, headers, sys.exc_info()) # type: ignore[arg-type]
4649
return [output]
4750

4851

4952
@pytest.mark.asyncio
50-
async def test_wsgi_get():
53+
async def test_wsgi_get() -> None:
5154
app = WSGIMiddleware(hello_world)
5255
async with httpx.AsyncClient(app=app, base_url="http://testserver") as client:
5356
response = await client.get("/")
@@ -56,7 +59,7 @@ async def test_wsgi_get():
5659

5760

5861
@pytest.mark.asyncio
59-
async def test_wsgi_post():
62+
async def test_wsgi_post() -> None:
6063
app = WSGIMiddleware(echo_body)
6164
async with httpx.AsyncClient(app=app, base_url="http://testserver") as client:
6265
response = await client.post("/", json={"example": 123})
@@ -65,7 +68,7 @@ async def test_wsgi_post():
6568

6669

6770
@pytest.mark.asyncio
68-
async def test_wsgi_exception():
71+
async def test_wsgi_exception() -> None:
6972
# Note that we're testing the WSGI app directly here.
7073
# The HTTP protocol implementations would catch this error and return 500.
7174
app = WSGIMiddleware(raise_exception)
@@ -75,7 +78,7 @@ async def test_wsgi_exception():
7578

7679

7780
@pytest.mark.asyncio
78-
async def test_wsgi_exc_info():
81+
async def test_wsgi_exc_info() -> None:
7982
# Note that we're testing the WSGI app directly here.
8083
# The HTTP protocol implementations would catch this error and return 500.
8184
app = WSGIMiddleware(return_exc_info)
@@ -96,16 +99,27 @@ async def test_wsgi_exc_info():
9699
assert response.text == "Internal Server Error"
97100

98101

99-
def test_build_environ_encoding():
100-
scope = {
102+
def test_build_environ_encoding() -> None:
103+
scope: HTTPScope = {
104+
"asgi": {"version": "3.0", "spec_version": "2.0"},
105+
"scheme": "http",
106+
"raw_path": b"/\xe6\x96\x87",
101107
"type": "http",
102108
"http_version": "1.1",
103109
"method": "GET",
104110
"path": "/æ–‡",
105111
"root_path": "/æ–‡",
112+
"client": None,
113+
"server": None,
106114
"query_string": b"a=123&b=456",
107115
"headers": [(b"key", b"value1"), (b"key", b"value2")],
116+
"extensions": {},
108117
}
109-
environ = build_environ(scope, b"", b"")
118+
message: HTTPRequestEvent = {
119+
"type": "http.request",
120+
"body": b"",
121+
"more_body": False,
122+
}
123+
environ = build_environ(scope, message, b"")
110124
assert environ["PATH_INFO"] == "/æ–‡".encode("utf8").decode("latin-1")
111125
assert environ["HTTP_KEY"] == "value1,value2"

‎uvicorn/_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@
99
StartResponse = typing.Callable[
1010
[str, typing.Iterable[typing.Tuple[str, str]], typing.Optional[ExcInfo]], None
1111
]
12+
WSGIApp = typing.Callable[
13+
[Environ, StartResponse], typing.Union[typing.Iterable[bytes], BaseException]
14+
]

‎uvicorn/middleware/wsgi.py

Lines changed: 71 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,23 @@
22
import concurrent.futures
33
import io
44
import sys
5+
from typing import Iterable, List, Optional, Tuple
56

7+
from asgiref.typing import (
8+
ASGIReceiveCallable,
9+
ASGIReceiveEvent,
10+
ASGISendCallable,
11+
ASGISendEvent,
12+
HTTPRequestEvent,
13+
HTTPResponseBodyEvent,
14+
HTTPResponseStartEvent,
15+
HTTPScope,
16+
)
617

7-
def build_environ(scope, message, body):
18+
from uvicorn._types import Environ, ExcInfo, StartResponse, WSGIApp
19+
20+
21+
def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: bytes) -> Environ:
822
"""
923
Builds a scope and request message into a WSGI environ object.
1024
"""
@@ -37,52 +51,63 @@ def build_environ(scope, message, body):
3751

3852
# Go through headers and make them into environ entries
3953
for name, value in scope.get("headers", []):
40-
name = name.decode("latin1")
41-
if name == "content-length":
54+
name_str: str = name.decode("latin1")
55+
if name_str == "content-length":
4256
corrected_name = "CONTENT_LENGTH"
43-
elif name == "content-type":
57+
elif name_str == "content-type":
4458
corrected_name = "CONTENT_TYPE"
4559
else:
46-
corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
60+
corrected_name = "HTTP_%s" % name_str.upper().replace("-", "_")
4761
# HTTPbis say only ASCII chars are allowed in headers, but we latin1
4862
# just in case
49-
value = value.decode("latin1")
63+
value_str: str = value.decode("latin1")
5064
if corrected_name in environ:
51-
value = environ[corrected_name] + "," + value
52-
environ[corrected_name] = value
65+
corrected_name_environ = environ[corrected_name]
66+
assert isinstance(corrected_name_environ, str)
67+
value_str = corrected_name_environ + "," + value_str
68+
environ[corrected_name] = value_str
5369
return environ
5470

5571

5672
class WSGIMiddleware:
57-
def __init__(self, app, workers=10):
73+
def __init__(self, app: WSGIApp, workers: int = 10):
5874
self.app = app
5975
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=workers)
6076

61-
async def __call__(self, scope, receive, send):
77+
async def __call__(
78+
self, scope: HTTPScope, receive: ASGIReceiveCallable, send: ASGISendCallable
79+
) -> None:
6280
assert scope["type"] == "http"
6381
instance = WSGIResponder(self.app, self.executor, scope)
6482
await instance(receive, send)
6583

6684

6785
class WSGIResponder:
68-
def __init__(self, app, executor, scope):
86+
def __init__(
87+
self,
88+
app: WSGIApp,
89+
executor: concurrent.futures.ThreadPoolExecutor,
90+
scope: HTTPScope,
91+
):
6992
self.app = app
7093
self.executor = executor
7194
self.scope = scope
7295
self.status = None
7396
self.response_headers = None
7497
self.send_event = asyncio.Event()
75-
self.send_queue = []
76-
self.loop = None
98+
self.send_queue: List[Optional[ASGISendEvent]] = []
99+
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
77100
self.response_started = False
78-
self.exc_info = None
101+
self.exc_info: Optional[ExcInfo] = None
79102

80-
async def __call__(self, receive, send):
81-
message = await receive()
103+
async def __call__(
104+
self, receive: ASGIReceiveCallable, send: ASGISendCallable
105+
) -> None:
106+
message: HTTPRequestEvent = await receive() # type: ignore[assignment]
82107
body = message.get("body", b"")
83108
more_body = message.get("more_body", False)
84109
while more_body:
85-
body_message = await receive()
110+
body_message: HTTPRequestEvent = await receive() # type: ignore[assignment]
86111
body += body_message.get("body", b"")
87112
more_body = body_message.get("more_body", False)
88113
environ = build_environ(self.scope, message, body)
@@ -100,7 +125,7 @@ async def __call__(self, receive, send):
100125
if self.exc_info is not None:
101126
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
102127

103-
async def sender(self, send):
128+
async def sender(self, send: ASGISendCallable) -> None:
104129
while True:
105130
if self.send_queue:
106131
message = self.send_queue.pop(0)
@@ -111,31 +136,43 @@ async def sender(self, send):
111136
await self.send_event.wait()
112137
self.send_event.clear()
113138

114-
def start_response(self, status, response_headers, exc_info=None):
139+
def start_response(
140+
self,
141+
status: str,
142+
response_headers: Iterable[Tuple[str, str]],
143+
exc_info: Optional[ExcInfo] = None,
144+
) -> None:
115145
self.exc_info = exc_info
116146
if not self.response_started:
117147
self.response_started = True
118-
status_code, _ = status.split(" ", 1)
119-
status_code = int(status_code)
148+
status_code_str, _ = status.split(" ", 1)
149+
status_code = int(status_code_str)
120150
headers = [
121151
(name.encode("ascii"), value.encode("ascii"))
122152
for name, value in response_headers
123153
]
124-
self.send_queue.append(
125-
{
126-
"type": "http.response.start",
127-
"status": status_code,
128-
"headers": headers,
129-
}
130-
)
154+
http_response_start_event: HTTPResponseStartEvent = {
155+
"type": "http.response.start",
156+
"status": status_code,
157+
"headers": headers,
158+
}
159+
self.send_queue.append(http_response_start_event)
131160
self.loop.call_soon_threadsafe(self.send_event.set)
132161

133-
def wsgi(self, environ, start_response):
134-
for chunk in self.app(environ, start_response):
135-
self.send_queue.append(
136-
{"type": "http.response.body", "body": chunk, "more_body": True}
137-
)
162+
def wsgi(self, environ: Environ, start_response: StartResponse) -> None:
163+
for chunk in self.app(environ, start_response): # type: ignore
164+
response_body: HTTPResponseBodyEvent = {
165+
"type": "http.response.body",
166+
"body": chunk,
167+
"more_body": True,
168+
}
169+
self.send_queue.append(response_body)
138170
self.loop.call_soon_threadsafe(self.send_event.set)
139171

140-
self.send_queue.append({"type": "http.response.body", "body": b""})
172+
empty_body: HTTPResponseBodyEvent = {
173+
"type": "http.response.body",
174+
"body": b"",
175+
"more_body": False,
176+
}
177+
self.send_queue.append(empty_body)
141178
self.loop.call_soon_threadsafe(self.send_event.set)

0 commit comments

Comments
 (0)