Skip to content

Commit 1c87c82

Browse files
authored
KAFKA-4160: Ensure rebalance listener not called with coordinator lock (#1438)
1 parent 215b626 commit 1c87c82

File tree

1 file changed

+116
-98
lines changed

1 file changed

+116
-98
lines changed

kafka/coordinator/base.py

+116-98
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,36 @@ def time_to_next_heartbeat(self):
338338
return float('inf')
339339
return self.heartbeat.time_to_next_heartbeat()
340340

341+
def _reset_join_group_future(self):
342+
with self._lock:
343+
self.join_future = None
344+
345+
def _initiate_join_group(self):
346+
with self._lock:
347+
# we store the join future in case we are woken up by the user
348+
# after beginning the rebalance in the call to poll below.
349+
# This ensures that we do not mistakenly attempt to rejoin
350+
# before the pending rebalance has completed.
351+
if self.join_future is None:
352+
self.state = MemberState.REBALANCING
353+
self.join_future = self._send_join_group_request()
354+
355+
# handle join completion in the callback so that the
356+
# callback will be invoked even if the consumer is woken up
357+
# before finishing the rebalance
358+
self.join_future.add_callback(self._handle_join_success)
359+
360+
# we handle failures below after the request finishes.
361+
# If the join completes after having been woken up, the
362+
# exception is ignored and we will rejoin
363+
self.join_future.add_errback(self._handle_join_failure)
364+
365+
return self.join_future
366+
341367
def _handle_join_success(self, member_assignment_bytes):
368+
# handle join completion in the callback so that the callback
369+
# will be invoked even if the consumer is woken up before
370+
# finishing the rebalance
342371
with self._lock:
343372
log.info("Successfully joined group %s with generation %s",
344373
self.group_id, self._generation.generation_id)
@@ -347,6 +376,9 @@ def _handle_join_success(self, member_assignment_bytes):
347376
self._heartbeat_thread.enable()
348377

349378
def _handle_join_failure(self, _):
379+
# we handle failures below after the request finishes.
380+
# if the join completes after having been woken up,
381+
# the exception is ignored and we will rejoin
350382
with self._lock:
351383
self.state = MemberState.UNJOINED
352384

@@ -360,92 +392,67 @@ def ensure_active_group(self, timeout_ms=None):
360392
Raises: KafkaTimeoutError if timeout_ms is not None
361393
"""
362394
inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to join consumer group')
363-
with self._client._lock, self._lock:
364-
if self._heartbeat_thread is None:
365-
self._start_heartbeat_thread()
366-
367-
while self.need_rejoin() or self._rejoin_incomplete():
368-
self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms())
369-
370-
# call on_join_prepare if needed. We set a flag
371-
# to make sure that we do not call it a second
372-
# time if the client is woken up before a pending
373-
# rebalance completes. This must be called on each
374-
# iteration of the loop because an event requiring
375-
# a rebalance (such as a metadata refresh which
376-
# changes the matched subscription set) can occur
377-
# while another rebalance is still in progress.
378-
if not self.rejoining:
379-
self._on_join_prepare(self._generation.generation_id,
380-
self._generation.member_id)
381-
self.rejoining = True
382-
383-
# ensure that there are no pending requests to the coordinator.
384-
# This is important in particular to avoid resending a pending
385-
# JoinGroup request.
386-
while not self.coordinator_unknown():
387-
if not self._client.in_flight_request_count(self.coordinator_id):
388-
break
389-
self._client.poll(timeout_ms=inner_timeout_ms(200))
390-
else:
391-
continue
392-
393-
# we store the join future in case we are woken up by the user
394-
# after beginning the rebalance in the call to poll below.
395-
# This ensures that we do not mistakenly attempt to rejoin
396-
# before the pending rebalance has completed.
397-
if self.join_future is None:
398-
# Fence off the heartbeat thread explicitly so that it cannot
399-
# interfere with the join group. Note that this must come after
400-
# the call to _on_join_prepare since we must be able to continue
401-
# sending heartbeats if that callback takes some time.
402-
self._heartbeat_thread.disable()
403-
404-
self.state = MemberState.REBALANCING
405-
future = self._send_join_group_request()
406-
407-
self.join_future = future # this should happen before adding callbacks
408-
409-
# handle join completion in the callback so that the
410-
# callback will be invoked even if the consumer is woken up
411-
# before finishing the rebalance
412-
future.add_callback(self._handle_join_success)
413-
414-
# we handle failures below after the request finishes.
415-
# If the join completes after having been woken up, the
416-
# exception is ignored and we will rejoin
417-
future.add_errback(self._handle_join_failure)
418-
419-
else:
420-
future = self.join_future
421-
422-
self._client.poll(future=future, timeout_ms=inner_timeout_ms())
423-
424-
if not future.is_done:
425-
raise Errors.KafkaTimeoutError()
395+
self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms())
396+
self._start_heartbeat_thread()
397+
self.join_group(timeout_ms=inner_timeout_ms())
426398

427-
if future.succeeded():
428-
self._on_join_complete(self._generation.generation_id,
429-
self._generation.member_id,
430-
self._generation.protocol,
431-
future.value)
432-
self.join_future = None
433-
self.rejoining = False
434-
self.rejoin_needed = False
399+
def join_group(self, timeout_ms=None):
400+
inner_timeout_ms = timeout_ms_fn(timeout_ms, 'Timeout attempting to join consumer group')
401+
while self.need_rejoin():
402+
self.ensure_coordinator_ready(timeout_ms=inner_timeout_ms())
403+
404+
# call on_join_prepare if needed. We set a flag
405+
# to make sure that we do not call it a second
406+
# time if the client is woken up before a pending
407+
# rebalance completes. This must be called on each
408+
# iteration of the loop because an event requiring
409+
# a rebalance (such as a metadata refresh which
410+
# changes the matched subscription set) can occur
411+
# while another rebalance is still in progress.
412+
if not self.rejoining:
413+
self._on_join_prepare(self._generation.generation_id,
414+
self._generation.member_id)
415+
self.rejoining = True
416+
417+
# fence off the heartbeat thread explicitly so that it cannot
418+
# interfere with the join group. # Note that this must come after
419+
# the call to onJoinPrepare since we must be able to continue
420+
# sending heartbeats if that callback takes some time.
421+
self._disable_heartbeat_thread()
422+
423+
# ensure that there are no pending requests to the coordinator.
424+
# This is important in particular to avoid resending a pending
425+
# JoinGroup request.
426+
while not self.coordinator_unknown():
427+
if not self._client.in_flight_request_count(self.coordinator_id):
428+
break
429+
self._client.poll(timeout_ms=inner_timeout_ms(200))
430+
else:
431+
continue
435432

436-
else:
437-
self.join_future = None
438-
exception = future.exception
439-
if isinstance(exception, (Errors.UnknownMemberIdError,
440-
Errors.RebalanceInProgressError,
441-
Errors.IllegalGenerationError)):
442-
continue
443-
elif not future.retriable():
444-
raise exception # pylint: disable-msg=raising-bad-type
445-
time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000)
446-
447-
def _rejoin_incomplete(self):
448-
return self.join_future is not None
433+
future = self._initiate_join_group()
434+
self._client.poll(future=future, timeout_ms=inner_timeout_ms())
435+
if future.is_done:
436+
self._reset_join_group_future()
437+
else:
438+
raise Errors.KafkaTimeoutError()
439+
440+
if future.succeeded():
441+
self.rejoining = False
442+
self.rejoin_needed = False
443+
self._on_join_complete(self._generation.generation_id,
444+
self._generation.member_id,
445+
self._generation.protocol,
446+
future.value)
447+
else:
448+
exception = future.exception
449+
if isinstance(exception, (Errors.UnknownMemberIdError,
450+
Errors.RebalanceInProgressError,
451+
Errors.IllegalGenerationError)):
452+
continue
453+
elif not future.retriable():
454+
raise exception # pylint: disable-msg=raising-bad-type
455+
time.sleep(inner_timeout_ms(self.config['retry_backoff_ms']) / 1000)
449456

450457
def _send_join_group_request(self):
451458
"""Join the group and return the assignment for the next generation.
@@ -751,23 +758,31 @@ def request_rejoin(self):
751758
self.rejoin_needed = True
752759

753760
def _start_heartbeat_thread(self):
754-
if self._heartbeat_thread is None:
755-
log.info('Starting new heartbeat thread')
756-
self._heartbeat_thread = HeartbeatThread(weakref.proxy(self))
757-
self._heartbeat_thread.daemon = True
758-
self._heartbeat_thread.start()
761+
with self._lock:
762+
if self._heartbeat_thread is None:
763+
log.info('Starting new heartbeat thread')
764+
self._heartbeat_thread = HeartbeatThread(weakref.proxy(self))
765+
self._heartbeat_thread.daemon = True
766+
self._heartbeat_thread.start()
767+
768+
def _disable_heartbeat_thread(self):
769+
with self._lock:
770+
if self._heartbeat_thread is not None:
771+
self._heartbeat_thread.disable()
759772

760773
def _close_heartbeat_thread(self):
761-
if hasattr(self, '_heartbeat_thread') and self._heartbeat_thread is not None:
762-
log.info('Stopping heartbeat thread')
763-
try:
764-
self._heartbeat_thread.close()
765-
except ReferenceError:
766-
pass
767-
self._heartbeat_thread = None
774+
with self._lock:
775+
if self._heartbeat_thread is not None:
776+
log.info('Stopping heartbeat thread')
777+
try:
778+
self._heartbeat_thread.close()
779+
except ReferenceError:
780+
pass
781+
self._heartbeat_thread = None
768782

769783
def __del__(self):
770-
self._close_heartbeat_thread()
784+
if hasattr(self, '_heartbeat_thread'):
785+
self._close_heartbeat_thread()
771786

772787
def close(self):
773788
"""Close the coordinator, leave the current group,
@@ -926,12 +941,15 @@ def __init__(self, coordinator):
926941

927942
def enable(self):
928943
with self.coordinator._lock:
944+
log.debug('Enabling heartbeat thread')
929945
self.enabled = True
930946
self.coordinator.heartbeat.reset_timeouts()
931947
self.coordinator._lock.notify()
932948

933949
def disable(self):
934-
self.enabled = False
950+
with self.coordinator._lock:
951+
log.debug('Disabling heartbeat thread')
952+
self.enabled = False
935953

936954
def close(self):
937955
if self.closed:

0 commit comments

Comments
 (0)