Skip to content

Commit bfcaea6

Browse files
authored
Merge pull request #145 from vladak/tls_port_detection
allow to use any port as TLS port
2 parents 342b8c9 + 768c046 commit bfcaea6

File tree

6 files changed

+159
-11
lines changed

6 files changed

+159
-11
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def __init__(
171171
username=None,
172172
password=None,
173173
client_id=None,
174-
is_ssl=True,
174+
is_ssl=None,
175175
keep_alive=60,
176176
recv_timeout=10,
177177
socket_pool=None,
@@ -220,13 +220,19 @@ def __init__(
220220
): # [MQTT-3.1.3.5]
221221
raise MMQTTException("Password length is too large.")
222222

223+
# The connection will be insecure unless is_ssl is set to True.
224+
# If the port is not specified, the security will be set based on the is_ssl parameter.
225+
# If the port is specified, the is_ssl parameter will be honored.
223226
self.port = MQTT_TCP_PORT
224-
if is_ssl:
227+
if is_ssl is None:
228+
is_ssl = False
229+
self._is_ssl = is_ssl
230+
if self._is_ssl:
225231
self.port = MQTT_TLS_PORT
226232
if port:
227233
self.port = port
228234

229-
# define client identifer
235+
# define client identifier
230236
if client_id:
231237
# user-defined client_id MAY allow client_id's > 23 bytes or
232238
# non-alpha-numeric characters
@@ -282,12 +288,12 @@ def _get_connect_socket(self, host, port, *, timeout=1):
282288
if not isinstance(port, int):
283289
raise RuntimeError("Port must be an integer")
284290

285-
if port == MQTT_TLS_PORT and not self._ssl_context:
291+
if self._is_ssl and not self._ssl_context:
286292
raise RuntimeError(
287293
"ssl_context must be set before using adafruit_mqtt for secure MQTT."
288294
)
289295

290-
if port == MQTT_TLS_PORT:
296+
if self._is_ssl:
291297
self.logger.info(f"Establishing a SECURE SSL connection to {host}:{port}")
292298
else:
293299
self.logger.info(f"Establishing an INSECURE connection to {host}:{port}")
@@ -306,7 +312,7 @@ def _get_connect_socket(self, host, port, *, timeout=1):
306312
raise TemporaryError from exc
307313

308314
connect_host = addr_info[-1][0]
309-
if port == MQTT_TLS_PORT:
315+
if self._is_ssl:
310316
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
311317
connect_host = host
312318
sock.settimeout(timeout)

examples/cpython/minimqtt_adafruitio_cpython.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# SPDX-FileCopyrightText: 2021 ladyada for Adafruit Industries
22
# SPDX-License-Identifier: MIT
33

4-
import time
54
import socket
5+
import ssl
6+
import time
7+
68
import adafruit_minimqtt.adafruit_minimqtt as MQTT
79

810
### Secrets File Setup ###
@@ -46,11 +48,12 @@ def message(client, topic, message):
4648

4749
# Set up a MiniMQTT Client
4850
mqtt_client = MQTT.MQTT(
49-
broker=secrets["broker"],
50-
port=1883,
51+
broker="io.adafruit.com",
5152
username=secrets["aio_username"],
5253
password=secrets["aio_key"],
5354
socket_pool=socket,
55+
is_ssl=True,
56+
ssl_context=ssl.create_default_context(),
5457
)
5558

5659
# Setup the callback methods above

examples/ethernet/minimqtt_simpletest_eth.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,12 @@ def publish(client, userdata, topic, pid):
6767
MQTT.set_socket(socket, eth)
6868

6969
# Set up a MiniMQTT Client
70+
# NOTE: We'll need to connect insecurely for ethernet configurations.
7071
client = MQTT.MQTT(
71-
broker=secrets["broker"], username=secrets["user"], password=secrets["pass"]
72+
broker=secrets["broker"],
73+
username=secrets["user"],
74+
password=secrets["pass"],
75+
is_ssl=False,
7276
)
7377

7478
# Connect callback handlers to client

examples/native_networking/minimqtt_adafruitio_native_networking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def message(client, topic, message):
6262

6363
# Set up a MiniMQTT Client
6464
mqtt_client = MQTT.MQTT(
65-
broker=secrets["broker"],
65+
broker="io.adafruit.com",
6666
port=secrets["port"],
6767
username=secrets["aio_username"],
6868
password=secrets["aio_key"],

tests/test_port_ssl.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# SPDX-FileCopyrightText: 2023 Vladimír Kotal
2+
#
3+
# SPDX-License-Identifier: Unlicense
4+
5+
"""tests that verify the connect behavior w.r.t. port number and TLS"""
6+
7+
import socket
8+
import ssl
9+
from unittest import TestCase, main
10+
from unittest.mock import Mock, call, patch
11+
12+
import adafruit_minimqtt.adafruit_minimqtt as MQTT
13+
14+
15+
class PortSslSetup(TestCase):
16+
"""This class contains tests that verify how host/port and TLS is set for connect().
17+
These tests assume that there is no MQTT broker running on the hosts/ports they connect to.
18+
"""
19+
20+
def test_default_port(self) -> None:
21+
"""verify default port value and that TLS is not used"""
22+
host = "127.0.0.1"
23+
port = 1883
24+
25+
with patch.object(socket.socket, "connect") as connect_mock:
26+
ssl_context = ssl.create_default_context()
27+
mqtt_client = MQTT.MQTT(
28+
broker=host,
29+
socket_pool=socket,
30+
ssl_context=ssl_context,
31+
connect_retries=1,
32+
)
33+
34+
ssl_mock = Mock()
35+
ssl_context.wrap_socket = ssl_mock
36+
37+
with self.assertRaises(MQTT.MMQTTException):
38+
expected_port = port
39+
mqtt_client.connect()
40+
41+
ssl_mock.assert_not_called()
42+
connect_mock.assert_called()
43+
# Assuming the repeated calls will have the same arguments.
44+
connect_mock.assert_has_calls([call((host, expected_port))])
45+
46+
def test_connect_override(self):
47+
"""Test that connect() can override host and port."""
48+
host = "127.0.0.1"
49+
port = 1883
50+
51+
with patch.object(socket.socket, "connect") as connect_mock:
52+
connect_mock.side_effect = OSError("artificial error")
53+
mqtt_client = MQTT.MQTT(
54+
broker=host,
55+
port=port,
56+
socket_pool=socket,
57+
connect_retries=1,
58+
)
59+
60+
with self.assertRaises(MQTT.MMQTTException):
61+
expected_host = "127.0.0.2"
62+
expected_port = 1884
63+
self.assertNotEqual(expected_port, port, "port override should differ")
64+
self.assertNotEqual(expected_host, host, "host override should differ")
65+
mqtt_client.connect(host=expected_host, port=expected_port)
66+
67+
connect_mock.assert_called()
68+
# Assuming the repeated calls will have the same arguments.
69+
connect_mock.assert_has_calls([call((expected_host, expected_port))])
70+
71+
def test_tls_port(self) -> None:
72+
"""verify that when is_ssl=True is set, the default port is 8883
73+
and the socket is TLS wrapped. Also test that the TLS port can be overridden."""
74+
host = "127.0.0.1"
75+
76+
for port in [None, 8884]:
77+
if port is None:
78+
expected_port = 8883
79+
else:
80+
expected_port = port
81+
with self.subTest():
82+
ssl_mock = Mock()
83+
mqtt_client = MQTT.MQTT(
84+
broker=host,
85+
port=port,
86+
socket_pool=socket,
87+
is_ssl=True,
88+
ssl_context=ssl_mock,
89+
connect_retries=1,
90+
)
91+
92+
socket_mock = Mock()
93+
connect_mock = Mock(side_effect=OSError)
94+
socket_mock.connect = connect_mock
95+
ssl_mock.wrap_socket = Mock(return_value=socket_mock)
96+
97+
with self.assertRaises(MQTT.MMQTTException):
98+
mqtt_client.connect()
99+
100+
ssl_mock.wrap_socket.assert_called()
101+
102+
connect_mock.assert_called()
103+
# Assuming the repeated calls will have the same arguments.
104+
connect_mock.assert_has_calls([call((host, expected_port))])
105+
106+
def test_tls_without_ssl_context(self) -> None:
107+
"""verify that when is_ssl=True is set, the code will check that ssl_context is not None"""
108+
host = "127.0.0.1"
109+
110+
mqtt_client = MQTT.MQTT(
111+
broker=host,
112+
socket_pool=socket,
113+
is_ssl=True,
114+
ssl_context=None,
115+
connect_retries=1,
116+
)
117+
118+
with self.assertRaises(RuntimeError) as context:
119+
mqtt_client.connect()
120+
self.assertTrue("ssl_context must be set" in str(context))
121+
122+
123+
if __name__ == "__main__":
124+
main()

tox.ini

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-FileCopyrightText: 2023 Vladimír Kotal
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
[tox]
6+
envlist = py39
7+
8+
[testenv]
9+
changedir = {toxinidir}/tests
10+
deps = pytest==6.2.5
11+
commands = pytest -v

0 commit comments

Comments
 (0)