1
- from ssl import SSLContext
2
- from typing import Optional
1
+ import select
2
+ import socket
3
+ from ssl import SSLContext , SSLSocket
4
+ from typing import Dict , Optional , Type , Union
3
5
4
6
import curio
5
7
import curio .io
10
12
ConnectTimeout ,
11
13
ReadError ,
12
14
ReadTimeout ,
15
+ TimeoutException ,
13
16
WriteError ,
14
17
WriteTimeout ,
15
18
map_exceptions ,
@@ -50,7 +53,13 @@ def semaphore(self) -> curio.Semaphore:
50
53
return self ._semaphore
51
54
52
55
async def acquire (self , timeout : float = None ) -> None :
53
- await self .semaphore .acquire ()
56
+ exc_map : Dict [Type [Exception ], Type [Exception ]] = {
57
+ curio .TaskTimeout : TimeoutException ,
58
+ }
59
+ acquire_timeout : int = convert_timeout (timeout )
60
+
61
+ with map_exceptions (exc_map ):
62
+ return await curio .timeout_after (acquire_timeout , self .semaphore .acquire ())
54
63
55
64
async def release (self ) -> None :
56
65
await self .semaphore .release ()
@@ -64,10 +73,14 @@ def __init__(self, socket: curio.io.Socket) -> None:
64
73
self .stream = socket .as_stream ()
65
74
66
75
def get_http_version (self ) -> str :
67
- if hasattr (self .socket , "_socket" ) and hasattr (self .socket ._socket , "_sslobj" ):
68
- ident = self .socket ._socket ._sslobj .selected_alpn_protocol ()
69
- else :
70
- ident = "http/1.1"
76
+ ident : Optional [str ] = "http/1.1"
77
+
78
+ if hasattr (self .socket , "_socket" ):
79
+ raw_socket : Union [SSLSocket , socket .socket ] = self .socket ._socket
80
+
81
+ if isinstance (raw_socket , SSLSocket ):
82
+ ident = raw_socket .selected_alpn_protocol ()
83
+
71
84
return "HTTP/2" if ident == "h2" else "HTTP/1.1"
72
85
73
86
async def start_tls (
@@ -118,11 +131,13 @@ async def write(self, data: bytes, timeout: TimeoutDict) -> None:
118
131
await curio .timeout_after (write_timeout , self .stream .write (data ))
119
132
120
133
async def aclose (self ) -> None :
121
- # we dont need to close the self.socket, since it's closed by stream closing
122
134
await self .stream .close ()
135
+ await self .socket .close ()
123
136
124
137
def is_connection_dropped (self ) -> bool :
125
- return self .socket ._closed
138
+ rready , _ , _ = select .select ([self .socket .fileno ()], [], [], 0 )
139
+
140
+ return bool (rready )
126
141
127
142
128
143
class CurioBackend (AsyncBackend ):
0 commit comments