Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 92 additions & 9 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@

from .matcher import MQTTMatcher

try:
import select
except ImportError:
select = None

__version__ = "0.0.0+auto.0"
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_MiniMQTT.git"

Expand Down Expand Up @@ -93,6 +98,13 @@
_default_sock = None
_fake_context = None

_SELECT_POLLIN = getattr(select, "POLLIN", 1) if select else 1
_SELECT_POLL_ERROR_FLAGS = 0
if select:
_SELECT_POLL_ERROR_FLAGS |= getattr(select, "POLLERR", 0)
_SELECT_POLL_ERROR_FLAGS |= getattr(select, "POLLHUP", 0)
_SELECT_POLL_ERROR_FLAGS |= getattr(select, "POLLNVAL", 0)


class MMQTTException(Exception):
"""
Expand Down Expand Up @@ -174,6 +186,8 @@ def __init__( # noqa: PLR0915, PLR0913, Too many statements, Too many arguments
self._socket_pool = socket_pool
self._ssl_context = ssl_context
self._sock = None
self._socket_poller = None
self._poll_socket = None
self._backwards_compatible_sock = False
self._use_binary_mode = use_binary_mode

Expand Down Expand Up @@ -546,6 +560,7 @@ def _connect( # noqa: PLR0912, PLR0913, PLR0915, Too many branches, Too many ar
)
self.session_id = session_id
self._backwards_compatible_sock = not hasattr(self._sock, "recv_into")
self._open_socket_poller()

fixed_header = bytearray([0x10])

Expand Down Expand Up @@ -611,9 +626,52 @@ def _connect( # noqa: PLR0912, PLR0913, PLR0915, Too many branches, Too many ar
def _close_socket(self):
if self._sock:
self.logger.debug("Closing socket")
self._close_socket_poller()
self._connection_manager.close_socket(self._sock)
self._sock = None

def _open_socket_poller(self) -> None:
"""Create and register a socket poller, if this socket supports select.poll."""
self._socket_poller = None
self._poll_socket = None
if not select or not hasattr(select, "poll"):
return

try:
poller = select.poll()
poller.register(self._sock, _SELECT_POLLIN)
self._socket_poller = poller
# Prefer ipoll when available; real hardware measurements showed near-zero
# allocation and faster idle polls.
self._poll_socket = getattr(poller, "ipoll", None) or poller.poll
except (OSError, TypeError):
# This socket cannot be registered with select.poll(); use timeout-based reads.
pass

def _close_socket_poller(self) -> None:
"""Release socket poll resources."""
poller = self._socket_poller
self._socket_poller = None
self._poll_socket = None
if poller is None:
return

poller.unregister(self._sock)

def _socket_poller_readable(self, timeout_ms: Optional[int]) -> bool:
"""Return whether the socket poller has a readable socket event."""
poll_socket = self._poll_socket
if poll_socket is None:
raise RuntimeError("Socket poller is unavailable")

poll_timeout = -1 if timeout_ms is None else max(0, timeout_ms)
events = poll_socket(poll_timeout)
for event_data in events:
event = event_data[1]
if event & (_SELECT_POLL_ERROR_FLAGS | _SELECT_POLLIN):
return True
return False

def _encode_remaining_length(self, fixed_header: bytearray, remaining_length: int) -> None:
"""Encode Remaining Length [2.2.3]"""
if remaining_length > 268_435_455:
Expand Down Expand Up @@ -980,38 +1038,63 @@ def loop(self, timeout: float = 1.0) -> Optional[list[int]]:
"""Non-blocking message loop. Use this method to check for incoming messages.
Returns list of packet types of any messages received or None.

:param float timeout: return after this timeout, in seconds.
:param float timeout: return if no message is received before this timeout,
in seconds. Reading a message that has begun may take longer.

"""
if timeout < self._socket_timeout:
raise ValueError(
f"loop timeout ({timeout}) must be >= "
+ f"socket timeout ({self._socket_timeout}))"
)
if timeout < 0:
raise ValueError("loop timeout must be >= 0")

self._connected()
self.logger.debug(f"waiting for messages for {timeout} seconds")

timeout_ms = int(timeout * 1000)
keep_alive_ms = self.keep_alive * 1000
socket_timeout_ms = int(self._socket_timeout * 1000)
socket_poller = self._socket_poller
if socket_poller is None and timeout_ms < socket_timeout_ms:
message = f"loop timeout ({timeout}) must be >= socket timeout "
message += f"({self._socket_timeout})"
message += " without socket readiness support"
raise ValueError(message)

stamp = ticks_ms()
rcs = []

while True:
if ticks_diff(ticks_ms(), self._last_msg_sent_timestamp) / 1000 >= self.keep_alive:
if ticks_diff(ticks_ms(), self._last_msg_sent_timestamp) >= keep_alive_ms:
# Handle KeepAlive by expecting a PINGREQ/PINGRESP from the server
self.logger.debug(
"KeepAlive period elapsed - requesting a PINGRESP from the server..."
)
rcs.extend(self.ping())
# ping() itself contains a _wait_for_msg() loop which might have taken a while,
# so check here as well.
if ticks_diff(ticks_ms(), stamp) / 1000 > timeout:
if ticks_diff(ticks_ms(), stamp) > timeout_ms:
self.logger.debug(f"Loop timed out after {timeout} seconds")
break

loop_timeout_remaining_ms = max(0, timeout_ms - ticks_diff(ticks_ms(), stamp))
keep_alive_elapsed_ms = ticks_diff(ticks_ms(), self._last_msg_sent_timestamp)
keep_alive_remaining_ms = max(0, keep_alive_ms - keep_alive_elapsed_ms)
socket_wait_timeout_ms = min(loop_timeout_remaining_ms, keep_alive_remaining_ms)
if socket_poller is not None:
socket_readable = self._socket_poller_readable(socket_wait_timeout_ms)
if not socket_readable:
if ticks_diff(ticks_ms(), self._last_msg_sent_timestamp) >= keep_alive_ms:
continue
if ticks_diff(ticks_ms(), stamp) >= timeout_ms:
self.logger.debug(f"Loop timed out after {timeout} seconds")
break
continue
elif loop_timeout_remaining_ms == 0:
self.logger.debug(f"Loop timed out after {timeout} seconds")
break

rc = self._wait_for_msg()
if rc is not None:
rcs.append(rc)
if ticks_diff(ticks_ms(), stamp) / 1000 > timeout:
if ticks_diff(ticks_ms(), stamp) > timeout_ms:
self.logger.debug(f"Loop timed out after {timeout} seconds")
break

Expand Down
Loading
Loading