From 243f5a70053d7edeb566efd9abbff40118da4248 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sun, 23 Mar 2025 10:10:21 -0700 Subject: [PATCH 1/8] KIP-74: Manage assigned partition order in consumer --- kafka/consumer/fetcher.py | 31 ++++++++++---------- kafka/consumer/subscription_state.py | 44 ++++++++++++++++------------ 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index 90dfdbbbc..068ee8def 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -581,7 +581,7 @@ def _create_fetch_requests(self): # create the fetch info as a dict of lists of partition info tuples # which can be passed to FetchRequest() via .items() version = self._client.api_version(FetchRequest, max_version=10) - fetchable = collections.defaultdict(dict) + fetchable = collections.defaultdict(collections.OrderedDict) for partition in self._fetchable_partitions(): node_id = self._client.cluster.leader_for_partition(partition) @@ -695,8 +695,6 @@ def _handle_fetch_response(self, node_id, fetch_offsets, send_time, response): for partition_data in partitions]) metric_aggregator = FetchResponseMetricAggregator(self._sensors, partitions) - # randomized ordering should improve balance for short-lived consumers - random.shuffle(response.topics) for topic, partitions in response.topics: random.shuffle(partitions) for partition_data in partitions: @@ -757,7 +755,6 @@ def _parse_fetched_data(self, completed_fetch): self.config['value_deserializer'], self.config['check_crcs'], completed_fetch.metric_aggregator) - return parsed_records elif records.size_in_bytes() > 0: # we did not read a single message from a non-empty # buffer because that message's size is larger than @@ -805,7 +802,9 @@ def _parse_fetched_data(self, completed_fetch): if parsed_records is None: completed_fetch.metric_aggregator.record(tp, 0, 0) - return None + if parsed_records is None or parsed_records.bytes_read > 0: + self._subscriptions.move_partition_to_end(tp) + return parsed_records def close(self): if self._next_partition_records is not None: @@ -943,6 +942,13 @@ def __init__(self, node_id): self.session_partitions = {} def build_next(self, next_partitions): + """ + Arguments: + next_partitions (dict): TopicPartition -> TopicPartitionState + + Returns: + FetchRequestData + """ if self.next_metadata.is_full: log.debug("Built full fetch %s for node %s with %s partition(s).", self.next_metadata, self.node_id, len(next_partitions)) @@ -965,8 +971,8 @@ def build_next(self, next_partitions): altered.add(tp) log.debug("Built incremental fetch %s for node %s. Added %s, altered %s, removed %s out of %s", - self.next_metadata, self.node_id, added, altered, removed, self.session_partitions.keys()) - to_send = {tp: next_partitions[tp] for tp in (added | altered)} + self.next_metadata, self.node_id, added, altered, removed, self.session_partitions.keys()) + to_send = collections.OrderedDict({tp: next_partitions[tp] for tp in next_partitions if tp in (added | altered)}) return FetchRequestData(to_send, removed, self.next_metadata) def handle_response(self, response): @@ -1106,18 +1112,11 @@ def epoch(self): @property def to_send(self): # Return as list of [(topic, [(partition, ...), ...]), ...] - # so it an be passed directly to encoder + # so it can be passed directly to encoder partition_data = collections.defaultdict(list) for tp, partition_info in six.iteritems(self._to_send): partition_data[tp.topic].append(partition_info) - # As of version == 3 partitions will be returned in order as - # they are requested, so to avoid starvation with - # `fetch_max_bytes` option we need this shuffle - # NOTE: we do have partition_data in random order due to usage - # of unordered structures like dicts, but that does not - # guarantee equal distribution, and starting in Python3.6 - # dicts retain insert order. - return random.sample(list(partition_data.items()), k=len(partition_data)) + return list(partition_data.items()) @property def to_forget(self): diff --git a/kafka/consumer/subscription_state.py b/kafka/consumer/subscription_state.py index a1675c724..47baed964 100644 --- a/kafka/consumer/subscription_state.py +++ b/kafka/consumer/subscription_state.py @@ -1,11 +1,13 @@ from __future__ import absolute_import import abc +from collections import defaultdict, OrderedDict try: from collections import Sequence except ImportError: from collections.abc import Sequence import logging +import random import re from kafka.vendor import six @@ -68,7 +70,7 @@ def __init__(self, offset_reset_strategy='earliest'): self.subscribed_pattern = None # regex str or None self._group_subscription = set() self._user_assignment = set() - self.assignment = dict() + self.assignment = OrderedDict() self.listener = None # initialize to true for the consumers to fetch offset upon starting up @@ -200,14 +202,8 @@ def assign_from_user(self, partitions): if self._user_assignment != set(partitions): self._user_assignment = set(partitions) - - for partition in partitions: - if partition not in self.assignment: - self._add_assigned_partition(partition) - - for tp in set(self.assignment.keys()) - self._user_assignment: - del self.assignment[tp] - + self._set_assignment({partition: self.assignment.get(partition, TopicPartitionState()) + for partition in partitions}) self.needs_fetch_committed_offsets = True def assign_from_subscribed(self, assignments): @@ -229,13 +225,25 @@ def assign_from_subscribed(self, assignments): if tp.topic not in self.subscription: raise ValueError("Assigned partition %s for non-subscribed topic." % (tp,)) - # after rebalancing, we always reinitialize the assignment state - self.assignment.clear() - for tp in assignments: - self._add_assigned_partition(tp) + # after rebalancing, we always reinitialize the assignment value + # randomized ordering should improve balance for short-lived consumers + self._set_assignment({partition: TopicPartitionState() for partition in assignments}, randomize=True) self.needs_fetch_committed_offsets = True log.info("Updated partition assignment: %s", assignments) + def _set_assignment(self, partition_states, randomize=False): + """Batch partition assignment by topic (self.assignment is OrderedDict)""" + self.assignment.clear() + topics = [tp.topic for tp in six.iterkeys(partition_states)] + if randomize: + random.shuffle(topics) + topic_partitions = OrderedDict({topic: [] for topic in topics}) + for tp in six.iterkeys(partition_states): + topic_partitions[tp.topic].append(tp) + for topic in six.iterkeys(topic_partitions): + for tp in topic_partitions[topic]: + self.assignment[tp] = partition_states[tp] + def unsubscribe(self): """Clear all topic subscriptions and partition assignments""" self.subscription = None @@ -283,11 +291,11 @@ def paused_partitions(self): if self.is_paused(partition)) def fetchable_partitions(self): - """Return set of TopicPartitions that should be Fetched.""" - fetchable = set() + """Return ordered list of TopicPartitions that should be Fetched.""" + fetchable = list() for partition, state in six.iteritems(self.assignment): if state.is_fetchable(): - fetchable.add(partition) + fetchable.append(partition) return fetchable def partitions_auto_assigned(self): @@ -348,8 +356,8 @@ def pause(self, partition): def resume(self, partition): self.assignment[partition].resume() - def _add_assigned_partition(self, partition): - self.assignment[partition] = TopicPartitionState() + def move_partition_to_end(self, partition): + self.assignment.move_to_end(partition) class TopicPartitionState(object): From 3310475050f72d5b6a7ba06f3f9989cb1827fd45 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 24 Mar 2025 09:46:35 -0700 Subject: [PATCH 2/8] avoid KeyError in move_partition_to_end --- kafka/consumer/subscription_state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kafka/consumer/subscription_state.py b/kafka/consumer/subscription_state.py index 47baed964..07a1a109d 100644 --- a/kafka/consumer/subscription_state.py +++ b/kafka/consumer/subscription_state.py @@ -357,7 +357,8 @@ def resume(self, partition): self.assignment[partition].resume() def move_partition_to_end(self, partition): - self.assignment.move_to_end(partition) + if partition in self.assignment: + self.assignment.move_to_end(partition) class TopicPartitionState(object): From a593a58bccaa1d5448e4fbf5fd7f1f5757577c8c Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 24 Mar 2025 09:46:59 -0700 Subject: [PATCH 3/8] Remove unused iterator_refetch_records config --- kafka/consumer/fetcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index 068ee8def..06c1472e5 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -57,7 +57,6 @@ class Fetcher(six.Iterator): 'max_partition_fetch_bytes': 1048576, 'max_poll_records': sys.maxsize, 'check_crcs': True, - 'iterator_refetch_records': 1, # undocumented -- interface may change 'metric_group_prefix': 'consumer', 'retry_backoff_ms': 100, 'enable_incremental_fetch_sessions': True, From 3875ce5f7c8ec70afb8e8e23ef0d18fb15c615c4 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 24 Mar 2025 09:47:34 -0700 Subject: [PATCH 4/8] fix fetchable_partitions discard --- kafka/consumer/fetcher.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index 06c1472e5..3fcbeec94 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -561,13 +561,11 @@ def _handle_list_offsets_response(self, future, response): def _fetchable_partitions(self): fetchable = self._subscriptions.fetchable_partitions() # do not fetch a partition if we have a pending fetch response to process + discard = {fetch.topic_partition for fetch in self._completed_fetches} current = self._next_partition_records - pending = copy.copy(self._completed_fetches) if current: - fetchable.discard(current.topic_partition) - for fetch in pending: - fetchable.discard(fetch.topic_partition) - return fetchable + discard.add(current.topic_partition) + return [tp for tp in fetchable if tp not in discard] def _create_fetch_requests(self): """Create fetch requests for all assigned partitions, grouped by node. From e19ca5b0ba906567533841ddb29eb8236763487d Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 24 Mar 2025 09:47:53 -0700 Subject: [PATCH 5/8] do not shuffle response partition order --- kafka/consumer/fetcher.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index 3fcbeec94..62981e36c 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -4,7 +4,6 @@ import copy import itertools import logging -import random import sys import time @@ -693,7 +692,6 @@ def _handle_fetch_response(self, node_id, fetch_offsets, send_time, response): metric_aggregator = FetchResponseMetricAggregator(self._sensors, partitions) for topic, partitions in response.topics: - random.shuffle(partitions) for partition_data in partitions: tp = TopicPartition(topic, partition_data[0]) fetch_offset = fetch_offsets[tp] From aaf015ca0a6ee552a867fd1150d5180f462f764c Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 24 Mar 2025 09:48:32 -0700 Subject: [PATCH 6/8] move_partition_to_end on error or drain --- kafka/consumer/fetcher.py | 71 ++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index 62981e36c..813bcb048 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -726,8 +726,6 @@ def _parse_fetched_data(self, completed_fetch): " since it is no longer fetchable", tp) elif error_type is Errors.NoError: - self._subscriptions.assignment[tp].highwater = highwater - # we are interested in this fetch only if the beginning # offset (of the *request*) matches the current consumed position # Note that the *response* may return a messageset that starts @@ -741,29 +739,35 @@ def _parse_fetched_data(self, completed_fetch): return None records = MemoryRecords(completed_fetch.partition_data[-1]) - if records.has_next(): - log.debug("Adding fetched record for partition %s with" - " offset %d to buffered record list", tp, - position.offset) - parsed_records = self.PartitionRecords(fetch_offset, tp, records, - self.config['key_deserializer'], - self.config['value_deserializer'], - self.config['check_crcs'], - completed_fetch.metric_aggregator) - elif records.size_in_bytes() > 0: - # we did not read a single message from a non-empty - # buffer because that message's size is larger than - # fetch size, in this case record this exception - record_too_large_partitions = {tp: fetch_offset} - raise RecordTooLargeError( - "There are some messages at [Partition=Offset]: %s " - " whose size is larger than the fetch size %s" - " and hence cannot be ever returned." - " Increase the fetch size, or decrease the maximum message" - " size the broker will allow." % ( - record_too_large_partitions, - self.config['max_partition_fetch_bytes']), - record_too_large_partitions) + log.debug("Preparing to read %s bytes of data for partition %s with offset %d", + records.size_in_bytes(), tp, fetch_offset) + parsed_records = self.PartitionRecords(fetch_offset, tp, records, + self.config['key_deserializer'], + self.config['value_deserializer'], + self.config['check_crcs'], + completed_fetch.metric_aggregator, + self._on_partition_records_drain) + if not records.has_next() and records.size_in_bytes() > 0: + if completed_fetch.response_version < 3: + # Implement the pre KIP-74 behavior of throwing a RecordTooLargeException. + record_too_large_partitions = {tp: fetch_offset} + raise RecordTooLargeError( + "There are some messages at [Partition=Offset]: %s " + " whose size is larger than the fetch size %s" + " and hence cannot be ever returned. Please condier upgrading your broker to 0.10.1.0 or" + " newer to avoid this issue. Alternatively, increase the fetch size on the client (using" + " max_partition_fetch_bytes)" % ( + record_too_large_partitions, + self.config['max_partition_fetch_bytes']), + record_too_large_partitions) + else: + # This should not happen with brokers that support FetchRequest/Response V3 or higher (i.e. KIP-74) + raise Errors.KafkaError("Failed to make progress reading messages at %s=%s." + " Received a non-empty fetch response from the server, but no" + " complete records were found." % (tp, fetch_offset)) + + if highwater >= 0: + self._subscriptions.assignment[tp].highwater = highwater elif error_type in (Errors.NotLeaderForPartitionError, Errors.ReplicaNotAvailableError, @@ -797,16 +801,25 @@ def _parse_fetched_data(self, completed_fetch): if parsed_records is None: completed_fetch.metric_aggregator.record(tp, 0, 0) - if parsed_records is None or parsed_records.bytes_read > 0: - self._subscriptions.move_partition_to_end(tp) + if error_type is not Errors.NoError: + # we move the partition to the end if there was an error. This way, it's more likely that partitions for + # the same topic can remain together (allowing for more efficient serialization). + self._subscriptions.move_partition_to_end(tp) + return parsed_records + def _on_partition_records_drain(self, partition_records): + # we move the partition to the end if we received some bytes. This way, it's more likely that partitions + # for the same topic can remain together (allowing for more efficient serialization). + if partition_records.bytes_read > 0: + self._subscriptions.move_partition_to_end(partition_records.topic_partition) + def close(self): if self._next_partition_records is not None: self._next_partition_records.drain() class PartitionRecords(object): - def __init__(self, fetch_offset, tp, records, key_deserializer, value_deserializer, check_crcs, metric_aggregator): + def __init__(self, fetch_offset, tp, records, key_deserializer, value_deserializer, check_crcs, metric_aggregator, on_drain): self.fetch_offset = fetch_offset self.topic_partition = tp self.leader_epoch = -1 @@ -818,6 +831,7 @@ def __init__(self, fetch_offset, tp, records, key_deserializer, value_deserializ self.record_iterator = itertools.dropwhile( self._maybe_skip_record, self._unpack_records(tp, records, key_deserializer, value_deserializer)) + self.on_drain = on_drain def _maybe_skip_record(self, record): # When fetching an offset that is in the middle of a @@ -839,6 +853,7 @@ def drain(self): if self.record_iterator is not None: self.record_iterator = None self.metric_aggregator.record(self.topic_partition, self.bytes_read, self.records_read) + self.on_drain(self) def take(self, n=None): return list(itertools.islice(self.record_iterator, 0, n)) From 89c8dd3dfdabf20332d282a25d3a540eebe65ebe Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 24 Mar 2025 09:54:03 -0700 Subject: [PATCH 7/8] fetcher tests --- test/test_fetcher.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_fetcher.py b/test/test_fetcher.py index f6e1cf5f4..7822a6f1f 100644 --- a/test/test_fetcher.py +++ b/test/test_fetcher.py @@ -451,7 +451,7 @@ def test__unpack_records(mocker): (None, b"c", None), ] memory_records = MemoryRecords(_build_record_batch(messages)) - part_records = Fetcher.PartitionRecords(0, tp, memory_records, None, None, False, mocker.MagicMock()) + part_records = Fetcher.PartitionRecords(0, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None) records = list(part_records.record_iterator) assert len(records) == 3 assert all(map(lambda x: isinstance(x, ConsumerRecord), records)) @@ -556,7 +556,7 @@ def test_partition_records_offset(mocker): tp = TopicPartition('foo', 0) messages = [(None, b'msg', None) for i in range(batch_start, batch_end)] memory_records = MemoryRecords(_build_record_batch(messages, offset=batch_start)) - records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock()) + records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None) assert records assert records.next_fetch_offset == fetch_offset msgs = records.take(1) @@ -573,7 +573,7 @@ def test_partition_records_offset(mocker): def test_partition_records_empty(mocker): tp = TopicPartition('foo', 0) memory_records = MemoryRecords(_build_record_batch([])) - records = Fetcher.PartitionRecords(0, tp, memory_records, None, None, False, mocker.MagicMock()) + records = Fetcher.PartitionRecords(0, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None) msgs = records.take() assert len(msgs) == 0 assert not records @@ -586,7 +586,7 @@ def test_partition_records_no_fetch_offset(mocker): tp = TopicPartition('foo', 0) messages = [(None, b'msg', None) for i in range(batch_start, batch_end)] memory_records = MemoryRecords(_build_record_batch(messages, offset=batch_start)) - records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock()) + records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None) msgs = records.take() assert len(msgs) == 0 assert not records @@ -610,7 +610,7 @@ def test_partition_records_compacted_offset(mocker): builder.append(key=None, value=b'msg', timestamp=None, headers=[]) builder.close() memory_records = MemoryRecords(builder.buffer()) - records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock()) + records = Fetcher.PartitionRecords(fetch_offset, tp, memory_records, None, None, False, mocker.MagicMock(), lambda x: None) msgs = records.take() assert len(msgs) == batch_end - fetch_offset - 1 assert msgs[0].offset == fetch_offset + 1 From 2328cbdc7cf3ad0a01e788666d4bd63bacefb601 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 24 Mar 2025 11:14:30 -0700 Subject: [PATCH 8/8] Dont return empty record lists from fetcher.fetched_records --- kafka/consumer/fetcher.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index 813bcb048..4d73ef435 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -378,10 +378,13 @@ def _append(self, drained, part, max_records, update_offsets): # as long as the partition is still assigned position = self._subscriptions.assignment[tp].position if part.next_fetch_offset == position.offset: - part_records = part.take(max_records) log.debug("Returning fetched records at offset %d for assigned" " partition %s", position.offset, tp) - drained[tp].extend(part_records) + part_records = part.take(max_records) + # list.extend([]) is a noop, but because drained is a defaultdict + # we should avoid initializing the default list unless there are records + if part_records: + drained[tp].extend(part_records) # We want to increment subscription position if (1) we're using consumer.poll(), # or (2) we didn't return any records (consumer iterator will update position # when each message is yielded). There may be edge cases where we re-fetch records