Skip to content

Commit b714fba

Browse files
authored
Merge pull request #153 from vladak/back_off_tests
add basic back-off test
2 parents 988332f + f60ff0c commit b714fba

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ _build
4646
.idea
4747
.vscode
4848
*~
49+
50+
# tox local cache
51+
.tox

tests/backoff_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-FileCopyrightText: 2023 Vladimír Kotal
2+
#
3+
# SPDX-License-Identifier: Unlicense
4+
5+
"""exponential back-off tests"""
6+
7+
import socket
8+
import ssl
9+
import time
10+
from unittest import TestCase, main
11+
from unittest.mock import call, patch
12+
13+
import adafruit_minimqtt.adafruit_minimqtt as MQTT
14+
15+
16+
class ExpBackOff(TestCase):
17+
"""basic exponential back-off test"""
18+
19+
connect_times = []
20+
21+
# pylint: disable=unused-argument
22+
def fake_connect(self, arg):
23+
"""connect() replacement that records the call times and always raises OSError"""
24+
self.connect_times.append(time.monotonic())
25+
raise OSError("this connect failed")
26+
27+
def test_failing_connect(self) -> None:
28+
"""test that exponential back-off is used when connect() always raises OSError"""
29+
# use RFC 1918 address to avoid dealing with IPv6 in the call list below
30+
host = "172.40.0.3"
31+
port = 1883
32+
33+
with patch.object(socket.socket, "connect") as mock_method:
34+
mock_method.side_effect = self.fake_connect
35+
36+
connect_retries = 3
37+
mqtt_client = MQTT.MQTT(
38+
broker=host,
39+
port=port,
40+
socket_pool=socket,
41+
ssl_context=ssl.create_default_context(),
42+
connect_retries=connect_retries,
43+
)
44+
print("connecting")
45+
with self.assertRaises(MQTT.MMQTTException) as context:
46+
mqtt_client.connect()
47+
self.assertTrue("Repeated connect failures" in str(context.exception))
48+
49+
mock_method.assert_called()
50+
calls = [call((host, port)) for _ in range(0, connect_retries)]
51+
mock_method.assert_has_calls(calls)
52+
53+
print(f"connect() call times: {self.connect_times}")
54+
for i in range(1, connect_retries):
55+
assert self.connect_times[i] >= 2**i
56+
57+
58+
if __name__ == "__main__":
59+
main()

0 commit comments

Comments
 (0)