Skip to content

fix: mypy issues #983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions uvicorn/_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import sys
from typing import Optional
from typing import Any, Awaitable, Callable, Dict, Optional, Type, Union

if sys.version_info < (3, 8):
from typing_extensions import Literal, TypedDict
from typing_extensions import Literal, Protocol, TypedDict
else:
from typing import Literal, TypedDict
from typing import Literal, Protocol, TypedDict


class ASGISpecInfo(TypedDict):
Expand All @@ -29,3 +29,23 @@ class LifespanSendMessage(TypedDict):
"lifespan.shutdown.failed",
]
message: Optional[str]


Scope = Dict[str, Any]
Message = Dict[str, Any]

Receive = Callable[[], Awaitable[Message]]
Send = Callable[[Message], Awaitable[None]]


class ASGI2Protocol(Protocol):
def __init__(self, scope: Scope) -> None:
...

async def __call__(self, receive: Receive, send: Send) -> None:
...


ASGI2App = Type[ASGI2Protocol]
ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]]
ASGIApp = Union[ASGI2App, ASGI3App]
102 changes: 54 additions & 48 deletions uvicorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import socket
import ssl
import sys
from typing import List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import click

Expand Down Expand Up @@ -102,8 +102,14 @@


def create_ssl_context(
certfile, keyfile, password, ssl_version, cert_reqs, ca_certs, ciphers
):
certfile: Optional[str],
keyfile: Optional[str],
password: Optional[str],
ssl_version: int,
cert_reqs: int,
ca_certs: Optional[str],
ciphers: str,
) -> ssl.SSLContext:
ctx = ssl.SSLContext(ssl_version)
get_password = (lambda: password) if password else None
ctx.load_cert_chain(certfile, keyfile, get_password)
Expand All @@ -118,44 +124,44 @@ def create_ssl_context(
class Config:
def __init__(
self,
app,
host="127.0.0.1",
port=8000,
uds=None,
fd=None,
loop="auto",
http="auto",
ws="auto",
lifespan="auto",
env_file=None,
log_config=LOGGING_CONFIG,
log_level=None,
access_log=True,
use_colors=None,
interface="auto",
debug=False,
reload=False,
reload_dirs=None,
reload_delay=None,
workers=None,
proxy_headers=True,
forwarded_allow_ips=None,
root_path="",
limit_concurrency=None,
limit_max_requests=None,
backlog=2048,
timeout_keep_alive=5,
timeout_notify=30,
callback_notify=None,
ssl_keyfile=None,
ssl_certfile=None,
ssl_keyfile_password=None,
ssl_version=SSL_PROTOCOL_VERSION,
ssl_cert_reqs=ssl.CERT_NONE,
ssl_ca_certs=None,
ssl_ciphers="TLSv1",
headers=None,
factory=False,
app: str,
host: str = "127.0.0.1",
port: int = 8000,
uds: Optional[str] = None,
fd: Optional[int] = None,
loop: str = "auto",
http: str = "auto",
ws: str = "auto",
lifespan: str = "auto",
env_file: Optional[str] = None,
log_config: Union[Dict[str, Any], str] = LOGGING_CONFIG,
log_level: Optional[str] = None,
access_log: bool = True,
use_colors: Optional[bool] = None,
interface: str = "auto",
debug: bool = False,
reload: bool = False,
reload_dirs: List[str] = None,
reload_delay: Optional[float] = None,
workers: Optional[int] = None,
proxy_headers: bool = True,
forwarded_allow_ips: Optional[str] = None,
root_path: str = "",
limit_concurrency: Optional[int] = None,
limit_max_requests: Optional[int] = None,
backlog: int = 2048,
timeout_keep_alive: int = 5,
timeout_notify: int = 30,
callback_notify: Optional[Callable] = None,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_keyfile_password: Optional[str] = None,
ssl_version: int = SSL_PROTOCOL_VERSION,
ssl_cert_reqs: int = ssl.CERT_NONE,
ssl_ca_certs: Optional[str] = None,
ssl_ciphers: str = "TLSv1",
headers: Optional[List[str]] = None,
factory: bool = False,
):
self.app = app
self.host = host
Expand Down Expand Up @@ -190,8 +196,8 @@ def __init__(
self.ssl_cert_reqs = ssl_cert_reqs
self.ssl_ca_certs = ssl_ca_certs
self.ssl_ciphers = ssl_ciphers
self.headers = headers if headers else [] # type: List[str]
self.encoded_headers = None # type: List[Tuple[bytes, bytes]]
self.headers: List[List[str]] = headers if headers else []
self.encoded_headers: List[Tuple[bytes, bytes]] = None
self.factory = factory

self.loaded = False
Expand Down Expand Up @@ -226,7 +232,7 @@ def asgi_version(self) -> str:
def is_ssl(self) -> bool:
return bool(self.ssl_keyfile or self.ssl_certfile)

def configure_logging(self):
def configure_logging(self) -> None:
logging.addLevelName(TRACE_LOG_LEVEL, "TRACE")

if self.log_config is not None:
Expand Down Expand Up @@ -266,7 +272,7 @@ def configure_logging(self):
logging.getLogger("uvicorn.access").handlers = []
logging.getLogger("uvicorn.access").propagate = False

def load(self):
def load(self) -> None:
assert not self.loaded

if self.is_ssl:
Expand Down Expand Up @@ -350,12 +356,12 @@ def load(self):

self.loaded = True

def setup_event_loop(self):
def setup_event_loop(self) -> None:
loop_setup = import_from_string(LOOP_SETUPS[self.loop])
if loop_setup is not None:
loop_setup()

def bind_socket(self):
def bind_socket(self) -> socket.socket:
family = socket.AF_INET
addr_format = "%s://%s:%d"

Expand Down Expand Up @@ -390,5 +396,5 @@ def bind_socket(self):
return sock

@property
def should_reload(self):
def should_reload(self) -> bool:
return isinstance(self.app, str) and (self.debug or self.reload)
3 changes: 2 additions & 1 deletion uvicorn/importer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import importlib
from typing import Any


class ImportFromStringError(Exception):
pass


def import_from_string(import_str):
def import_from_string(import_str: Any) -> Any:
if not isinstance(import_str, str):
return import_str

Expand Down
31 changes: 20 additions & 11 deletions uvicorn/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import sys
from copy import copy
from typing import Optional

import click

Expand All @@ -28,24 +29,30 @@ class ColourizedFormatter(logging.Formatter):
),
}

