|
43 | 43 | import base64
|
44 | 44 | import json
|
45 | 45 | import sys
|
46 |
| -from typing import Any, Dict, Optional, Tuple |
| 46 | +from typing import Any, Dict, Mapping, Optional, Tuple, Union |
47 | 47 | from urllib import parse as urlparse
|
48 | 48 |
|
49 | 49 | import requests
|
@@ -75,7 +75,7 @@ def encode_canonical_json(value: object) -> bytes:
|
75 | 75 | value,
|
76 | 76 | # Encode code-points outside of ASCII as UTF-8 rather than \u escapes
|
77 | 77 | ensure_ascii=False,
|
78 |
| - # Remove unecessary white space. |
| 78 | + # Remove unnecessary white space. |
79 | 79 | separators=(",", ":"),
|
80 | 80 | # Sort the keys of dictionaries.
|
81 | 81 | sort_keys=True,
|
@@ -298,12 +298,23 @@ def send(
|
298 | 298 |
|
299 | 299 | return super().send(request, *args, **kwargs)
|
300 | 300 |
|
301 |
| - def get_connection( |
302 |
| - self, url: str, proxies: Optional[Dict[str, str]] = None |
| 301 | + def get_connection_with_tls_context( |
| 302 | + self, |
| 303 | + request: PreparedRequest, |
| 304 | + verify: Optional[Union[bool, str]], |
| 305 | + proxies: Optional[Mapping[str, str]] = None, |
| 306 | + cert: Optional[Union[Tuple[str, str], str]] = None, |
303 | 307 | ) -> HTTPConnectionPool:
|
304 |
| - # overrides the get_connection() method in the base class |
305 |
| - parsed = urlparse.urlsplit(url) |
306 |
| - (host, port, ssl_server_name) = self._lookup(parsed.netloc) |
| 308 | + # overrides the get_connection_with_tls_context() method in the base class |
| 309 | + parsed = urlparse.urlsplit(request.url) |
| 310 | + |
| 311 | + # Extract the server name from the request URL, and ensure it's a str. |
| 312 | + hostname = parsed.netloc |
| 313 | + if isinstance(hostname, bytes): |
| 314 | + hostname = hostname.decode("utf-8") |
| 315 | + assert isinstance(hostname, str) |
| 316 | + |
| 317 | + (host, port, ssl_server_name) = self._lookup(hostname) |
307 | 318 | print(
|
308 | 319 | f"Connecting to {host}:{port} with SNI {ssl_server_name}", file=sys.stderr
|
309 | 320 | )
|
|
0 commit comments