Skip to content

Commit 9e14ea1

Browse files
xiangyan99pre-commit-ci[bot]webknjazDreamsorcererbdraco
authored
Tightening the runtime type check for ssl (#7698)
Currently, the valid types of ssl parameter are SSLContext, Literal[False], Fingerprint or None. If user sets ssl = False, we disable ssl certificate validation which makes total sense. But if user set ssl = True by mistake, instead of enabling ssl certificate validation or raising errors, we silently disable the validation too which is a little subtle but weird. In this PR, we added a check that if user sets ssl=True, we enable certificate validation by treating it as using Default SSL Context. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sviatoslav Sydorenko <[email protected]> Co-authored-by: Sam Bull <[email protected]> Co-authored-by: J. Nick Koston <[email protected]> Co-authored-by: Sam Bull <[email protected]>
1 parent 2670e7b commit 9e14ea1

File tree

9 files changed

+47
-39
lines changed

9 files changed

+47
-39
lines changed

CHANGES/7698.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added support for passing `True` to `ssl` while deprecating `None`. -- by :user:`xiangyan99`

aiohttp/client.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
Generic,
2525
Iterable,
2626
List,
27-
Literal,
2827
Mapping,
2928
Optional,
3029
Set,
@@ -364,7 +363,7 @@ async def _request(
364363
proxy: Optional[StrOrURL] = None,
365364
proxy_auth: Optional[BasicAuth] = None,
366365
timeout: Union[ClientTimeout, _SENTINEL, None] = sentinel,
367-
ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None,
366+
ssl: Union[SSLContext, bool, Fingerprint] = True,
368367
server_hostname: Optional[str] = None,
369368
proxy_headers: Optional[LooseHeaders] = None,
370369
trace_request_ctx: Optional[SimpleNamespace] = None,
@@ -382,8 +381,8 @@ async def _request(
382381

383382
if not isinstance(ssl, SSL_ALLOWED_TYPES):
384383
raise TypeError(
385-
"ssl should be SSLContext, bool, Fingerprint, "
386-
"or None, got {!r} instead.".format(ssl)
384+
"ssl should be SSLContext, Fingerprint, or bool, "
385+
"got {!r} instead.".format(ssl)
387386
)
388387

389388
if data is not None and json is not None:
@@ -513,7 +512,7 @@ async def _request(
513512
proxy_auth=proxy_auth,
514513
timer=timer,
515514
session=self,
516-
ssl=ssl,
515+
ssl=ssl if ssl is not None else True, # type: ignore[redundant-expr]
517516
server_hostname=server_hostname,
518517
proxy_headers=proxy_headers,
519518
traces=traces,
@@ -702,7 +701,7 @@ def ws_connect(
702701
headers: Optional[LooseHeaders] = None,
703702
proxy: Optional[StrOrURL] = None,
704703
proxy_auth: Optional[BasicAuth] = None,
705-
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
704+
ssl: Union[SSLContext, bool, Fingerprint] = True,
706705
server_hostname: Optional[str] = None,
707706
proxy_headers: Optional[LooseHeaders] = None,
708707
compress: int = 0,
@@ -725,7 +724,7 @@ def ws_connect(
725724
headers=headers,
726725
proxy=proxy,
727726
proxy_auth=proxy_auth,
728-
ssl=ssl,
727+
ssl=ssl if ssl is not None else True, # type: ignore[redundant-expr]
729728
server_hostname=server_hostname,
730729
proxy_headers=proxy_headers,
731730
compress=compress,
@@ -750,7 +749,7 @@ async def _ws_connect(
750749
headers: Optional[LooseHeaders] = None,
751750
proxy: Optional[StrOrURL] = None,
752751
proxy_auth: Optional[BasicAuth] = None,
753-
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
752+
ssl: Union[SSLContext, bool, Fingerprint] = True,
754753
server_hostname: Optional[str] = None,
755754
proxy_headers: Optional[LooseHeaders] = None,
756755
compress: int = 0,
@@ -806,10 +805,19 @@ async def _ws_connect(
806805
extstr = ws_ext_gen(compress=compress)
807806
real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr
808807

808+
# For the sake of backward compatibility, if user passes in None, convert it to True
809+
if ssl is None:
810+
warnings.warn(
811+
"ssl=None is deprecated, please use ssl=True",
812+
DeprecationWarning,
813+
stacklevel=2,
814+
)
815+
ssl = True
816+
809817
if not isinstance(ssl, SSL_ALLOWED_TYPES):
810818
raise TypeError(
811-
"ssl should be SSLContext, bool, Fingerprint, "
812-
"or None, got {!r} instead.".format(ssl)
819+
"ssl should be SSLContext, Fingerprint, or bool, "
820+
"got {!r} instead.".format(ssl)
813821
)
814822

815823
# send request

aiohttp/client_exceptions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,12 @@ def port(self) -> Optional[int]:
149149
return self._conn_key.port
150150

151151
@property
152-
def ssl(self) -> Union[SSLContext, None, bool, "Fingerprint"]:
152+
def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]:
153153
return self._conn_key.ssl
154154

155155
def __str__(self) -> str:
156156
return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format(
157-
self, self.ssl if self.ssl is not None else "default", self.strerror
157+
self, "default" if self.ssl is True else self.ssl, self.strerror
158158
)
159159

160160
# OSError.__reduce__ does too much black magick
@@ -188,7 +188,7 @@ def path(self) -> str:
188188

189189
def __str__(self) -> str:
190190
return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format(
191-
self, self.ssl if self.ssl is not None else "default", self.strerror
191+
self, "default" if self.ssl is True else self.ssl, self.strerror
192192
)
193193

194194

aiohttp/client_reqrep.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
Dict,
1919
Iterable,
2020
List,
21-
Literal,
2221
Mapping,
2322
Optional,
2423
Tuple,
@@ -149,7 +148,7 @@ def check(self, transport: asyncio.Transport) -> None:
149148
if ssl is not None:
150149
SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None))
151150
else: # pragma: no cover
152-
SSL_ALLOWED_TYPES = type(None)
151+
SSL_ALLOWED_TYPES = (bool, type(None))
153152

154153

155154
@dataclasses.dataclass(frozen=True)
@@ -159,7 +158,7 @@ class ConnectionKey:
159158
host: str
160159
port: Optional[int]
161160
is_ssl: bool
162-
ssl: Union[SSLContext, None, Literal[False], Fingerprint]
161+
ssl: Union[SSLContext, bool, Fingerprint]
163162
proxy: Optional[URL]
164163
proxy_auth: Optional[BasicAuth]
165164
proxy_headers_hash: Optional[int] # hash(CIMultiDict)
@@ -213,7 +212,7 @@ def __init__(
213212
proxy_auth: Optional[BasicAuth] = None,
214213
timer: Optional[BaseTimerContext] = None,
215214
session: Optional["ClientSession"] = None,
216-
ssl: Union[SSLContext, Literal[False], Fingerprint, None] = None,
215+
ssl: Union[SSLContext, bool, Fingerprint] = True,
217216
proxy_headers: Optional[LooseHeaders] = None,
218217
traces: Optional[List["Trace"]] = None,
219218
trust_env: bool = False,
@@ -248,7 +247,7 @@ def __init__(
248247
real_response_class = response_class
249248
self.response_class: Type[ClientResponse] = real_response_class
250249
self._timer = timer if timer is not None else TimerNoop()
251-
self._ssl = ssl
250+
self._ssl = ssl if ssl is not None else True # type: ignore[redundant-expr]
252251
self.server_hostname = server_hostname
253252

254253
if loop.get_debug():
@@ -290,7 +289,7 @@ def is_ssl(self) -> bool:
290289
return self.url.scheme in ("https", "wss")
291290

292291
@property
293-
def ssl(self) -> Union["SSLContext", None, Literal[False], Fingerprint]:
292+
def ssl(self) -> Union["SSLContext", bool, Fingerprint]:
294293
return self._ssl
295294

296295
@property

aiohttp/connector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ def __init__(
746746
use_dns_cache: bool = True,
747747
ttl_dns_cache: Optional[int] = 10,
748748
family: int = 0,
749-
ssl: Union[None, Literal[False], Fingerprint, SSLContext] = None,
749+
ssl: Union[bool, Fingerprint, SSLContext] = True,
750750
local_addr: Optional[Tuple[str, int]] = None,
751751
resolver: Optional[AbstractResolver] = None,
752752
keepalive_timeout: Union[None, float, _SENTINEL] = sentinel,
@@ -769,8 +769,8 @@ def __init__(
769769

770770
if not isinstance(ssl, SSL_ALLOWED_TYPES):
771771
raise TypeError(
772-
"ssl should be SSLContext, bool, Fingerprint, "
773-
"or None, got {!r} instead.".format(ssl)
772+
"ssl should be SSLContext, Fingerprint, or bool, "
773+
"got {!r} instead.".format(ssl)
774774
)
775775
self._ssl = ssl
776776
if resolver is None:
@@ -942,13 +942,13 @@ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
942942
sslcontext = req.ssl
943943
if isinstance(sslcontext, ssl.SSLContext):
944944
return sslcontext
945-
if sslcontext is not None:
945+
if sslcontext is not True:
946946
# not verified or fingerprinted
947947
return self._make_ssl_context(False)
948948
sslcontext = self._ssl
949949
if isinstance(sslcontext, ssl.SSLContext):
950950
return sslcontext
951-
if sslcontext is not None:
951+
if sslcontext is not True:
952952
# not verified or fingerprinted
953953
return self._make_ssl_context(False)
954954
return self._make_ssl_context(True)

tests/test_client_exceptions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class TestClientConnectorError:
8989
host="example.com",
9090
port=8080,
9191
is_ssl=False,
92-
ssl=None,
92+
ssl=True,
9393
proxy=None,
9494
proxy_auth=None,
9595
proxy_headers_hash=None,
@@ -106,7 +106,7 @@ def test_ctor(self) -> None:
106106
assert err.os_error.strerror == "No such file"
107107
assert err.host == "example.com"
108108
assert err.port == 8080
109-
assert err.ssl is None
109+
assert err.ssl is True
110110

111111
def test_pickle(self) -> None:
112112
err = client.ClientConnectorError(
@@ -123,7 +123,7 @@ def test_pickle(self) -> None:
123123
assert err2.os_error.strerror == "No such file"
124124
assert err2.host == "example.com"
125125
assert err2.port == 8080
126-
assert err2.ssl is None
126+
assert err2.ssl is True
127127
assert err2.foo == "bar"
128128

129129
def test_repr(self) -> None:
@@ -141,7 +141,7 @@ def test_str(self) -> None:
141141
os_error=OSError(errno.ENOENT, "No such file"),
142142
)
143143
assert str(err) == (
144-
"Cannot connect to host example.com:8080 ssl:" "default [No such file]"
144+
"Cannot connect to host example.com:8080 ssl:default [No such file]"
145145
)
146146

147147

@@ -150,7 +150,7 @@ class TestClientConnectorCertificateError:
150150
host="example.com",
151151
port=8080,
152152
is_ssl=False,
153-
ssl=None,
153+
ssl=True,
154154
proxy=None,
155155
proxy_auth=None,
156156
proxy_headers_hash=None,

tests/test_client_request.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def test_host_port_default_http(make_request: Any) -> None:
164164
req = make_request("get", "http://python.org/")
165165
assert req.host == "python.org"
166166
assert req.port == 80
167-
assert not req.ssl
167+
assert not req.is_ssl()
168168

169169

170170
def test_host_port_default_https(make_request: Any) -> None:
@@ -391,7 +391,7 @@ def test_ipv6_default_http_port(make_request: Any) -> None:
391391
req = make_request("get", "http://[2001:db8::1]/")
392392
assert req.host == "2001:db8::1"
393393
assert req.port == 80
394-
assert not req.ssl
394+
assert not req.is_ssl()
395395

396396

397397
def test_ipv6_default_https_port(make_request: Any) -> None:

tests/test_connector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,19 @@
3131
@pytest.fixture()
3232
def key():
3333
# Connection key
34-
return ConnectionKey("localhost", 80, False, None, None, None, None)
34+
return ConnectionKey("localhost", 80, False, True, None, None, None)
3535

3636

3737
@pytest.fixture
3838
def key2():
3939
# Connection key
40-
return ConnectionKey("localhost", 80, False, None, None, None, None)
40+
return ConnectionKey("localhost", 80, False, True, None, None, None)
4141

4242

4343
@pytest.fixture
4444
def ssl_key():
4545
# Connection key
46-
return ConnectionKey("localhost", 80, True, None, None, None, None)
46+
return ConnectionKey("localhost", 80, True, True, None, None, None)
4747

4848

4949
@pytest.fixture
@@ -1478,7 +1478,7 @@ async def test_cleanup_closed_disabled(loop: Any, mocker: Any) -> None:
14781478

14791479
async def test_tcp_connector_ctor(loop: Any) -> None:
14801480
conn = aiohttp.TCPConnector()
1481-
assert conn._ssl is None
1481+
assert conn._ssl is True
14821482

14831483
assert conn.use_dns_cache
14841484
assert conn.family == 0
@@ -1565,7 +1565,7 @@ async def test___get_ssl_context3(loop: Any) -> None:
15651565
conn = aiohttp.TCPConnector(ssl=ctx)
15661566
req = mock.Mock()
15671567
req.is_ssl.return_value = True
1568-
req.ssl = None
1568+
req.ssl = True
15691569
assert conn._get_ssl_context(req) is ctx
15701570

15711571

@@ -1591,7 +1591,7 @@ async def test___get_ssl_context6(loop: Any) -> None:
15911591
conn = aiohttp.TCPConnector()
15921592
req = mock.Mock()
15931593
req.is_ssl.return_value = True
1594-
req.ssl = None
1594+
req.ssl = True
15951595
assert conn._get_ssl_context(req) is conn._make_ssl_context(True)
15961596

15971597

tests/test_proxy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ async def make_conn():
8888
auth=None,
8989
headers={"Host": "www.python.org"},
9090
loop=self.loop,
91-
ssl=None,
91+
ssl=True,
9292
)
9393

9494
conn.close()
@@ -146,7 +146,7 @@ async def make_conn():
146146
auth=None,
147147
headers={"Host": "www.python.org", "Foo": "Bar"},
148148
loop=self.loop,
149-
ssl=None,
149+
ssl=True,
150150
)
151151

152152
conn.close()

0 commit comments

Comments
 (0)