Skip to content

Commit 70faa4f

Browse files
authored
Merge pull request #187 from vladak/subscribe_vs_remaining_len
encode/decode remaining length properly for {,UN}SUBSCRIBE/SUBACK
2 parents 4a52082 + 279387e commit 70faa4f

File tree

4 files changed

+499
-64
lines changed

4 files changed

+499
-64
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 77 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@
6060
MQTT_PINGREQ = b"\xc0\0"
6161
MQTT_PINGRESP = const(0xD0)
6262
MQTT_PUBLISH = const(0x30)
63-
MQTT_SUB = b"\x82"
64-
MQTT_UNSUB = b"\xA2"
63+
MQTT_SUB = const(0x82)
64+
MQTT_UNSUB = const(0xA2)
6565
MQTT_DISCONNECT = b"\xe0\0"
6666

6767
MQTT_PKT_TYPE_MASK = const(0xF0)
@@ -597,13 +597,12 @@ def _connect(
597597
self.broker, self.port, timeout=self._socket_timeout
598598
)
599599

600-
# Fixed Header
601600
fixed_header = bytearray([0x10])
602601

603602
# Variable CONNECT header [MQTT 3.1.2]
604603
# The byte array is used as a template.
605-
var_header = bytearray(b"\x04MQTT\x04\x02\0\0")
606-
var_header[6] = clean_session << 1
604+
var_header = bytearray(b"\x00\x04MQTT\x04\x02\0\0")
605+
var_header[7] = clean_session << 1
607606

608607
# Set up variable header and remaining_length
609608
remaining_length = 12 + len(self.client_id.encode("utf-8"))
@@ -614,36 +613,19 @@ def _connect(
614613
+ 2
615614
+ len(self._password.encode("utf-8"))
616615
)
617-
var_header[6] |= 0xC0
616+
var_header[7] |= 0xC0
618617
if self.keep_alive:
619618
assert self.keep_alive < MQTT_TOPIC_LENGTH_LIMIT
620-
var_header[7] |= self.keep_alive >> 8
621-
var_header[8] |= self.keep_alive & 0x00FF
619+
var_header[8] |= self.keep_alive >> 8
620+
var_header[9] |= self.keep_alive & 0x00FF
622621
if self._lw_topic:
623622
remaining_length += (
624623
2 + len(self._lw_topic.encode("utf-8")) + 2 + len(self._lw_msg)
625624
)
626-
var_header[6] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3
627-
var_header[6] |= self._lw_retain << 5
628-
629-
# Remaining length calculation
630-
large_rel_length = False
631-
if remaining_length > 0x7F:
632-
large_rel_length = True
633-
# Calculate Remaining Length [2.2.3]
634-
while remaining_length > 0:
635-
encoded_byte = remaining_length % 0x80
636-
remaining_length = remaining_length // 0x80
637-
# if there is more data to encode, set the top bit of the byte
638-
if remaining_length > 0:
639-
encoded_byte |= 0x80
640-
fixed_header.append(encoded_byte)
641-
if large_rel_length:
642-
fixed_header.append(0x00)
643-
else:
644-
fixed_header.append(remaining_length)
645-
fixed_header.append(0x00)
625+
var_header[7] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3
626+
var_header[7] |= self._lw_retain << 5
646627

628+
self._encode_remaining_length(fixed_header, remaining_length)
647629
self.logger.debug("Sending CONNECT to broker...")
648630
self.logger.debug(f"Fixed Header: {fixed_header}")
649631
self.logger.debug(f"Variable Header: {var_header}")
@@ -680,6 +662,26 @@ def _connect(
680662
f"No data received from broker for {self._recv_timeout} seconds."
681663
)
682664

