Skip to content

Make PythonParser resumable (alternative) #2512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Make PythonParser resumable in case of error (#2512)
* Add `timeout=None` in `SentinelConnectionManager.read_response`
* Documentation fix: password protected socket connection (#2374)
* Allow `timeout=None` in `PubSub.get_message()` to wait forever
Expand Down
81 changes: 54 additions & 27 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,13 @@ async def read_response(
class PythonParser(BaseParser):
"""Plain Python parsing class"""

__slots__ = BaseParser.__slots__ + ("encoder",)
__slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos")

def __init__(self, socket_read_size: int):
super().__init__(socket_read_size)
self.encoder: Optional[Encoder] = None
self._buffer = b""
self._pos = 0

def on_connect(self, connection: "Connection"):
"""Called when the stream connects"""
Expand All @@ -227,8 +229,11 @@ def on_disconnect(self):
if self._stream is not None:
self._stream = None
self.encoder = None
self._buffer = b""

async def can_read_destructive(self) -> bool:
if self._buffer:
return True
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
Expand All @@ -237,14 +242,38 @@ async def can_read_destructive(self) -> bool:
except asyncio.TimeoutError:
return False

async def read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
async def read_response(self, disable_decoding: bool = False):
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
raw = await self._readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)

if not self._buffer:
await self._fill_buffer()
while True:
self._pos = 0
try:
response = self._read_response(disable_decoding=disable_decoding)

except EOFError:
await self._fill_buffer()
else:
break
# Successfully parsing a response allows us to clear our parsing buffer
self._buffer = self._buffer[self._pos :]
return response

async def _fill_buffer(self):
"""
IO is performed here
"""
buffer = await self._stream.read(self._read_size)
if not buffer or not isinstance(buffer, bytes):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
self._buffer += buffer

def _read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
raw = self._readline()
response: Any
byte, response = raw[:1], raw[1:]

Expand All @@ -258,6 +287,7 @@ async def read_response(
# if the error is a ConnectionError, raise immediately so the user
# is notified
if isinstance(error, ConnectionError):
self._buffer = self._buffer[self._pos :] # Successful parse
raise error
# otherwise, we're dealing with a ResponseError that might belong
# inside a pipeline response. the connection's read_response()
Expand All @@ -275,43 +305,40 @@ async def read_response(
length = int(response)
if length == -1:
return None
response = await self._read(length)
response = self._read(length)
# multi-bulk response
elif byte == b"*":
length = int(response)
if length == -1:
return None
response = [
(await self.read_response(disable_decoding)) for _ in range(length)
]
response = [(self._read_response(disable_decoding)) for _ in range(length)]
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

async def _read(self, length: int) -> bytes:
def _read(self, length: int) -> bytes:
"""
Read `length` bytes of data. These are assumed to be followed
by a '\r\n' terminator which is subsequently discarded.
"""
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
data = await self._stream.readexactly(length + 2)
except asyncio.IncompleteReadError as error:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
return data[:-2]

async def _readline(self) -> bytes:
end = self._pos + length + 2
if len(self._buffer) < end:
raise EOFError() # Signal that we need more data
result = self._buffer[self._pos : end - 2]
self._pos = end
return result

def _readline(self) -> bytes:
"""
read an unknown number of bytes up to the next '\r\n'
line separator, which is discarded.
"""
if self._stream is None:
raise RedisError("Buffer is closed.")
data = await self._stream.readline()
if not data.endswith(b"\r\n"):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
return data[:-2]
found = self._buffer.find(b"\r\n", self._pos)
if found < 0:
raise EOFError() # signal that we need more data
result = self._buffer[self._pos : found]
self._pos = found + 2
return result


class HiredisParser(BaseParser):
Expand Down
58 changes: 42 additions & 16 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,6 @@ def read(self, length):
self._buffer.seek(self.bytes_read)
data = self._buffer.read(length)
self.bytes_read += len(data)

# purge the buffer when we've consumed it all so it doesn't
# grow forever
if self.bytes_read == self.bytes_written:
self.purge()

return data[:-2]

def readline(self):
Expand All @@ -251,23 +245,44 @@ def readline(self):
data = buf.readline()

self.bytes_read += len(data)
return data[:-2]

# purge the buffer when we've consumed it all so it doesn't
# grow forever
if self.bytes_read == self.bytes_written:
self.purge()
def get_pos(self):
"""
Get current read position
"""
return self.bytes_read

return data[:-2]
def rewind(self, pos):
"""
Rewind the buffer to a specific position, to re-start reading
"""
self.bytes_read = pos

def purge(self):
self._buffer.seek(0)
self._buffer.truncate()
self.bytes_written = 0
"""
After a successful read, purge the read part of buffer
"""
unread = self.bytes_written - self.bytes_read

# Only if we have read all of the buffer do we truncate, to
# reduce the amount of memory thrashing. This heuristic
# can be changed or removed later.
if unread > 0:
return

if unread > 0:
# move unread data to the front
view = self._buffer.getbuffer()
view[:unread] = view[-unread:]
self._buffer.truncate(unread)
self.bytes_written = unread
self.bytes_read = 0
self._buffer.seek(0)

def close(self):
try:
self.purge()
self.bytes_written = self.bytes_read = 0
self._buffer.close()
except Exception:
# issue #633 suggests the purge/close somehow raised a
Expand Down Expand Up @@ -315,6 +330,17 @@ def can_read(self, timeout):
return self._buffer and self._buffer.can_read(timeout)

def read_response(self, disable_decoding=False):
pos = self._buffer.get_pos()
try:
result = self._read_response(disable_decoding=disable_decoding)
except BaseException:
self._buffer.rewind(pos)
raise
else:
self._buffer.purge()
return result

def _read_response(self, disable_decoding=False):
raw = self._buffer.readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
Expand Down Expand Up @@ -355,7 +381,7 @@ def read_response(self, disable_decoding=False):
if length == -1:
return None
response = [
self.read_response(disable_decoding=disable_decoding)
self._read_response(disable_decoding=disable_decoding)
for i in range(length)
]
if isinstance(response, bytes) and disable_decoding is False:
Expand Down
41 changes: 41 additions & 0 deletions tests/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Various mocks for testing


class MockSocket:
"""
A class simulating an readable socket, optionally raising a
special exception every other read.
"""

class TestError(BaseException):
pass

def __init__(self, data, interrupt_every=0):
self.data = data
self.counter = 0
self.pos = 0
self.interrupt_every = interrupt_every

def tick(self):
self.counter += 1
if not self.interrupt_every:
return
if (self.counter % self.interrupt_every) == 0:
raise self.TestError()

def recv(self, bufsize):
self.tick()
bufsize = min(5, bufsize) # truncate the read size
result = self.data[self.pos : self.pos + bufsize]
self.pos += len(result)
return result

def recv_into(self, buffer, nbytes=0, flags=0):
self.tick()
if nbytes == 0:
nbytes = len(buffer)
nbytes = min(5, nbytes) # truncate the read size
result = self.data[self.pos : self.pos + nbytes]
self.pos += len(result)
buffer[: len(result)] = result
return len(result)
51 changes: 51 additions & 0 deletions tests/test_asyncio/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import asyncio

# Helper Mocking classes for the tests.


class MockStream:
"""
A class simulating an asyncio input buffer, optionally raising a
special exception every other read.
"""

class TestError(BaseException):
pass

def __init__(self, data, interrupt_every=0):
self.data = data
self.counter = 0
self.pos = 0
self.interrupt_every = interrupt_every

def tick(self):
self.counter += 1
if not self.interrupt_every:
return
if (self.counter % self.interrupt_every) == 0:
raise self.TestError()

async def read(self, want):
self.tick()
want = 5
result = self.data[self.pos : self.pos + want]
self.pos += len(result)
return result

async def readline(self):
self.tick()
find = self.data.find(b"\n", self.pos)
if find >= 0:
result = self.data[self.pos : find + 1]
else:
result = self.data[self.pos :]
self.pos += len(result)
return result

async def readexactly(self, length):
self.tick()
result = self.data[self.pos : self.pos + length]
if len(result) < length:
raise asyncio.IncompleteReadError(result, None)
self.pos += len(result)
return result
Loading