Skip to content

KIP-74: Manage assigned partition order in consumer #2562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 64 additions & 52 deletions kafka/consumer/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import copy
import itertools
import logging
import random
import sys
import time

Expand Down Expand Up @@ -57,7 +56,6 @@ class Fetcher(six.Iterator):
'max_partition_fetch_bytes': 1048576,
'max_poll_records': sys.maxsize,
'check_crcs': True,
'iterator_refetch_records': 1, # undocumented -- interface may change
'metric_group_prefix': 'consumer',
'retry_backoff_ms': 100,
'enable_incremental_fetch_sessions': True,
Expand Down Expand Up @@ -380,10 +378,13 @@ def _append(self, drained, part, max_records, update_offsets):
# as long as the partition is still assigned
position = self._subscriptions.assignment[tp].position
if part.next_fetch_offset == position.offset:
part_records = part.take(max_records)
log.debug("Returning fetched records at offset %d for assigned"
" partition %s", position.offset, tp)
drained[tp].extend(part_records)
part_records = part.take(max_records)
# list.extend([]) is a noop, but because drained is a defaultdict
# we should avoid initializing the default list unless there are records
if part_records:
drained[tp].extend(part_records)
# We want to increment subscription position if (1) we're using consumer.poll(),
# or (2) we didn't return any records (consumer iterator will update position
# when each message is yielded). There may be edge cases where we re-fetch records
Expand Down Expand Up @@ -562,13 +563,11 @@ def _handle_list_offsets_response(self, future, response):
def _fetchable_partitions(self):
fetchable = self._subscriptions.fetchable_partitions()
# do not fetch a partition if we have a pending fetch response to process
discard = {fetch.topic_partition for fetch in self._completed_fetches}
current = self._next_partition_records
pending = copy.copy(self._completed_fetches)
if current:
fetchable.discard(current.topic_partition)
for fetch in pending:
fetchable.discard(fetch.topic_partition)
return fetchable
discard.add(current.topic_partition)
return [tp for tp in fetchable if tp not in discard]

