diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 2d74dd92..519aae2f 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -171,7 +171,7 @@ def __init__( username=None, password=None, client_id=None, - is_ssl=True, + is_ssl=None, keep_alive=60, recv_timeout=10, socket_pool=None, @@ -220,13 +220,19 @@ def __init__( ): # [MQTT-3.1.3.5] raise MMQTTException("Password length is too large.") + # The connection will be insecure unless is_ssl is set to True. + # If the port is not specified, the security will be set based on the is_ssl parameter. + # If the port is specified, the is_ssl parameter will be honored. self.port = MQTT_TCP_PORT - if is_ssl: + if is_ssl is None: + is_ssl = False + self._is_ssl = is_ssl + if self._is_ssl: self.port = MQTT_TLS_PORT if port: self.port = port - # define client identifer + # define client identifier if client_id: # user-defined client_id MAY allow client_id's > 23 bytes or # non-alpha-numeric characters @@ -282,12 +288,12 @@ def _get_connect_socket(self, host, port, *, timeout=1): if not isinstance(port, int): raise RuntimeError("Port must be an integer") - if port == MQTT_TLS_PORT and not self._ssl_context: + if self._is_ssl and not self._ssl_context: raise RuntimeError( "ssl_context must be set before using adafruit_mqtt for secure MQTT." ) - if port == MQTT_TLS_PORT: + if self._is_ssl: self.logger.info(f"Establishing a SECURE SSL connection to {host}:{port}") else: self.logger.info(f"Establishing an INSECURE connection to {host}:{port}") @@ -306,7 +312,7 @@ def _get_connect_socket(self, host, port, *, timeout=1): raise TemporaryError from exc connect_host = addr_info[-1][0] - if port == MQTT_TLS_PORT: + if self._is_ssl: sock = self._ssl_context.wrap_socket(sock, server_hostname=host) connect_host = host sock.settimeout(timeout) diff --git a/examples/cpython/minimqtt_adafruitio_cpython.py b/examples/cpython/minimqtt_adafruitio_cpython.py index 7eb4f5fb..3667329b 100644 --- a/examples/cpython/minimqtt_adafruitio_cpython.py +++ b/examples/cpython/minimqtt_adafruitio_cpython.py @@ -1,8 +1,10 @@ # SPDX-FileCopyrightText: 2021 ladyada for Adafruit Industries # SPDX-License-Identifier: MIT -import time import socket +import ssl +import time + import adafruit_minimqtt.adafruit_minimqtt as MQTT ### Secrets File Setup ### @@ -46,11 +48,12 @@ def message(client, topic, message): # Set up a MiniMQTT Client mqtt_client = MQTT.MQTT( - broker=secrets["broker"], - port=1883, + broker="io.adafruit.com", username=secrets["aio_username"], password=secrets["aio_key"], socket_pool=socket, + is_ssl=True, + ssl_context=ssl.create_default_context(), ) # Setup the callback methods above diff --git a/examples/ethernet/minimqtt_simpletest_eth.py b/examples/ethernet/minimqtt_simpletest_eth.py index c585cf78..3e91c8fc 100644 --- a/examples/ethernet/minimqtt_simpletest_eth.py +++ b/examples/ethernet/minimqtt_simpletest_eth.py @@ -67,8 +67,12 @@ def publish(client, userdata, topic, pid): MQTT.set_socket(socket, eth) # Set up a MiniMQTT Client +# NOTE: We'll need to connect insecurely for ethernet configurations. client = MQTT.MQTT( - broker=secrets["broker"], username=secrets["user"], password=secrets["pass"] + broker=secrets["broker"], + username=secrets["user"], + password=secrets["pass"], + is_ssl=False, ) # Connect callback handlers to client diff --git a/examples/native_networking/minimqtt_adafruitio_native_networking.py b/examples/native_networking/minimqtt_adafruitio_native_networking.py index a4f5ecaa..21661d31 100644 --- a/examples/native_networking/minimqtt_adafruitio_native_networking.py +++ b/examples/native_networking/minimqtt_adafruitio_native_networking.py @@ -62,7 +62,7 @@ def message(client, topic, message): # Set up a MiniMQTT Client mqtt_client = MQTT.MQTT( - broker=secrets["broker"], + broker="io.adafruit.com", port=secrets["port"], username=secrets["aio_username"], password=secrets["aio_key"], diff --git a/tests/test_port_ssl.py b/tests/test_port_ssl.py new file mode 100644 index 00000000..8474b56d --- /dev/null +++ b/tests/test_port_ssl.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: 2023 Vladimír Kotal +# +# SPDX-License-Identifier: Unlicense + +"""tests that verify the connect behavior w.r.t. port number and TLS""" + +import socket +import ssl +from unittest import TestCase, main +from unittest.mock import Mock, call, patch + +import adafruit_minimqtt.adafruit_minimqtt as MQTT + + +class PortSslSetup(TestCase): + """This class contains tests that verify how host/port and TLS is set for connect(). + These tests assume that there is no MQTT broker running on the hosts/ports they connect to. + """ + + def test_default_port(self) -> None: + """verify default port value and that TLS is not used""" + host = "127.0.0.1" + port = 1883 + + with patch.object(socket.socket, "connect") as connect_mock: + ssl_context = ssl.create_default_context() + mqtt_client = MQTT.MQTT( + broker=host, + socket_pool=socket, + ssl_context=ssl_context, + connect_retries=1, + ) + + ssl_mock = Mock() + ssl_context.wrap_socket = ssl_mock + + with self.assertRaises(MQTT.MMQTTException): + expected_port = port + mqtt_client.connect() + + ssl_mock.assert_not_called() + connect_mock.assert_called() + # Assuming the repeated calls will have the same arguments. + connect_mock.assert_has_calls([call((host, expected_port))]) + + def test_connect_override(self): + """Test that connect() can override host and port.""" + host = "127.0.0.1" + port = 1883 + + with patch.object(socket.socket, "connect") as connect_mock: + connect_mock.side_effect = OSError("artificial error") + mqtt_client = MQTT.MQTT( + broker=host, + port=port, + socket_pool=socket, + connect_retries=1, + ) + + with self.assertRaises(MQTT.MMQTTException): + expected_host = "127.0.0.2" + expected_port = 1884 + self.assertNotEqual(expected_port, port, "port override should differ") + self.assertNotEqual(expected_host, host, "host override should differ") + mqtt_client.connect(host=expected_host, port=expected_port) + + connect_mock.assert_called() + # Assuming the repeated calls will have the same arguments. + connect_mock.assert_has_calls([call((expected_host, expected_port))]) + + def test_tls_port(self) -> None: + """verify that when is_ssl=True is set, the default port is 8883 + and the socket is TLS wrapped. Also test that the TLS port can be overridden.""" + host = "127.0.0.1" + + for port in [None, 8884]: + if port is None: + expected_port = 8883 + else: + expected_port = port + with self.subTest(): + ssl_mock = Mock() + mqtt_client = MQTT.MQTT( + broker=host, + port=port, + socket_pool=socket, + is_ssl=True, + ssl_context=ssl_mock, + connect_retries=1, + ) + + socket_mock = Mock() + connect_mock = Mock(side_effect=OSError) + socket_mock.connect = connect_mock + ssl_mock.wrap_socket = Mock(return_value=socket_mock) + + with self.assertRaises(MQTT.MMQTTException): + mqtt_client.connect() + + ssl_mock.wrap_socket.assert_called() + + connect_mock.assert_called() + # Assuming the repeated calls will have the same arguments. + connect_mock.assert_has_calls([call((host, expected_port))]) + + def test_tls_without_ssl_context(self) -> None: + """verify that when is_ssl=True is set, the code will check that ssl_context is not None""" + host = "127.0.0.1" + + mqtt_client = MQTT.MQTT( + broker=host, + socket_pool=socket, + is_ssl=True, + ssl_context=None, + connect_retries=1, + ) + + with self.assertRaises(RuntimeError) as context: + mqtt_client.connect() + self.assertTrue("ssl_context must be set" in str(context)) + + +if __name__ == "__main__": + main() diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..6a9584b3 --- /dev/null +++ b/tox.ini @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: 2023 Vladimír Kotal +# +# SPDX-License-Identifier: MIT + +[tox] +envlist = py39 + +[testenv] +changedir = {toxinidir}/tests +deps = pytest==6.2.5 +commands = pytest -v