def __init__(self, fmt=None, datefmt=None, style="%", use_colors=None):
if use_colors in (True, False):
def __init__(
self,
fmt: Optional[str] = None,
datefmt: Optional[str] = None,
style: str = "%",
use_colors: Optional[bool] = None,
) -> None:
if isinstance(use_colors, bool):
self.use_colors = use_colors
else:
self.use_colors = sys.stdout.isatty()
super().__init__(fmt=fmt, datefmt=datefmt, style=style)

def color_level_name(self, level_name, level_no):
def default(level_name):
def color_level_name(self, level_name: str, level_no: int) -> str:
def default(level_name: str) -> str:
return str(level_name)

func = self.level_name_colors.get(level_no, default)
return func(level_name)

def should_use_colors(self):
def should_use_colors(self) -> bool:
return True

def formatMessage(self, record):
def formatMessage(self, record: logging.LogRecord) -> str:
recordcopy = copy(record)
levelname = recordcopy.levelname
seperator = " " * (8 - len(recordcopy.levelname))
Expand All @@ -59,7 +66,7 @@ def formatMessage(self, record):


class DefaultFormatter(ColourizedFormatter):
def should_use_colors(self):
def should_use_colors(self) -> bool:
return sys.stderr.isatty()


Expand All @@ -72,22 +79,22 @@ class AccessFormatter(ColourizedFormatter):
5: lambda code: click.style(str(code), fg="bright_red"),
}

def get_status_code(self, status_code: int):
def get_status_code(self, status_code: int) -> str:
try:
status_phrase = http.HTTPStatus(status_code).phrase
except ValueError:
status_phrase = ""
status_and_phrase = "%s %s" % (status_code, status_phrase)
if self.use_colors:

def default(code):
def default(code: int) -> str:
return status_and_phrase

func = self.status_code_colours.get(status_code // 100, default)
return func(status_and_phrase)
return status_and_phrase

def formatMessage(self, record):
def formatMessage(self, record: logging.LogRecord) -> str:
recordcopy = copy(record)
(
client_addr,
Expand All @@ -96,7 +103,9 @@ def formatMessage(self, record):
http_version,
status_code,
) = recordcopy.args
status_code = self.get_status_code(status_code)
# NOTE: error: Argument 1 to "get_status_code" of "AccessFormatter" has
# incompatible type "Union[Any, str]"; expected "int"
status_code = self.get_status_code(status_code) # type: ignore
request_line = "%s %s HTTP/%s" % (method, full_path, http_version)
if self.use_colors:
request_line = click.style(request_line, bold=True)
Expand Down
4 changes: 2 additions & 2 deletions uvicorn/loops/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys


def asyncio_setup():
def asyncio_setup() -> None:
if (
sys.version_info.major >= 3
and sys.version_info.minor >= 8
Expand All @@ -14,5 +14,5 @@ def asyncio_setup():
loop = asyncio.SelectorEventLoop(selector)
asyncio.set_event_loop(loop)
else:
loop = asyncio.new_event_loop()
loop = asyncio.new_event_loop() # type: ignore
asyncio.set_event_loop(loop)
2 changes: 1 addition & 1 deletion uvicorn/loops/auto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def auto_loop_setup():
def auto_loop_setup() -> None:
try:
import uvloop # noqa
except ImportError: # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion uvicorn/loops/uvloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
import uvloop


def uvloop_setup():
def uvloop_setup() -> None:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
17 changes: 7 additions & 10 deletions uvicorn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
logger = logging.getLogger("uvicorn.error")


def print_version(ctx, param, value):
def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
click.echo(
Expand Down Expand Up @@ -281,7 +281,7 @@ def print_version(ctx, param, value):
show_default=True,
)
def main(
app,
app: str,
host: str,
port: int,
uds: str,
Expand Down Expand Up @@ -318,11 +318,10 @@ def main(
use_colors: bool,
app_dir: str,
factory: bool,
):
) -> None:
sys.path.insert(0, app_dir)

kwargs = {
"app": app,
"host": host,
"port": port,
"uds": uds,
Expand Down Expand Up @@ -359,10 +358,10 @@ def main(
"use_colors": use_colors,
"factory": factory,
}
run(**kwargs)
run(app, **kwargs)


def run(app, **kwargs):
def run(app: str, **kwargs: typing.Any) -> None:
config = Config(app, **kwargs)
server = Server(config=config)

Expand All @@ -376,12 +375,10 @@ def run(app, **kwargs):

if config.should_reload:
sock = config.bind_socket()
supervisor = ChangeReload(config, target=server.run, sockets=[sock])
supervisor.run()
ChangeReload(config, target=server.run, sockets=[sock]).run()
elif config.workers > 1:
sock = config.bind_socket()
supervisor = Multiprocess(config, target=server.run, sockets=[sock])
supervisor.run()
Multiprocess(config, target=server.run, sockets=[sock]).run()
else:
server.run()

Expand Down
Loading