def _create_fetch_requests(self):
"""Create fetch requests for all assigned partitions, grouped by node.
Expand All @@ -581,7 +580,7 @@ def _create_fetch_requests(self):
# create the fetch info as a dict of lists of partition info tuples
# which can be passed to FetchRequest() via .items()
version = self._client.api_version(FetchRequest, max_version=10)
fetchable = collections.defaultdict(dict)
fetchable = collections.defaultdict(collections.OrderedDict)

for partition in self._fetchable_partitions():
node_id = self._client.cluster.leader_for_partition(partition)
Expand Down Expand Up @@ -695,10 +694,7 @@ def _handle_fetch_response(self, node_id, fetch_offsets, send_time, response):
for partition_data in partitions])
metric_aggregator = FetchResponseMetricAggregator(self._sensors, partitions)

# randomized ordering should improve balance for short-lived consumers
random.shuffle(response.topics)
for topic, partitions in response.topics:
random.shuffle(partitions)
for partition_data in partitions:
tp = TopicPartition(topic, partition_data[0])
fetch_offset = fetch_offsets[tp]
Expand Down Expand Up @@ -733,8 +729,6 @@ def _parse_fetched_data(self, completed_fetch):
" since it is no longer fetchable", tp)

elif error_type is Errors.NoError:
self._subscriptions.assignment[tp].highwater = highwater

# we are interested in this fetch only if the beginning
# offset (of the *request*) matches the current consumed position
# Note that the *response* may return a messageset that starts
Expand All @@ -748,30 +742,35 @@ def _parse_fetched_data(self, completed_fetch):
return None

records = MemoryRecords(completed_fetch.partition_data[-1])
if records.has_next():
log.debug("Adding fetched record for partition %s with"
" offset %d to buffered record list", tp,
position.offset)
parsed_records = self.PartitionRecords(fetch_offset, tp, records,
self.config['key_deserializer'],
self.config['value_deserializer'],
self.config['check_crcs'],
completed_fetch.metric_aggregator)
return parsed_records
elif records.size_in_bytes() > 0:
# we did not read a single message from a non-empty
# buffer because that message's size is larger than
# fetch size, in this case record this exception
record_too_large_partitions = {tp: fetch_offset}
raise RecordTooLargeError(
"There are some messages at [Partition=Offset]: %s "
" whose size is larger than the fetch size %s"
" and hence cannot be ever returned."
" Increase the fetch size, or decrease the maximum message"
" size the broker will allow." % (
record_too_large_partitions,
self.config['max_partition_fetch_bytes']),
record_too_large_partitions)
log.debug("Preparing to read %s bytes of data for partition %s with offset %d",
records.size_in_bytes(), tp, fetch_offset)
parsed_records = self.PartitionRecords(fetch_offset, tp, records,
self.config['key_deserializer'],
self.config['value_deserializer'],
self.config['check_crcs'],
completed_fetch.metric_aggregator,
self._on_partition_records_drain)
if not records.has_next() and records.size_in_bytes() > 0:
if completed_fetch.response_version < 3:
# Implement the pre KIP-74 behavior of throwing a RecordTooLargeException.
record_too_large_partitions = {tp: fetch_offset}
raise RecordTooLargeError(
"There are some messages at [Partition=Offset]: %s "
" whose size is larger than the fetch size %s"
" and hence cannot be ever returned. Please condier upgrading your broker to 0.10.1.0 or"
" newer to avoid this issue. Alternatively, increase the fetch size on the client (using"
" max_partition_fetch_bytes)" % (
record_too_large_partitions,
self.config['max_partition_fetch_bytes']),
record_too_large_partitions)
else:
# This should not happen with brokers that support FetchRequest/Response V3 or higher (i.e. KIP-74)
raise Errors.KafkaError("Failed to make progress reading messages at %s=%s."
" Received a non-empty fetch response from the server, but no"
" complete records were found." % (tp, fetch_offset))

if highwater >= 0:
self._subscriptions.assignment[tp].highwater = highwater

elif error_type in (Errors.NotLeaderForPartitionError,
Errors.ReplicaNotAvailableError,
Expand Down Expand Up @@ -805,14 +804,25 @@ def _parse_fetched_data(self, completed_fetch):
if parsed_records is None:
completed_fetch.metric_aggregator.record(tp, 0, 0)

return None
if error_type is not Errors.NoError:
# we move the partition to the end if there was an error. This way, it's more likely that partitions for
# the same topic can remain together (allowing for more efficient serialization).
self._subscriptions.move_partition_to_end(tp)

return parsed_records

def _on_partition_records_drain(self, partition_records):
# we move the partition to the end if we received some bytes. This way, it's more likely that partitions
# for the same topic can remain together (allowing for more efficient serialization).
if partition_records.bytes_read > 0:
self._subscriptions.move_partition_to_end(partition_records.topic_partition)

def close(self):
if self._next_partition_records is not None:
self._next_partition_records.drain()

class PartitionRecords(object):
def __init__(self, fetch_offset, tp, records, key_deserializer, value_deserializer, check_crcs, metric_aggregator):
def __init__(self, fetch_offset, tp, records, key_deserializer, value_deserializer, check_crcs, metric_aggregator, on_drain):
self.fetch_offset = fetch_offset
self.topic_partition = tp
self.leader_epoch = -1
Expand All @@ -824,6 +834,7 @@ def __init__(self, fetch_offset, tp, records, key_deserializer, value_deserializ
self.record_iterator = itertools.dropwhile(
self._maybe_skip_record,
self._unpack_records(tp, records, key_deserializer, value_deserializer))
self.on_drain = on_drain

def _maybe_skip_record(self, record):
# When fetching an offset that is in the middle of a
Expand All @@ -845,6 +856,7 @@ def drain(self):
if self.record_iterator is not None:
self.record_iterator = None
self.metric_aggregator.record(self.topic_partition, self.bytes_read, self.records_read)
self.on_drain(self)

def take(self, n=None):
return list(itertools.islice(self.record_iterator, 0, n))
Expand Down Expand Up @@ -943,6 +955,13 @@ def __init__(self, node_id):
self.session_partitions = {}

def build_next(self, next_partitions):
"""
Arguments:
next_partitions (dict): TopicPartition -> TopicPartitionState

