Skip to content

Re-try in more cases when socket cannot first be created #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pycqa/pylint
rev: v2.17.4
rev: v3.1.0
hooks:
- id: pylint
name: pylint (library code)
Expand Down
122 changes: 55 additions & 67 deletions adafruit_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

"""

# imports

__version__ = "0.0.0+auto.0"
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager.git"

Expand All @@ -31,9 +29,6 @@

WIZNET5K_SSL_SUPPORT_VERSION = (9, 1)

# typing


if not sys.implementation.name == "circuitpython":
from typing import List, Optional, Tuple

Expand All @@ -46,9 +41,6 @@
)


# ssl and pool helpers


class _FakeSSLSocket:
def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None:
self._socket = socket
Expand Down Expand Up @@ -82,7 +74,7 @@ def wrap_socket( # pylint: disable=unused-argument
if hasattr(self._iface, "TLS_MODE"):
return _FakeSSLSocket(socket, self._iface.TLS_MODE)

raise AttributeError("This radio does not support TLS/HTTPS")
raise ValueError("This radio does not support TLS/HTTPS")


def create_fake_ssl_context(
Expand Down Expand Up @@ -167,7 +159,7 @@ def get_radio_socketpool(radio):
ssl_context = create_fake_ssl_context(pool, radio)

else:
raise AttributeError(f"Unsupported radio class: {class_name}")
raise ValueError(f"Unsupported radio class: {class_name}")

_global_key_by_socketpool[pool] = key
_global_socketpools[key] = pool
Expand All @@ -189,11 +181,8 @@ def get_radio_ssl_context(radio):
return _global_ssl_contexts[_get_radio_hash_key(radio)]


# main class


class ConnectionManager:
"""A library for managing sockets accross libraries."""
"""A library for managing sockets across multiple hardware platforms and libraries."""

def __init__(
self,
Expand All @@ -215,6 +204,11 @@ def _free_sockets(self, force: bool = False) -> None:
for socket in open_sockets:
self.close_socket(socket)

def _register_connected_socket(self, key, socket):
"""Register a socket as managed."""
self._key_by_managed_socket[socket] = key
self._managed_socket_by_key[key] = socket

def _get_connected_socket( # pylint: disable=too-many-arguments
self,
addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]],
Expand All @@ -224,23 +218,24 @@ def _get_connected_socket( # pylint: disable=too-many-arguments
is_ssl: bool,
ssl_context: Optional[SSLContextType] = None,
):
try:
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
except (OSError, RuntimeError) as exc:
return exc

socket = self._socket_pool.socket(addr_info[0], addr_info[1])

if is_ssl:
socket = ssl_context.wrap_socket(socket, server_hostname=host)
connect_host = host
else:
connect_host = addr_info[-1][0]
socket.settimeout(timeout) # socket read timeout

# Set socket read and connect timeout.
socket.settimeout(timeout)

try:
socket.connect((connect_host, port))
except (MemoryError, OSError) as exc:
except (MemoryError, OSError):
# If any connect problems, clean up and re-raise the problem exception.
socket.close()
return exc
raise

return socket

Expand Down Expand Up @@ -269,82 +264,78 @@ def close_socket(self, socket: SocketType) -> None:
self._available_sockets.remove(socket)

def free_socket(self, socket: SocketType) -> None:
"""Mark a managed socket as available so it can be reused."""
"""Mark a managed socket as available so it can be reused. The socket is not closed."""
if socket not in self._managed_socket_by_key.values():
raise RuntimeError("Socket not managed")
self._available_sockets.add(socket)

# pylint: disable=too-many-arguments
def get_socket(
self,
host: str,
port: int,
proto: str,
session_id: Optional[str] = None,
*,
timeout: float = 1,
timeout: float = 1.0,
is_ssl: bool = False,
ssl_context: Optional[SSLContextType] = None,
) -> CircuitPythonSocketType:
"""
Get a new socket and connect.

