Skip to content

KAFKA-4160: Ensure rebalance listener not called with coordinator lock #1438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 116 additions & 98 deletions kafka/coordinator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,36 @@ def time_to_next_heartbeat(self):
return float('inf')
return self.heartbeat.time_to_next_heartbeat()

def _reset_join_group_future(self):
with self._lock:
self.join_future = None

def _initiate_join_group(self):
with self._lock:
# we store the join future in case we are woken up by the user
# after beginning the rebalance in the call to poll below.
# This ensures that we do not mistakenly attempt to rejoin
# before the pending rebalance has completed.
if self.join_future is None:
self.state = MemberState.REBALANCING
self.join_future = self._send_join_group_request()

# handle join completion in the callback so that the
# callback will be invoked even if the consumer is woken up
# before finishing the rebalance
self.join_future.add_callback(self._handle_join_success)

# we handle failures below after the request finishes.
# If the join completes after having been woken up, the
# exception is ignored and we will rejoin
self.join_future.add_errback(self._handle_join_failure)

return self.join_future

def _handle_join_success(self, member_assignment_bytes):
# handle join completion in the callback so that the callback
# will be invoked even if the consumer is woken up before
# finishing the rebalance
with self._lock:
log.info("Successfully joined group %s with generation %s",
self.group_id, self._generation.generation_id)
Expand All @@ -347,6 +376,9 @@ def _handle_join_success(self, member_assignment_bytes):
self._heartbeat_thread.enable()

def _handle_join_failure(self, _):
# we handle failures below after the request finishes.
# if the join completes after having been woken up,
# the exception is ignored and we will rejoin
with self._lock:
self.state = MemberState.UNJOINED

Expand All @@ -360,92 +392,67 @@ def ensure_active_group(self, timeout_ms=None):
Raises: KafkaTimeoutError if timeout_ms is not None
"""
inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to join consumer group')
with self._client._lock, self._lock:
if self._heartbeat_thread is None:
self._start_heartbeat_thread()

while self.need_rejoin() or self._rejoin_incomplete():
self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms())

# call on_join_prepare if needed. We set a flag
# to make sure that we do not call it a second
# time if the client is woken up before a pending
# rebalance completes. This must be called on each
# iteration of the loop because an event requiring
# a rebalance (such as a metadata refresh which
# changes the matched subscription set) can occur
# while another rebalance is still in progress.
if not self.rejoining:
self._on_join_prepare(self._generation.generation_id,
self._generation.member_id)
self.rejoining = True

# ensure that there are no pending requests to the coordinator.
# This is important in particular to avoid resending a pending
# JoinGroup request.
while not self.coordinator_unknown():
if not self._client.in_flight_request_count(self.coordinator_id):
break
self._client.poll(timeout_ms=inner_timeout_ms(200))
else:
continue

# we store the join future in case we are woken up by the user
# after beginning the rebalance in the call to poll below.
# This ensures that we do not mistakenly attempt to rejoin
# before the pending rebalance has completed.
if self.join_future is None:
# Fence off the heartbeat thread explicitly so that it cannot
# interfere with the join group. Note that this must come after
# the call to _on_join_prepare since we must be able to continue
# sending heartbeats if that callback takes some time.
self._heartbeat_thread.disable()

self.state = MemberState.REBALANCING
future = self._send_join_group_request()

self.join_future = future # this should happen before adding callbacks

# handle join completion in the callback so that the
# callback will be invoked even if the consumer is woken up
# before finishing the rebalance
future.add_callback(self._handle_join_success)

# we handle failures below after the request finishes.
# If the join completes after having been woken up, the
# exception is ignored and we will rejoin
future.add_errback(self._handle_join_failure)

else:
future = self.join_future

self._client.poll(future=future, timeout_ms=inner_timeout_ms())

if not future.is_done:
raise Errors.KafkaTimeoutError()
self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms())
self._start_heartbeat_thread()
self.join_group(timeout_ms=inner_timeout_ms())

if future.succeeded():
self._on_join_complete(self._generation.generation_id,
self._generation.member_id,
self._generation.protocol,
future.value)
self.join_future = None
self.rejoining = False
self.rejoin_needed = False
def join_group(self, timeout_ms=None):
inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to join consumer group')
while self.need_rejoin():
self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms())

# call on_join_prepare if needed. We set a flag
# to make sure that we do not call it a second
# time if the client is woken up before a pending
# rebalance completes. This must be called on each
# iteration of the loop because an event requiring
# a rebalance (such as a metadata refresh which
# changes the matched subscription set) can occur
# while another rebalance is still in progress.
if not self.rejoining:
self._on_join_prepare(self._generation.generation_id,
self._generation.member_id)
self.rejoining = True

# fence off the heartbeat thread explicitly so that it cannot
# interfere with the join group. # Note that this must come after
# the call to onJoinPrepare since we must be able to continue
# sending heartbeats if that callback takes some time.
self._disable_heartbeat_thread()

# ensure that there are no pending requests to the coordinator.
# This is important in particular to avoid resending a pending
# JoinGroup request.
while not self.coordinator_unknown():
if not self._client.in_flight_request_count(self.coordinator_id):
break
self._client.poll(timeout_ms=inner_timeout_ms(200))
else:
continue

else:
self.join_future = None
exception = future.exception
if isinstance(exception, (Errors.UnknownMemberIdError,
Errors.RebalanceInProgressError,
Errors.IllegalGenerationError)):
continue
elif not future.retriable():
raise exception # pylint: disable-msg=raising-bad-type
time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000)

def _rejoin_incomplete(self):
return self.join_future is not None
future = self._initiate_join_group()
self._client.poll(future=future, timeout_ms=inner_timeout_ms())
if future.is_done:
self._reset_join_group_future()
else:
raise Errors.KafkaTimeoutError()

if future.succeeded():
self.rejoining = False
self.rejoin_needed = False
self._on_join_complete(self._generation.generation_id,
self._generation.member_id,
self._generation.protocol,
future.value)
else:
exception = future.exception
if isinstance(exception, (Errors.UnknownMemberIdError,
Errors.RebalanceInProgressError,
Errors.IllegalGenerationError)):
continue
elif not future.retriable():
raise exception # pylint: disable-msg=raising-bad-type
time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000)

def _send_join_group_request(self):
"""Join the group and return the assignment for the next generation.
Expand Down Expand Up @@ -751,23 +758,31 @@ def request_rejoin(self):
self.rejoin_needed = True

