@@ -144,14 +144,14 @@ def __init__(
144144 self ._opcode : int = OP_CODE_NOT_SET
145145 self ._frame_fin = False
146146 self ._frame_opcode : int = OP_CODE_NOT_SET
147- self ._frame_payload : Union [bytes , bytearray ] = b""
147+ self ._payload_fragments : list [bytes ] = []
148148 self ._frame_payload_len = 0
149149
150150 self ._tail : bytes = b""
151151 self ._has_mask = False
152152 self ._frame_mask : Optional [bytes ] = None
153- self ._payload_length = 0
154- self ._payload_length_flag = 0
153+ self ._payload_bytes_to_read = 0
154+ self ._payload_len_flag = 0
155155 self ._compressed : int = COMPRESSED_NOT_SET
156156 self ._decompressobj : Optional [ZLibDecompressor ] = None
157157 self ._compress = compress
@@ -317,13 +317,13 @@ def _feed_data(self, data: bytes) -> None:
317317 data , self ._tail = self ._tail + data , b""
318318
319319 start_pos : int = 0
320- data_length = len (data )
320+ data_len = len (data )
321321 data_cstr = data
322322
323323 while True :
324324 # read header
325325 if self ._state == READ_HEADER :
326- if data_length - start_pos < 2 :
326+ if data_len - start_pos < 2 :
327327 break
328328 first_byte = data_cstr [start_pos ]
329329 second_byte = data_cstr [start_pos + 1 ]
@@ -382,77 +382,88 @@ def _feed_data(self, data: bytes) -> None:
382382 self ._frame_fin = bool (fin )
383383 self ._frame_opcode = opcode
384384 self ._has_mask = bool (has_mask )
385- self ._payload_length_flag = length
385+ self ._payload_len_flag = length
386386 self ._state = READ_PAYLOAD_LENGTH
387387
388388 # read payload length
389389 if self ._state == READ_PAYLOAD_LENGTH :
390- length_flag = self ._payload_length_flag
391- if length_flag == 126 :
392- if data_length - start_pos < 2 :
390+ len_flag = self ._payload_len_flag
391+ if len_flag == 126 :
392+ if data_len - start_pos < 2 :
393393 break
394394 first_byte = data_cstr [start_pos ]
395395 second_byte = data_cstr [start_pos + 1 ]
396396 start_pos += 2
397- self ._payload_length = first_byte << 8 | second_byte
398- elif length_flag > 126 :
399- if data_length - start_pos < 8 :
397+ self ._payload_bytes_to_read = first_byte << 8 | second_byte
398+ elif len_flag > 126 :
399+ if data_len - start_pos < 8 :
400400 break
401- self ._payload_length = UNPACK_LEN3 (data , start_pos )[0 ]
401+ self ._payload_bytes_to_read = UNPACK_LEN3 (data , start_pos )[0 ]
402402 start_pos += 8
403403 else :
404- self ._payload_length = length_flag
404+ self ._payload_bytes_to_read = len_flag
405405
406406 self ._state = READ_PAYLOAD_MASK if self ._has_mask else READ_PAYLOAD
407407
408408 # read payload mask
409409 if self ._state == READ_PAYLOAD_MASK :
410- if data_length - start_pos < 4 :
410+ if data_len - start_pos < 4 :
411411 break
412412 self ._frame_mask = data_cstr [start_pos : start_pos + 4 ]
413413 start_pos += 4
414414 self ._state = READ_PAYLOAD
415415
416416 if self ._state == READ_PAYLOAD :
417- chunk_len = data_length - start_pos
418- if self ._payload_length >= chunk_len :
419- end_pos = data_length
420- self ._payload_length -= chunk_len
417+ chunk_len = data_len - start_pos
418+ if self ._payload_bytes_to_read >= chunk_len :
419+ f_end_pos = data_len
420+ self ._payload_bytes_to_read -= chunk_len
421421 else :
422- end_pos = start_pos + self ._payload_length
423- self ._payload_length = 0
424-
425- if self ._frame_payload_len :
426- if type (self ._frame_payload ) is not bytearray :
427- self ._frame_payload = bytearray (self ._frame_payload )
428- self ._frame_payload += data_cstr [start_pos :end_pos ]
429- else :
430- # Fast path for the first frame
431- self ._frame_payload = data_cstr [start_pos :end_pos ]
432-
433- self ._frame_payload_len += end_pos - start_pos
434- start_pos = end_pos
435-
436- if self ._payload_length != 0 :
422+ f_end_pos = start_pos + self ._payload_bytes_to_read
423+ self ._payload_bytes_to_read = 0
424+
425+ had_fragments = self ._frame_payload_len
426+ self ._frame_payload_len += f_end_pos - start_pos
427+ f_start_pos = start_pos
428+ start_pos = f_end_pos
429+
430+ if self ._payload_bytes_to_read != 0 :
431+ # If we don't have a complete frame, we need to save the
432+ # data for the next call to feed_data.
433+ self ._payload_fragments .append (data_cstr [f_start_pos :f_end_pos ])
437434 break
438435
439- if self ._has_mask :
436+ payload : Union [bytes , bytearray ]
437+ if had_fragments :
438+ # We have to join the payload fragments get the payload
439+ self ._payload_fragments .append (data_cstr [f_start_pos :f_end_pos ])
440+ if self ._has_mask :
441+ assert self ._frame_mask is not None
442+ payload_bytearray = bytearray ()
443+ payload_bytearray .join (self ._payload_fragments )
444+ websocket_mask (self ._frame_mask , payload_bytearray )
445+ payload = payload_bytearray
446+ else :
447+ payload = b"" .join (self ._payload_fragments )
448+ self ._payload_fragments .clear ()
449+ elif self ._has_mask :
440450 assert self ._frame_mask is not None
441- if type (self ._frame_payload ) is not bytearray :
442- self ._frame_payload = bytearray (self ._frame_payload )
443- websocket_mask (self ._frame_mask , self ._frame_payload )
451+ payload_bytearray = data_cstr [f_start_pos :f_end_pos ] # type: ignore[assignment]
452+ if type (payload_bytearray ) is not bytearray : # pragma: no branch
453+ # Cython will do the conversion for us
454+ # but we need to do it for Python and we
455+ # will always get here in Python
456+ payload_bytearray = bytearray (payload_bytearray )
457+ websocket_mask (self ._frame_mask , payload_bytearray )
458+ payload = payload_bytearray
459+ else :
460+ payload = data_cstr [f_start_pos :f_end_pos ]
444461
445462 self ._handle_frame (
446- self ._frame_fin ,
447- self ._frame_opcode ,
448- self ._frame_payload ,
449- self ._compressed ,
463+ self ._frame_fin , self ._frame_opcode , payload , self ._compressed
450464 )
451- self ._frame_payload = b""
452465 self ._frame_payload_len = 0
453466 self ._state = READ_HEADER
454467
455468 # XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
456- self ._tail = (
457- data_cstr [start_pos :data_length ] if start_pos < data_length else b""
458- )
469+ self ._tail = data_cstr [start_pos :data_len ] if start_pos < data_len else b""
0 commit comments