- **host** *(str)* – The host you are want to connect to: "www.adaftuit.com"
- **port** *(int)* – The port you want to connect to: 80
- **proto** *(str)* – The protocal you want to use: "http:"
- **session_id** *(Optional[str])* – A unique Session ID, when wanting to have multiple open
connections to the same host
- **timeout** *(float)* – Time timeout used for connecting
- **is_ssl** *(bool)* – If the connection is to be over SSL (auto set when proto is
"https:")
- **ssl_context** *(Optional[SSLContextType])* – The SSL context to use when making SSL
requests
Get a new socket and connect to the given host.

:param str host: host to connect to, such as ``"www.example.org"``
:param int port: port to use for connection, such as ``80`` or ``443``
:param str proto: connection protocol: ``"http:"``, ``"https:"``, etc.
:param Optional[str]: unique session ID,
used for multiple simultaneous connections to the same host
:param float timeout: how long to wait to connect
:param bool is_ssl: ``True`` If the connection is to be over SSL;
automatically set when ``proto`` is ``"https:"``
:param Optional[SSLContextType]: SSL context to use when making SSL requests
"""
if session_id:
session_id = str(session_id)
key = (host, port, proto, session_id)

# Do we have already have a socket available for the requested connection?
if key in self._managed_socket_by_key:
socket = self._managed_socket_by_key[key]
if socket in self._available_sockets:
self._available_sockets.remove(socket)
return socket

raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}")
raise RuntimeError(
f"An existing socket is already connected to {proto}//{host}:{port}"
)

if proto == "https:":
is_ssl = True
if is_ssl and not ssl_context:
raise AttributeError(
"ssl_context must be set before using adafruit_requests for https"
)
raise ValueError("ssl_context must be provided if using ssl")

addr_info = self._socket_pool.getaddrinfo(
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]

first_exception = None
result = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
if isinstance(result, Exception):
# Got an error, if there are any available sockets, free them and try again
try:
socket = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
self._register_connected_socket(key, socket)
return socket
except (MemoryError, OSError, RuntimeError):
# Could not get a new socket (or two, if SSL).
# If there are any available sockets, free them all and try again.
if self.available_socket_count:
first_exception = result
self._free_sockets()
result = self._get_connected_socket(
socket = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
if isinstance(result, Exception):
last_result = f", first error: {first_exception}" if first_exception else ""
raise RuntimeError(
f"Error connecting socket: {result}{last_result}"
) from result

self._key_by_managed_socket[result] = key
self._managed_socket_by_key[key] = result
return result


# global helpers
self._register_connected_socket(key, socket)
return socket
# Re-raise exception if no sockets could be freed.
raise


def connection_manager_close_all(
Expand All @@ -353,10 +344,10 @@ def connection_manager_close_all(
"""
Close all open sockets for pool, optionally release references.

- **socket_pool** *(Optional[SocketpoolModuleType])* – A specifc SocketPool you want to close
sockets for, leave blank for all SocketPools
- **release_references** *(bool)* – Set to True if you want to also clear stored references to
the SocketPool and SSL contexts
:param Optional[SocketpoolModuleType] socket_pool:
a specific socket pool whose sockets you want to close; ``None`` means all socket pools
:param bool release_references: ``True`` if you also want the `ConnectionManager` to forget
all the socket pools and SSL contexts it knows about
"""
if socket_pool:
socket_pools = [socket_pool]
Expand All @@ -383,10 +374,7 @@ def connection_manager_close_all(

def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager:
"""
Get the ConnectionManager singleton for the given pool.

