diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 600ad245..e74d3910 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -226,7 +226,7 @@ def __init__( self._is_connected = False self._msg_size_lim = MQTT_MSG_SZ_LIM self._pid = 0 - self._timestamp: float = 0 + self._last_msg_sent_timestamp: float = 0 self.logger = NullLogger() """An optional logging attribute that can be set with with a Logger to enable debug logging.""" @@ -640,6 +640,7 @@ def _connect( if self._username is not None: self._send_str(self._username) self._send_str(self._password) + self._last_msg_sent_timestamp = self.get_monotonic_time() self.logger.debug("Receiving CONNACK packet from broker") stamp = self.get_monotonic_time() while True: @@ -694,6 +695,7 @@ def disconnect(self) -> None: self._sock.close() self._is_connected = False self._subscribed_topics = [] + self._last_msg_sent_timestamp = 0 if self.on_disconnect is not None: self.on_disconnect(self, self.user_data, 0) @@ -707,6 +709,7 @@ def ping(self) -> list[int]: self._sock.send(MQTT_PINGREQ) ping_timeout = self.keep_alive stamp = self.get_monotonic_time() + self._last_msg_sent_timestamp = stamp rc, rcs = None, [] while rc != MQTT_PINGRESP: rc = self._wait_for_msg() @@ -781,6 +784,7 @@ def publish( self._sock.send(pub_hdr_fixed) self._sock.send(pub_hdr_var) self._sock.send(msg) + self._last_msg_sent_timestamp = self.get_monotonic_time() if qos == 0 and self.on_publish is not None: self.on_publish(self, self.user_data, topic, self._pid) if qos == 1: @@ -858,6 +862,7 @@ def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> N self.logger.debug(f"payload: {payload}") self._sock.send(payload) stamp = self.get_monotonic_time() + self._last_msg_sent_timestamp = stamp while True: op = self._wait_for_msg() if op is None: @@ -933,6 +938,7 @@ def unsubscribe(self, topic: Optional[Union[str, list]]) -> None: for t in topics: self.logger.debug(f"UNSUBSCRIBING from topic {t}") self._sock.send(payload) + self._last_msg_sent_timestamp = self.get_monotonic_time() self.logger.debug("Waiting for UNSUBACK...") while True: stamp = self.get_monotonic_time() @@ -1022,7 +1028,6 @@ def reconnect(self, resub_topics: bool = True) -> int: return ret def loop(self, timeout: float = 0) -> Optional[list[int]]: - # pylint: disable = too-many-return-statements """Non-blocking message loop. Use this method to check for incoming messages. Returns list of packet types of any messages received or None. @@ -1038,23 +1043,27 @@ def loop(self, timeout: float = 0) -> Optional[list[int]]: self._connected() self.logger.debug(f"waiting for messages for {timeout} seconds") - if self._timestamp == 0: - self._timestamp = self.get_monotonic_time() - current_time = self.get_monotonic_time() - if current_time - self._timestamp >= self.keep_alive: - self._timestamp = 0 - # Handle KeepAlive by expecting a PINGREQ/PINGRESP from the server - self.logger.debug( - "KeepAlive period elapsed - requesting a PINGRESP from the server..." - ) - rcs = self.ping() - return rcs stamp = self.get_monotonic_time() rcs = [] while True: - rc = self._wait_for_msg(timeout=timeout) + if ( + self.get_monotonic_time() - self._last_msg_sent_timestamp + >= self.keep_alive + ): + # Handle KeepAlive by expecting a PINGREQ/PINGRESP from the server + self.logger.debug( + "KeepAlive period elapsed - requesting a PINGRESP from the server..." + ) + rcs.extend(self.ping()) + # ping() itself contains a _wait_for_msg() loop which might have taken a while, + # so check here as well. + if self.get_monotonic_time() - stamp > timeout: + self.logger.debug(f"Loop timed out after {timeout} seconds") + break + + rc = self._wait_for_msg() if rc is not None: rcs.append(rc) if self.get_monotonic_time() - stamp > timeout: diff --git a/tests/test_loop.py b/tests/test_loop.py index 4728ac11..173fe526 100644 --- a/tests/test_loop.py +++ b/tests/test_loop.py @@ -8,12 +8,99 @@ import socket import ssl import time +import errno + from unittest import TestCase, main from unittest.mock import patch +from unittest import mock import adafruit_minimqtt.adafruit_minimqtt as MQTT +class Nulltet: + """ + Mock Socket that does nothing. + + Inspired by the Mocket class from Adafruit_CircuitPython_Requests + """ + + def __init__(self): + self.sent = bytearray() + + self.timeout = mock.Mock() + self.connect = mock.Mock() + self.close = mock.Mock() + + def send(self, bytes_to_send): + """ + Record the bytes. return the length of this bytearray. + """ + self.sent.extend(bytes_to_send) + return len(bytes_to_send) + + # MiniMQTT checks for the presence of "recv_into" and switches behavior based on that. + # pylint: disable=unused-argument,no-self-use + def recv_into(self, retbuf, bufsize): + """Always raise timeout exception.""" + exc = OSError() + exc.errno = errno.ETIMEDOUT + raise exc + + +class Pingtet: + """ + Mock Socket tailored for PINGREQ testing. + Records sent data, hands out PINGRESP for each PINGREQ received. + + Inspired by the Mocket class from Adafruit_CircuitPython_Requests + """ + + PINGRESP = bytearray([0xD0, 0x00]) + + def __init__(self): + self._to_send = self.PINGRESP + + self.sent = bytearray() + + self.timeout = mock.Mock() + self.connect = mock.Mock() + self.close = mock.Mock() + + self._got_pingreq = False + + def send(self, bytes_to_send): + """ + Recognize PINGREQ and record the indication that it was received. + Assumes it was sent in one chunk (of 2 bytes). + Also record the bytes. return the length of this bytearray. + """ + self.sent.extend(bytes_to_send) + if bytes_to_send == b"\xc0\0": + self._got_pingreq = True + return len(bytes_to_send) + + # MiniMQTT checks for the presence of "recv_into" and switches behavior based on that. + def recv_into(self, retbuf, bufsize): + """ + If the PINGREQ indication is on, return PINGRESP, otherwise raise timeout exception. + """ + if self._got_pingreq: + size = min(bufsize, len(self._to_send)) + if size == 0: + return size + chop = self._to_send[0:size] + retbuf[0:] = chop + self._to_send = self._to_send[size:] + if len(self._to_send) == 0: + self._got_pingreq = False + self._to_send = self.PINGRESP + return size + + exc = OSError() + exc.errno = errno.ETIMEDOUT + raise exc + + class Loop(TestCase): """basic loop() test""" @@ -54,6 +141,8 @@ def test_loop_basic(self) -> None: time_before = time.monotonic() timeout = random.randint(3, 8) + # pylint: disable=protected-access + mqtt_client._last_msg_sent_timestamp = mqtt_client.get_monotonic_time() rcs = mqtt_client.loop(timeout=timeout) time_after = time.monotonic() @@ -64,6 +153,7 @@ def test_loop_basic(self) -> None: assert rcs is not None assert len(rcs) >= 1 expected_rc = self.INITIAL_RCS_VAL + # pylint: disable=not-an-iterable for ret_code in rcs: assert ret_code == expected_rc expected_rc += 1 @@ -104,6 +194,71 @@ def test_loop_is_connected(self): assert "not connected" in str(context.exception) + # pylint: disable=no-self-use + def test_loop_ping_timeout(self): + """Verify that ping will be sent even with loop timeout bigger than keep alive timeout + and no outgoing messages are sent.""" + + recv_timeout = 2 + keep_alive_timeout = recv_timeout * 2 + mqtt_client = MQTT.MQTT( + broker="localhost", + port=1883, + ssl_context=ssl.create_default_context(), + connect_retries=1, + socket_timeout=1, + recv_timeout=recv_timeout, + keep_alive=keep_alive_timeout, + ) + + # patch is_connected() to avoid CONNECT/CONNACK handling. + mqtt_client.is_connected = lambda: True + mocket = Pingtet() + # pylint: disable=protected-access + mqtt_client._sock = mocket + + start = time.monotonic() + res = mqtt_client.loop(timeout=2 * keep_alive_timeout) + assert time.monotonic() - start >= 2 * keep_alive_timeout + assert len(mocket.sent) > 0 + assert len(res) == 2 + assert set(res) == {int(0xD0)} + + # pylint: disable=no-self-use + def test_loop_ping_vs_msgs_sent(self): + """Verify that ping will not be sent unnecessarily.""" + + recv_timeout = 2 + keep_alive_timeout = recv_timeout * 2 + mqtt_client = MQTT.MQTT( + broker="localhost", + port=1883, + ssl_context=ssl.create_default_context(), + connect_retries=1, + socket_timeout=1, + recv_timeout=recv_timeout, + keep_alive=keep_alive_timeout, + ) + + # patch is_connected() to avoid CONNECT/CONNACK handling. + mqtt_client.is_connected = lambda: True + + # With QoS=0 no PUBACK message is sent, so Nulltet can be used. + mocket = Nulltet() + # pylint: disable=protected-access + mqtt_client._sock = mocket + + i = 0 + topic = "foo" + message = "bar" + for _ in range(3 * keep_alive_timeout): + mqtt_client.publish(topic, message, qos=0) + mqtt_client.loop(1) + i += 1 + + # This means no other messages than the PUBLISH messages generated by the code above. + assert len(mocket.sent) == i * (2 + 2 + len(topic) + len(message)) + if __name__ == "__main__": main()