|
| 1 | +# SPDX-FileCopyrightText: 2023 Vladimír Kotal |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Unlicense |
| 4 | + |
| 5 | +"""exponential back-off tests""" |
| 6 | + |
| 7 | +import socket |
| 8 | +import ssl |
| 9 | +import time |
| 10 | +from unittest import TestCase, main |
| 11 | +from unittest.mock import call, patch |
| 12 | + |
| 13 | +import adafruit_minimqtt.adafruit_minimqtt as MQTT |
| 14 | + |
| 15 | + |
| 16 | +class ExpBackOff(TestCase): |
| 17 | + """basic exponential back-off test""" |
| 18 | + |
| 19 | + connect_times = [] |
| 20 | + |
| 21 | + # pylint: disable=unused-argument |
| 22 | + def fake_connect(self, arg): |
| 23 | + """connect() replacement that records the call times and always raises OSError""" |
| 24 | + self.connect_times.append(time.monotonic()) |
| 25 | + raise OSError("this connect failed") |
| 26 | + |
| 27 | + def test_failing_connect(self) -> None: |
| 28 | + """test that exponential back-off is used when connect() always raises OSError""" |
| 29 | + # use RFC 1918 address to avoid dealing with IPv6 in the call list below |
| 30 | + host = "172.40.0.3" |
| 31 | + port = 1883 |
| 32 | + |
| 33 | + with patch.object(socket.socket, "connect") as mock_method: |
| 34 | + mock_method.side_effect = self.fake_connect |
| 35 | + |
| 36 | + connect_retries = 3 |
| 37 | + mqtt_client = MQTT.MQTT( |
| 38 | + broker=host, |
| 39 | + port=port, |
| 40 | + socket_pool=socket, |
| 41 | + ssl_context=ssl.create_default_context(), |
| 42 | + connect_retries=connect_retries, |
| 43 | + ) |
| 44 | + print("connecting") |
| 45 | + with self.assertRaises(MQTT.MMQTTException) as context: |
| 46 | + mqtt_client.connect() |
| 47 | + self.assertTrue("Repeated connect failures" in str(context.exception)) |
| 48 | + |
| 49 | + mock_method.assert_called() |
| 50 | + calls = [call((host, port)) for _ in range(0, connect_retries)] |
| 51 | + mock_method.assert_has_calls(calls) |
| 52 | + |
| 53 | + print(f"connect() call times: {self.connect_times}") |
| 54 | + for i in range(1, connect_retries): |
| 55 | + assert self.connect_times[i] >= 2**i |
| 56 | + |
| 57 | + |
| 58 | +if __name__ == "__main__": |
| 59 | + main() |
0 commit comments