Skip to content

Commit dc306d0

Browse files
committed
Fixed PR remarks (#94)
Added timeout handling to Semaphore::acquire and tried to avoid private API usage in SocketStream::get_http_version, also changed is_connection_dropped behaviour
1 parent 659bd94 commit dc306d0

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

httpcore/_backends/curio.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
from ssl import SSLContext
2-
from typing import Optional
1+
import select
2+
import socket
3+
from ssl import SSLContext, SSLSocket
4+
from typing import Dict, Optional, Type, Union
35

46
import curio
57
import curio.io
@@ -10,6 +12,7 @@
1012
ConnectTimeout,
1113
ReadError,
1214
ReadTimeout,
15+
TimeoutException,
1316
WriteError,
1417
WriteTimeout,
1518
map_exceptions,
@@ -50,7 +53,13 @@ def semaphore(self) -> curio.Semaphore:
5053
return self._semaphore
5154

5255
async def acquire(self, timeout: float = None) -> None:
53-
await self.semaphore.acquire()
56+
exc_map: Dict[Type[Exception], Type[Exception]] = {
57+
curio.TaskTimeout: TimeoutException,
58+
}
59+
acquire_timeout: int = convert_timeout(timeout)
60+
61+
with map_exceptions(exc_map):
62+
return await curio.timeout_after(acquire_timeout, self.semaphore.acquire())
5463

5564
async def release(self) -> None:
5665
await self.semaphore.release()
@@ -64,10 +73,14 @@ def __init__(self, socket: curio.io.Socket) -> None:
6473
self.stream = socket.as_stream()
6574

6675
def get_http_version(self) -> str:
67-
if hasattr(self.socket, "_socket") and hasattr(self.socket._socket, "_sslobj"):
68-
ident = self.socket._socket._sslobj.selected_alpn_protocol()
69-
else:
70-
ident = "http/1.1"
76+
ident: Optional[str] = "http/1.1"
77+
78+
if hasattr(self.socket, "_socket"):
79+
raw_socket: Union[SSLSocket, socket.socket] = self.socket._socket
80+
81+
if isinstance(raw_socket, SSLSocket):
82+
ident = raw_socket.selected_alpn_protocol()
83+
7184
return "HTTP/2" if ident == "h2" else "HTTP/1.1"
7285

7386
async def start_tls(
@@ -118,11 +131,13 @@ async def write(self, data: bytes, timeout: TimeoutDict) -> None:
118131
await curio.timeout_after(write_timeout, self.stream.write(data))
119132

120133
async def aclose(self) -> None:
121-
# we dont need to close the self.socket, since it's closed by stream closing
122134
await self.stream.close()
135+
await self.socket.close()
123136

124137
def is_connection_dropped(self) -> bool:
125-
return self.socket._closed
138+
rready, _, _ = select.select([self.socket.fileno()], [], [], 0)
139+
140+
return bool(rready)
126141

127142

128143
class CurioBackend(AsyncBackend):

tests/marks/curio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def curio_pytest_pycollect_makeitem(collector, name, obj):
1919
if collector.funcnamefilter(name) and _is_coroutine(obj):
2020
item = pytest.Function.from_parent(collector, name=name)
2121
if "curio" in item.keywords:
22-
return list(collector._genfunctions(name, obj))
22+
return list(collector._genfunctions(name, obj)) # pragma: nocover
2323

2424

2525
@pytest.hookimpl(tryfirst=True, hookwrapper=True)

0 commit comments

Comments
 (0)