- **socket_pool** *(Optional[SocketpoolModuleType])* – The SocketPool you want the
ConnectionManager for
Get or create the ConnectionManager singleton for the given pool.
"""
if socket_pool not in _global_connection_managers:
_global_connection_managers[socket_pool] = ConnectionManager(socket_pool)
Expand Down
4 changes: 2 additions & 2 deletions tests/get_radio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument

def test_get_radio_socketpool_unsupported():
radio = mocket.MockRadio.Unsupported()
with pytest.raises(AttributeError) as context:
with pytest.raises(ValueError) as context:
adafruit_connection_manager.get_radio_socketpool(radio)
assert "Unsupported radio class" in str(context)

Expand Down Expand Up @@ -100,7 +100,7 @@ def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument

def test_get_radio_ssl_context_unsupported():
radio = mocket.MockRadio.Unsupported()
with pytest.raises(AttributeError) as context:
with pytest.raises(ValueError) as context:
adafruit_connection_manager.get_radio_ssl_context(radio)
assert "Unsupported radio class" in str(context)

Expand Down
20 changes: 7 additions & 13 deletions tests/get_socket_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_get_socket_not_flagged_free():
# get a socket for the same host, should be a different one
with pytest.raises(RuntimeError) as context:
socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert "Socket already connected" in str(context)
assert "An existing socket is already connected" in str(context)


def test_get_socket_os_error():
Expand All @@ -105,9 +105,8 @@ def test_get_socket_os_error():
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

# try to get a socket that returns a OSError
with pytest.raises(RuntimeError) as context:
with pytest.raises(OSError):
connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert "Error connecting socket: OSError" in str(context)


def test_get_socket_runtime_error():
Expand All @@ -121,9 +120,8 @@ def test_get_socket_runtime_error():
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

# try to get a socket that returns a RuntimeError
with pytest.raises(RuntimeError) as context:
with pytest.raises(RuntimeError):
connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert "Error connecting socket: RuntimeError" in str(context)


def test_get_socket_connect_memory_error():
Expand All @@ -139,9 +137,8 @@ def test_get_socket_connect_memory_error():
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

# try to connect a socket that returns a MemoryError
with pytest.raises(RuntimeError) as context:
with pytest.raises(MemoryError):
connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert "Error connecting socket: MemoryError" in str(context)


def test_get_socket_connect_os_error():
Expand All @@ -157,9 +154,8 @@ def test_get_socket_connect_os_error():
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

# try to connect a socket that returns a OSError
with pytest.raises(RuntimeError) as context:
with pytest.raises(OSError):
connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert "Error connecting socket: OSError" in str(context)


def test_get_socket_runtime_error_ties_again_at_least_one_free():
Expand Down Expand Up @@ -211,9 +207,8 @@ def test_get_socket_runtime_error_ties_again_only_once():
free_sockets_mock.assert_not_called()

# try to get a socket that returns a RuntimeError twice
with pytest.raises(RuntimeError) as context:
with pytest.raises(RuntimeError):
connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:")
assert "Error connecting socket: error 2, first error: error 1" in str(context)
free_sockets_mock.assert_called_once()


Expand Down Expand Up @@ -248,8 +243,7 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument
ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio)
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

with pytest.raises(RuntimeError) as context:
with pytest.raises(OSError):
connection_manager.get_socket(
mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context
)
assert "Error connecting socket: [Errno 12] RuntimeError" in str(context)
4 changes: 2 additions & 2 deletions tests/protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def test_get_https_no_ssl():
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

# verify not sending in a SSL context for a HTTPS call errors
with pytest.raises(AttributeError) as context:
with pytest.raises(ValueError) as context:
connection_manager.get_socket(mocket.MOCK_HOST_1, 443, "https:")
assert "ssl_context must be set" in str(context)
assert "ssl_context must be provided if using ssl" in str(context)


def test_connect_https():
Expand Down
2 changes: 1 addition & 1 deletion tests/ssl_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argumen
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

# verify a HTTPS call for a board without built in WiFi and SSL support errors
with pytest.raises(AttributeError) as context:
with pytest.raises(ValueError) as context:
connection_manager.get_socket(
mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context
)
Expand Down
Loading