def _start_heartbeat_thread(self):
if self._heartbeat_thread is None:
log.info('Starting new heartbeat thread')
self._heartbeat_thread = HeartbeatThread(weakref.proxy(self))
self._heartbeat_thread.daemon = True
self._heartbeat_thread.start()
with self._lock:
if self._heartbeat_thread is None:
log.info('Starting new heartbeat thread')
self._heartbeat_thread = HeartbeatThread(weakref.proxy(self))
self._heartbeat_thread.daemon = True
self._heartbeat_thread.start()

def _disable_heartbeat_thread(self):
with self._lock:
if self._heartbeat_thread is not None:
self._heartbeat_thread.disable()

def _close_heartbeat_thread(self):
if hasattr(self, '_heartbeat_thread') and self._heartbeat_thread is not None:
log.info('Stopping heartbeat thread')
try:
self._heartbeat_thread.close()
except ReferenceError:
pass
self._heartbeat_thread = None
with self._lock:
if self._heartbeat_thread is not None:
log.info('Stopping heartbeat thread')
try:
self._heartbeat_thread.close()
except ReferenceError:
pass
self._heartbeat_thread = None

def __del__(self):
self._close_heartbeat_thread()
if hasattr(self, '_heartbeat_thread'):
self._close_heartbeat_thread()

def close(self):
"""Close the coordinator, leave the current group,
Expand Down Expand Up @@ -926,12 +941,15 @@ def __init__(self, coordinator):

def enable(self):
with self.coordinator._lock:
log.debug('Enabling heartbeat thread')
self.enabled = True
self.coordinator.heartbeat.reset_timeouts()
self.coordinator._lock.notify()

def disable(self):
self.enabled = False
with self.coordinator._lock:
log.debug('Disabling heartbeat thread')
self.enabled = False

def close(self):
if self.closed:
Expand Down
Loading