Skip to content

Fix key deserialization propagation in windows #848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions quixstreams/dataframe/windows/sliding.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def process_window(
# build a complete list otherwise expired windows could be deleted
# in state.delete_windows() and never be fetched.
expired_windows = list(
self._expired_windows(state, max_expired_window_start, collect)
self._expired_windows(key, state, max_expired_window_start, collect)
)

state.delete_windows(
Expand All @@ -289,14 +289,14 @@ def process_window(

return reversed(updated_windows), expired_windows

def _expired_windows(self, state, max_expired_window_start, collect):
def _expired_windows(self, key, state, max_expired_window_start, collect):
for window in state.expire_windows(
max_start_time=max_expired_window_start,
delete=False,
collect=collect,
end_inclusive=True,
):
(start, end), (max_timestamp, aggregated), collected, key = window
(start, end), (max_timestamp, aggregated), collected = window
if end == max_timestamp:
yield key, self._results(aggregated, collected, start, end)

Expand Down
7 changes: 2 additions & 5 deletions quixstreams/dataframe/windows/time_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,18 +229,15 @@ def expire_by_partition(

def expire_by_key(
self,
key: bytes,
key: Any,
state: WindowedState,
max_expired_start: int,
collect: bool,
) -> Iterable[WindowKeyResult]:
start = time.monotonic()
count = 0

for (
window_start,
window_end,
), aggregated, collected, _ in state.expire_windows(
for (window_start, window_end), aggregated, collected in state.expire_windows(
max_start_time=max_expired_start,
collect=collect,
):
Expand Down
50 changes: 42 additions & 8 deletions quixstreams/state/rocksdb/windowed/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@
PartitionTransactionStatus,
validate_transaction_status,
)
from quixstreams.state.exceptions import StateSerializationError
from quixstreams.state.metadata import DEFAULT_PREFIX, SEPARATOR
from quixstreams.state.recovery import ChangelogProducer
from quixstreams.state.serialization import (
DumpsFunc,
LoadsFunc,
deserialize,
serialize,
)
from quixstreams.state.types import ExpiredWindowDetail, WindowDetail
from quixstreams.state.types import (
ExpiredWindowDetail,
ExpiredWindowDetailWithKey,
WindowDetail,
)

from .metadata import (
GLOBAL_COUNTER_CF_NAME,
Expand Down Expand Up @@ -297,7 +303,7 @@ def expire_windows(

# Collect values into windows
if collect:
for (start, end), aggregated, key in windows:
for (start, end), aggregated, _ in windows:
collected = self.get_from_collection(
start=start,
# Sliding windows are inclusive on both ends
Expand All @@ -307,11 +313,11 @@ def expire_windows(
end=end + 1 if end_inclusive else end,
prefix=prefix,
)
yield ((start, end), aggregated, collected, key)
yield (start, end), aggregated, collected

else:
for window, aggregated, key in windows:
yield (window, aggregated, [], key)
for window, aggregated, _ in windows:
yield window, aggregated, []

# Delete expired windows from the state
if delete:
Expand All @@ -326,7 +332,7 @@ def expire_all_windows(
step_ms: int,
delete: bool = True,
collect: bool = False,
) -> Iterable[ExpiredWindowDetail]:
) -> Iterable[ExpiredWindowDetailWithKey]:
"""
Get all expired windows for all prefix from RocksDB up to the specified `max_end_time` timestamp.

Expand Down Expand Up @@ -360,7 +366,8 @@ def expire_all_windows(
end=end,
prefix=prefix,
)
yield (start, end), aggregated, collected, prefix
deserialized_prefix = self._deserialize_prefix(prefix)
yield (start, end), aggregated, collected, deserialized_prefix

else:
# If we don't have a saved last_expired value it means one of two cases
Expand All @@ -382,7 +389,8 @@ def expire_all_windows(
prefix=prefix,
)

yield (start, end), aggregated, collected, prefix
deserialized_prefix = self._deserialize_prefix(prefix)
yield (start, end), aggregated, collected, deserialized_prefix

if delete:
for prefix, start, end in to_delete:
Expand All @@ -394,6 +402,32 @@ def expire_all_windows(
prefix=b"", cache=self._last_expired_timestamps, timestamp_ms=last_expired
)

def _deserialize_prefix(self, prefix: bytes) -> Any:
"""
Attempt to deserialize a window prefix.

Window prefixes can be provided either as raw bytes or as other types
(e.g., dict). The `as_state()` method conditionally serializes these
prefixes to bytes only if they are not already bytes before storing.

When retrieving a prefix during partition-level windows expiration, we
don't know its original type due to this conditional serialization.
Therefore, we must first *try* to deserialize it using the configured
`loads` function.
Comment on lines +413 to +416
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we store an indicator about the prefix type, if it's a byte or something else ?

For migration, existing windows missing that information can try to deserialize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it's considerably more code, not sure if it's worth it.

Let's see what @daniil-quix thinks. But even if we decide to do this, it may be done as a separate optimisation; in the meantime, this will fix the bug.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm mostly worried about the performance hit of trying to deserialise every keys

Copy link
Collaborator

@daniil-quix daniil-quix May 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some quick tests locally, and the performance hit actually depends on the value itself.

Two almost identical values b'123123' and b'{123123 take very different amounts of time:

### Successful deserialization

%%timeit

try:
    orjson.loads(b'123123')
except Exception:
    ...

34.4 ns ± 0.182 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
### Deserialization failed
%%timeit

try:
    orjson.loads(b'{123123')
except Exception:
    ...

765 ns ± 4.51 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


If deserialization succeeds, it means the prefix was originally a
non-bytes type, and we return the deserialized object.
If deserialization fails with a `StateSerializationError`, it indicates
that the prefix was likely provided as raw bytes initially, so we
return the original `prefix` bytes.

:param prefix: The prefix bytes retrieved from storage.
"""
try:
return deserialize(prefix, loads=self._loads)
except StateSerializationError:
return prefix

def delete_windows(
self, max_start_time: int, delete_values: bool, prefix: bytes
) -> None:
Expand Down
7 changes: 5 additions & 2 deletions quixstreams/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
tuple[int, int], V, bytes
] # (start, end), aggregated, key
ExpiredWindowDetail: TypeAlias = tuple[
tuple[int, int], V, list[V], bytes
tuple[int, int], V, list[V]
] # (start, end), aggregated, collected
ExpiredWindowDetailWithKey: TypeAlias = tuple[
tuple[int, int], V, list[V], Any
] # (start, end), aggregated, collected, key


Expand Down Expand Up @@ -378,7 +381,7 @@ def expire_all_windows(
step_ms: int,
delete: bool = True,
collect: bool = False,
) -> Iterable[ExpiredWindowDetail[V]]:
) -> Iterable[ExpiredWindowDetailWithKey[V]]:
"""
Get all expired windows for all prefix from RocksDB up to the specified `max_start_time` timestamp.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def test_expire_windows(transaction_state, delete):

assert len(expired) == 2
assert expired == [
((0, 10), 1, [], b"__key__"),
((10, 20), 2, [], b"__key__"),
((0, 10), 1, []),
((10, 20), 2, []),
]

with transaction_state() as state:
Expand Down Expand Up @@ -112,8 +112,8 @@ def test_expire_windows_with_collect(transaction_state, end_inclusive):
window_1_value = ["a", "b"] if end_inclusive else ["a"]
window_2_value = ["b", "c"] if end_inclusive else ["b"]
assert expired == [
((0, 10), None, window_1_value, b"__key__"),
((10, 20), [777, None], window_2_value, b"__key__"),
((0, 10), None, window_1_value),
((10, 20), [777, None], window_2_value),
]


Expand All @@ -132,7 +132,7 @@ def test_same_keys_in_db_and_update_cache(transaction_state):
expired = list(state.expire_windows(max_start_time=max_start_time))

# Value from the cache takes precedence over the value in the db
assert expired == [((0, 10), 3, [], b"__key__")]
assert expired == [((0, 10), 3, [])]


def test_get_latest_timestamp(windowed_rocksdb_store_factory):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def test_expire_windows_expired(self, windowed_rocksdb_store_factory, delete):

assert len(expired) == 2
assert expired == [
((0, 10), 1, [], prefix),
((10, 20), 2, [], prefix),
((0, 10), 1, []),
((10, 20), 2, []),
]

with store.start_partition_transaction(0) as tx:
Expand Down Expand Up @@ -131,8 +131,8 @@ def test_expire_windows_cached(self, windowed_rocksdb_store_factory, delete):
)
assert len(expired) == 2
assert expired == [
((0, 10), 1, [], prefix),
((10, 20), 2, [], prefix),
((0, 10), 1, []),
((10, 20), 2, []),
]
assert (
tx.get_window(start_ms=0, end_ms=10, prefix=prefix) == None
Expand Down Expand Up @@ -193,7 +193,7 @@ def test_expire_windows_with_grace_expired(self, windowed_rocksdb_store_factory)
)

assert len(expired) == 1
assert expired == [((0, 10), 1, [], prefix)]
assert expired == [((0, 10), 1, [])]

def test_expire_windows_with_grace_empty(self, windowed_rocksdb_store_factory):
store = windowed_rocksdb_store_factory()
Expand Down Expand Up @@ -328,9 +328,9 @@ def test_expire_windows_multiple_windows(self, windowed_rocksdb_store_factory):
)

assert len(expired) == 3
assert expired[0] == ((0, 10), 1, [], prefix)
assert expired[1] == ((10, 20), 1, [], prefix)
assert expired[2] == ((20, 30), 1, [], prefix)
Comment on lines -331 to -333
Copy link
Collaborator

@daniil-quix daniil-quix May 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was kinda testing that the same prefix is returned :)
Let's return the key from expire_windows() to make the tests legit, even though we don't use it in the actual code.

assert expired[0] == ((0, 10), 1, [])
assert expired[1] == ((10, 20), 1, [])
assert expired[2] == ((20, 30), 1, [])

def test_get_latest_timestamp_update(self, windowed_rocksdb_store_factory):
store = windowed_rocksdb_store_factory()
Expand Down
Loading