665+
# pylint: disable=no-self-use
666+
def _encode_remaining_length(
667+
self, fixed_header: bytearray, remaining_length: int
668+
) -> None:
669+
"""Encode Remaining Length [2.2.3]"""
670+
if remaining_length > 268_435_455:
671+
raise MMQTTException("invalid remaining length")
672+
673+
# Remaining length calculation
674+
if remaining_length > 0x7F:
675+
while remaining_length > 0:
676+
encoded_byte = remaining_length % 0x80
677+
remaining_length = remaining_length // 0x80
678+
# if there is more data to encode, set the top bit of the byte
679+
if remaining_length > 0:
680+
encoded_byte |= 0x80
681+
fixed_header.append(encoded_byte)
682+
else:
683+
fixed_header.append(remaining_length)
684+
683685
def disconnect(self) -> None:
684686
"""Disconnects the MiniMQTT client from the MQTT broker."""
685687
self._connected()
@@ -766,16 +768,7 @@ def publish(
766768
pub_hdr_var.append(self._pid >> 8)
767769
pub_hdr_var.append(self._pid & 0xFF)
768770

769-
# Calculate remaining length [2.2.3]
770-
if remaining_length > 0x7F:
771-
while remaining_length > 0:
772-
encoded_byte = remaining_length % 0x80
773-
remaining_length = remaining_length // 0x80
774-
if remaining_length > 0:
775-
encoded_byte |= 0x80
776-
pub_hdr_fixed.append(encoded_byte)
777-
else:
778-
pub_hdr_fixed.append(remaining_length)
771+
self._encode_remaining_length(pub_hdr_fixed, remaining_length)
779772

780773
self.logger.debug(
781774
"Sending PUBLISH\nTopic: %s\nMsg: %s\
@@ -810,9 +803,9 @@ def publish(
810803
f"No data received from broker for {self._recv_timeout} seconds."
811804
)
812805

813-
def subscribe(self, topic: str, qos: int = 0) -> None:
806+
def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> None:
814807
"""Subscribes to a topic on the MQTT Broker.
815-
This method can subscribe to one topics or multiple topics.
808+
This method can subscribe to one topic or multiple topics.
816809
817810
:param str|tuple|list topic: Unique MQTT topic identifier string. If
818811
this is a `tuple`, then the tuple should
@@ -842,21 +835,28 @@ def subscribe(self, topic: str, qos: int = 0) -> None:
842835
self._valid_topic(t)
843836
topics.append((t, q))
844837
# Assemble packet
838+
self.logger.debug("Sending SUBSCRIBE to broker...")
839+
fixed_header = bytearray([MQTT_SUB])
845840
packet_length = 2 + (2 * len(topics)) + (1 * len(topics))
846841
packet_length += sum(len(topic.encode("utf-8")) for topic, qos in topics)
847-
packet_length_byte = packet_length.to_bytes(1, "big")
842+
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
843+
self.logger.debug(f"Fixed Header: {fixed_header}")
844+
self._sock.send(fixed_header)
848845
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
849846
packet_id_bytes = self._pid.to_bytes(2, "big")
850-
# Packet with variable and fixed headers
851-
packet = MQTT_SUB + packet_length_byte + packet_id_bytes
847+
var_header = packet_id_bytes
848+
self.logger.debug(f"Variable Header: {var_header}")
849+
self._sock.send(var_header)
852850
# attaching topic and QOS level to the packet
851+
payload = bytes()
853852
for t, q in topics:
854853
topic_size = len(t.encode("utf-8")).to_bytes(2, "big")
855854
qos_byte = q.to_bytes(1, "big")
856-
packet += topic_size + t.encode() + qos_byte
855+
payload += topic_size + t.encode() + qos_byte
857856
for t, q in topics:
858-
self.logger.debug("SUBSCRIBING to topic %s with QoS %d", t, q)
859-
self._sock.send(packet)
857+
self.logger.debug(f"SUBSCRIBING to topic {t} with QoS {q}")
858+
self.logger.debug(f"payload: {payload}")
859+
self._sock.send(payload)
860860
stamp = self.get_monotonic_time()
861861
while True:
862862
op = self._wait_for_msg()
@@ -867,13 +867,13 @@ def subscribe(self, topic: str, qos: int = 0) -> None:
867867
)
868868
else:
869869
if op == 0x90:
870-
rc = self._sock_exact_recv(3)
871-
# Check packet identifier.
872-
assert rc[1] == packet[2] and rc[2] == packet[3]
873-
remaining_len = rc[0] - 2
870+
remaining_len = self._decode_remaining_length()
874871
assert remaining_len > 0
875-
rc = self._sock_exact_recv(remaining_len)
876-
for i in range(0, remaining_len):
872+
rc = self._sock_exact_recv(2)
873+
# Check packet identifier.
874+
assert rc[0] == var_header[0] and rc[1] == var_header[1]
875+
rc = self._sock_exact_recv(remaining_len - 2)
876+
for i in range(0, remaining_len - 2):
877877
if rc[i] not in [0, 1, 2]:
878878
raise MMQTTException(
879879
f"SUBACK Failure for topic {topics[i][0]}: {hex(rc[i])}"
@@ -883,13 +883,17 @@ def subscribe(self, topic: str, qos: int = 0) -> None:
883883
if self.on_subscribe is not None:
884884
self.on_subscribe(self, self.user_data, t, q)
885885
self._subscribed_topics.append(t)
886+
886887
return
887888

888-
raise MMQTTException(
889-
f"invalid message received as response to SUBSCRIBE: {hex(op)}"
890-
)
889+
if op != MQTT_PUBLISH:
890+
# [3.8.4] The Server is permitted to start sending PUBLISH packets
891+
# matching the Subscription before the Server sends the SUBACK Packet.
892+
raise MMQTTException(
893+
f"invalid message received as response to SUBSCRIBE: {hex(op)}"
894+
)
891895

892-
def unsubscribe(self, topic: str) -> None:
896+
def unsubscribe(self, topic: Optional[Union[str, list]]) -> None:
893897
"""Unsubscribes from a MQTT topic.
894898
895899
:param str|list topic: Unique MQTT topic identifier string or list.
@@ -910,18 +914,25 @@ def unsubscribe(self, topic: str) -> None:
910914
"Topic must be subscribed to before attempting unsubscribe."
911915
)
912916
# Assemble packet
917+
self.logger.debug("Sending UNSUBSCRIBE to broker...")
918+
fixed_header = bytearray([MQTT_UNSUB])
913919
packet_length = 2 + (2 * len(topics))
914920
packet_length += sum(len(topic.encode("utf-8")) for topic in topics)
915-
packet_length_byte = packet_length.to_bytes(1, "big")
921+
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
922+
self.logger.debug(f"Fixed Header: {fixed_header}")
923+
self._sock.send(fixed_header)
916924
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
917925
packet_id_bytes = self._pid.to_bytes(2, "big")
918-
packet = MQTT_UNSUB + packet_length_byte + packet_id_bytes
926+
var_header = packet_id_bytes
927+
self.logger.debug(f"Variable Header: {var_header}")
928+
self._sock.send(var_header)
929+
payload = bytes()
919930
for t in topics:
920931
topic_size = len(t.encode("utf-8")).to_bytes(2, "big")
921-
packet += topic_size + t.encode()
932+
payload += topic_size + t.encode()
922933
for t in topics:
923-
self.logger.debug("UNSUBSCRIBING from topic %s", t)
924-
self._sock.send(packet)
934+
self.logger.debug(f"UNSUBSCRIBING from topic {t}")
935+
self._sock.send(payload)
925936
self.logger.debug("Waiting for UNSUBACK...")
926937
while True:
927938
stamp = self.get_monotonic_time()
@@ -1082,7 +1093,7 @@ def _wait_for_msg(self) -> Optional[int]:
10821093
return pkt_type
10831094