Returns:
FetchRequestData
"""
if self.next_metadata.is_full:
log.debug("Built full fetch %s for node %s with %s partition(s).",
self.next_metadata, self.node_id, len(next_partitions))
Expand All @@ -965,8 +984,8 @@ def build_next(self, next_partitions):
altered.add(tp)

log.debug("Built incremental fetch %s for node %s. Added %s, altered %s, removed %s out of %s",
self.next_metadata, self.node_id, added, altered, removed, self.session_partitions.keys())
to_send = {tp: next_partitions[tp] for tp in (added | altered)}
self.next_metadata, self.node_id, added, altered, removed, self.session_partitions.keys())
to_send = collections.OrderedDict({tp: next_partitions[tp] for tp in next_partitions if tp in (added | altered)})
return FetchRequestData(to_send, removed, self.next_metadata)

def handle_response(self, response):
Expand Down Expand Up @@ -1106,18 +1125,11 @@ def epoch(self):
@property
def to_send(self):
# Return as list of [(topic, [(partition, ...), ...]), ...]
# so it an be passed directly to encoder
# so it can be passed directly to encoder
partition_data = collections.defaultdict(list)
for tp, partition_info in six.iteritems(self._to_send):
partition_data[tp.topic].append(partition_info)
# As of version == 3 partitions will be returned in order as
# they are requested, so to avoid starvation with
# `fetch_max_bytes` option we need this shuffle
# NOTE: we do have partition_data in random order due to usage
# of unordered structures like dicts, but that does not
# guarantee equal distribution, and starting in Python3.6
# dicts retain insert order.
return random.sample(list(partition_data.items()), k=len(partition_data))
return list(partition_data.items())

@property
def to_forget(self):
Expand Down
45 changes: 27 additions & 18 deletions kafka/consumer/subscription_state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import absolute_import

import abc
from collections import defaultdict, OrderedDict
try:
from collections import Sequence
except ImportError:
from collections.abc import Sequence
import logging
import random
import re

from kafka.vendor import six
Expand Down Expand Up @@ -68,7 +70,7 @@ def __init__(self, offset_reset_strategy='earliest'):
self.subscribed_pattern = None # regex str or None
self._group_subscription = set()
self._user_assignment = set()
self.assignment = dict()
self.assignment = OrderedDict()
self.listener = None

# initialize to true for the consumers to fetch offset upon starting up
Expand Down Expand Up @@ -200,14 +202,8 @@ def assign_from_user(self, partitions):

if self._user_assignment != set(partitions):
self._user_assignment = set(partitions)

for partition in partitions:
if partition not in self.assignment:
self._add_assigned_partition(partition)

for tp in set(self.assignment.keys()) - self._user_assignment:
del self.assignment[tp]

self._set_assignment({partition: self.assignment.get(partition, TopicPartitionState())
for partition in partitions})
self.needs_fetch_committed_offsets = True

def assign_from_subscribed(self, assignments):
Expand All @@ -229,13 +225,25 @@ def assign_from_subscribed(self, assignments):
if tp.topic not in self.subscription:
raise ValueError("Assigned partition %s for non-subscribed topic." % (tp,))

# after rebalancing, we always reinitialize the assignment state
self.assignment.clear()
for tp in assignments:
self._add_assigned_partition(tp)
# after rebalancing, we always reinitialize the assignment value
# randomized ordering should improve balance for short-lived consumers
self._set_assignment({partition: TopicPartitionState() for partition in assignments}, randomize=True)
self.needs_fetch_committed_offsets = True
log.info("Updated partition assignment: %s", assignments)

def _set_assignment(self, partition_states, randomize=False):
"""Batch partition assignment by topic (self.assignment is OrderedDict)"""
self.assignment.clear()
topics = [tp.topic for tp in six.iterkeys(partition_states)]
if randomize:
random.shuffle(topics)
topic_partitions = OrderedDict({topic: [] for topic in topics})
for tp in six.iterkeys(partition_states):
topic_partitions[tp.topic].append(tp)
for topic in six.iterkeys(topic_partitions):
for tp in topic_partitions[topic]:
self.assignment[tp] = partition_states[tp]

def unsubscribe(self):
"""Clear all topic subscriptions and partition assignments"""
self.subscription = None
Expand Down Expand Up @@ -283,11 +291,11 @@ def paused_partitions(self):
if self.is_paused(partition))

def fetchable_partitions(self):
"""Return set of TopicPartitions that should be Fetched."""
fetchable = set()
"""Return ordered list of TopicPartitions that should be Fetched."""
fetchable = list()
for partition, state in six.iteritems(self.assignment):
if state.is_fetchable():
fetchable.add(partition)
fetchable.append(partition)
return fetchable

def partitions_auto_assigned(self):
Expand Down Expand Up @@ -348,8 +356,9 @@ def pause(self, partition):
def resume(self, partition):
self.assignment[partition].resume()

def _add_assigned_partition(self, partition):
self.assignment[partition] = TopicPartitionState()
def move_partition_to_end(self, partition):
if partition in self.assignment:
self.assignment.move_to_end(partition)


class TopicPartitionState(object):
Expand Down
10 changes: 5 additions & 5 deletions test/test_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def test__unpack_records(mocker):
(None, b"c", None),
]
memory_records = MemoryRecords(_build_record_batch(messages))
part_records = Fetcher.PartitionRecords(0, tp, memory_records, None, None, False, mocker.MagicMock())
part_records = Fetcher.PartitionRecords(0, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None)
records = list(part_records.record_iterator)
assert len(records) == 3
assert all(map(lambda x: isinstance(x, ConsumerRecord), records))
Expand Down Expand Up @@ -556,7 +556,7 @@ def test_partition_records_offset(mocker):
tp = TopicPartition('foo', 0)
messages = [(None, b'msg', None) for i in range(batch_start, batch_end)]
memory_records = MemoryRecords(_build_record_batch(messages, offset=batch_start))
records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock())
records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None)
assert records
assert records.next_fetch_offset == fetch_offset
msgs = records.take(1)
Expand All @@ -573,7 +573,7 @@ def test_partition_records_offset(mocker):
def test_partition_records_empty(mocker):
tp = TopicPartition('foo', 0)
memory_records = MemoryRecords(_build_record_batch([]))
records = Fetcher.PartitionRecords(0, tp, memory_records, None, None, False, mocker.MagicMock())
records = Fetcher.PartitionRecords(0, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None)
msgs = records.take()
assert len(msgs) == 0
assert not records
Expand All @@ -586,7 +586,7 @@ def test_partition_records_no_fetch_offset(mocker):
tp = TopicPartition('foo', 0)
messages = [(None, b'msg', None) for i in range(batch_start, batch_end)]
memory_records = MemoryRecords(_build_record_batch(messages, offset=batch_start))
records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock())
records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None)
msgs = records.take()
assert len(msgs) == 0
assert not records
Expand All @@ -610,7 +610,7 @@ def test_partition_records_compacted_offset(mocker):
builder.append(key=None, value=b'msg', timestamp=None, headers=[])
builder.close()
memory_records = MemoryRecords(builder.buffer())
records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock())
records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None)
msgs = records.take()
assert len(msgs) == batch_end - fetch_offset - 1
assert msgs[0].offset == fetch_offset + 1
Loading