10841095
# Handle only the PUBLISH packet type from now on.
1085-
sz = self._recv_len()
1096+
sz = self._decode_remaining_length()
10861097
# topic length MSB & LSB
10871098
topic_len_buf = self._sock_exact_recv(2)
10881099
topic_len = int((topic_len_buf[0] << 8) | topic_len_buf[1])
@@ -1115,11 +1126,13 @@ def _wait_for_msg(self) -> Optional[int]:
11151126

11161127
return pkt_type
11171128

1118-
def _recv_len(self) -> int:
1119-
"""Unpack MQTT message length."""
1129+
def _decode_remaining_length(self) -> int:
1130+
"""Decode Remaining Length [2.2.3]"""
11201131
n = 0
11211132
sh = 0
11221133
while True:
1134+
if sh > 28:
1135+
raise MMQTTException("invalid remaining length encoding")
11231136
b = self._sock_exact_recv(1)[0]
11241137
n |= (b & 0x7F) << sh
11251138
if not b & 0x80:

tests/mocket.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# SPDX-FileCopyrightText: 2023 Vladimír Kotal
2+
#
3+
# SPDX-License-Identifier: Unlicense
4+
5+
"""fake socket class for protocol level testing"""
6+
7+
from unittest import mock
8+
9+
10+
class Mocket:
11+
"""
12+
Mock Socket tailored for MiniMQTT testing. Records sent data,
13+
hands out pre-recorded reply.
14+
15+
Inspired by the Mocket class from Adafruit_CircuitPython_Requests
16+
"""
17+
18+
def __init__(self, to_send):
19+
self._to_send = to_send
20+
21+
self.sent = bytearray()
22+
23+
self.timeout = mock.Mock()
24+
self.connect = mock.Mock()
25+
self.close = mock.Mock()
26+
27+
def send(self, bytes_to_send):
28+
"""merely record the bytes. return the length of this bytearray."""
29+
self.sent.extend(bytes_to_send)
30+
return len(bytes_to_send)
31+
32+
# MiniMQTT checks for the presence of "recv_into" and switches behavior based on that.
33+
def recv_into(self, retbuf, bufsize):
34+
"""return data from internal buffer"""
35+
size = min(bufsize, len(self._to_send))
36+
if size == 0:
37+
return size
38+
chop = self._to_send[0:size]
39+
retbuf[0:] = chop
40+
self._to_send = self._to_send[size:]
41+
return size

0 commit comments

Comments
 (0)