Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion images/airflow/2.10.1/python/mwaa/celery/sqs_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,11 @@ def _size(self, queue):
resp = c.get_queue_attributes(
QueueUrl=url, AttributeNames=["ApproximateNumberOfMessages"]
)
return int(resp["Attributes"]["ApproximateNumberOfMessages"])
try:
return int(resp["Attributes"]["ApproximateNumberOfMessages"])
except Exception:
logger.error("Unexpected response from SQS get_queue_attributes: %s", resp)
raise

def _purge(self, queue):
"""Delete all current messages in a queue."""
Expand Down
6 changes: 5 additions & 1 deletion images/airflow/2.9.2/python/mwaa/celery/sqs_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,11 @@ def _size(self, queue):
resp = c.get_queue_attributes(
QueueUrl=url, AttributeNames=["ApproximateNumberOfMessages"]
)
return int(resp["Attributes"]["ApproximateNumberOfMessages"])
try:
return int(resp["Attributes"]["ApproximateNumberOfMessages"])
except Exception:
logger.error("Unexpected response from SQS get_queue_attributes: %s", resp)
raise

def _purge(self, queue):
"""Delete all current messages in a queue."""
Expand Down
33 changes: 33 additions & 0 deletions tests/images/airflow/2.10.1/celery/test_sqs_broker_2_10_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,36 @@ def test_worker_heartbeat_conditional_metrics(self, mock_stats, mock_channel, nu
else:
assert not any(call[0][0] == "mwaa.celery.sqs.consumption_paused"
for call in mock_stats.gauge.call_args_list)

def test_size_with_attributes(self, mock_channel):
"""Test _size returns message count when Attributes exist."""
queue = 'test-queue'
mock_sqs = MagicMock()
mock_sqs.get_queue_attributes.return_value = {
'Attributes': {'ApproximateNumberOfMessages': '5'}
}

with patch.object(mock_channel, '_new_queue', return_value='queue-url'), \
patch.object(mock_channel, 'canonical_queue_name', return_value=queue), \
patch.object(mock_channel, 'sqs', return_value=mock_sqs):
result = mock_channel._size(queue)

assert result == 5
mock_sqs.get_queue_attributes.assert_called_once_with(
QueueUrl='queue-url', AttributeNames=['ApproximateNumberOfMessages']
)

def test_size_without_attributes(self, mock_channel):
"""Test _size raises exception when Attributes missing."""
queue = 'test-queue'
mock_sqs = MagicMock()
mock_sqs.get_queue_attributes.return_value = {}

with patch.object(mock_channel, '_new_queue', return_value='queue-url'), \
patch.object(mock_channel, 'canonical_queue_name', return_value=queue), \
patch.object(mock_channel, 'sqs', return_value=mock_sqs), \
patch('mwaa.celery.sqs_broker.logger') as mock_logger:
with pytest.raises(KeyError):
mock_channel._size(queue)

mock_logger.error.assert_called_once()
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@ def test_size_with_attributes(self, mock_channel):
}

with patch.object(mock_channel, '_new_queue', return_value='queue-url'), \
patch.object(mock_channel, 'canonical_queue_name', return_value=queue), \
patch.object(mock_channel, 'sqs', return_value=mock_sqs):

patch.object(mock_channel, 'canonical_queue_name', return_value=queue), \
patch.object(mock_channel, 'sqs', return_value=mock_sqs):
result = mock_channel._size(queue)

assert result == 5
Expand All @@ -136,10 +135,9 @@ def test_size_without_attributes(self, mock_channel):
mock_sqs.get_queue_attributes.return_value = {}

with patch.object(mock_channel, '_new_queue', return_value='queue-url'), \
patch.object(mock_channel, 'canonical_queue_name', return_value=queue), \
patch.object(mock_channel, 'sqs', return_value=mock_sqs), \
patch('mwaa.celery.sqs_broker.logger') as mock_logger:

patch.object(mock_channel, 'canonical_queue_name', return_value=queue), \
patch.object(mock_channel, 'sqs', return_value=mock_sqs), \
patch('mwaa.celery.sqs_broker.logger') as mock_logger:
with pytest.raises(KeyError):
mock_channel._size(queue)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,12 @@ def test_size_with_attributes(self, mock_channel):
mock_sqs.get_queue_attributes.return_value = {
'Attributes': {'ApproximateNumberOfMessages': '5'}
}

with patch.object(mock_channel, '_new_queue', return_value='queue-url'), \
patch.object(mock_channel, 'canonical_queue_name', return_value=queue), \
patch.object(mock_channel, 'sqs', return_value=mock_sqs):

patch.object(mock_channel, 'canonical_queue_name', return_value=queue), \
patch.object(mock_channel, 'sqs', return_value=mock_sqs):
result = mock_channel._size(queue)

assert result == 5
mock_sqs.get_queue_attributes.assert_called_once_with(
QueueUrl='queue-url', AttributeNames=['ApproximateNumberOfMessages']
Expand All @@ -117,13 +116,89 @@ def test_size_without_attributes(self, mock_channel):
queue = 'test-queue'
mock_sqs = MagicMock()
mock_sqs.get_queue_attributes.return_value = {}

with patch.object(mock_channel, '_new_queue', return_value='queue-url'), \
patch.object(mock_channel, 'canonical_queue_name', return_value=queue), \
patch.object(mock_channel, 'sqs', return_value=mock_sqs), \
patch('mwaa.celery.sqs_broker.logger') as mock_logger:

patch.object(mock_channel, 'canonical_queue_name', return_value=queue), \
patch.object(mock_channel, 'sqs', return_value=mock_sqs), \
patch('mwaa.celery.sqs_broker.logger') as mock_logger:
with pytest.raises(KeyError):
mock_channel._size(queue)

mock_logger.error.assert_called_once()

mock_logger.error.assert_called_once()

def test_celery_worker_task_limit_constant_parsing(self):
"""Test CELERY_WORKER_TASK_LIMIT correctly parses environment variable."""
with patch('mwaa.celery.sqs_broker.os.environ.get') as mock_env_get:
mock_env_get.return_value = '20,10'

# Re-import to trigger constant re-evaluation
import importlib
import mwaa.celery.sqs_broker
importlib.reload(mwaa.celery.sqs_broker)

from mwaa.celery.sqs_broker import CELERY_WORKER_TASK_LIMIT
assert CELERY_WORKER_TASK_LIMIT == 20

@pytest.mark.parametrize("monitoring_enabled,expected_tasks", [
(True, 15), # When monitoring enabled, use actual task count
(False, 20), # When monitoring disabled, use CELERY_WORKER_TASK_LIMIT
])
@patch('mwaa.celery.sqs_broker.Stats')
def test_worker_heartbeat_active_tasks_calculation(self, mock_stats, mock_channel, monitoring_enabled,
expected_tasks):
"""Test num_active_tasks calculation logic."""
message = {'test': 'message'}
exchange = 'celeryev'
routing_key = 'worker.heartbeat'

mock_channel._get_tasks_from_state = MagicMock(return_value=[{'task': 'test'}] * 15)
mock_channel._is_task_consumption_paused = MagicMock(return_value=False)
mock_channel.idle_worker_monitoring_enabled = monitoring_enabled

with patch('mwaa.celery.sqs_broker.os.environ.get') as mock_env_get:
mock_env_get.return_value = 'true'

mock_channel.basic_publish(message, exchange, routing_key)

if expected_tasks >= 20: # CELERY_WORKER_TASK_LIMIT
mock_stats.gauge.assert_any_call("mwaa.celery.at_max_concurrency", expected_tasks)
else:
assert not any(call[0][0] == "mwaa.celery.at_max_concurrency"
for call in mock_stats.gauge.call_args_list)

@pytest.mark.parametrize("num_tasks,is_paused,expect_max_concurrency,expect_consumption_paused", [
(25, False, True, False), # Above limit, not paused
(15, True, False, True), # Below limit, paused
(20, False, True, False), # At limit, not paused
(10, False, False, False), # Below limit, not paused
])
@patch('mwaa.celery.sqs_broker.Stats')
def test_worker_heartbeat_conditional_metrics(self, mock_stats, mock_channel, num_tasks, is_paused,
expect_max_concurrency, expect_consumption_paused):
"""Test conditional metric emission logic."""
message = {'test': 'message'}
exchange = 'celeryev'
routing_key = 'worker.heartbeat'

mock_channel._get_tasks_from_state = MagicMock(return_value=[{'task': 'test'}] * num_tasks)
mock_channel._is_task_consumption_paused = MagicMock(return_value=is_paused)
mock_channel.idle_worker_monitoring_enabled = True

with patch('mwaa.celery.sqs_broker.os.environ.get') as mock_env_get:
mock_env_get.return_value = 'true'

mock_channel.basic_publish(message, exchange, routing_key)

mock_stats.gauge.assert_any_call("mwaa.celery.process.heartbeat", 1)

if expect_max_concurrency:
mock_stats.gauge.assert_any_call("mwaa.celery.at_max_concurrency", num_tasks)
else:
assert not any(call[0][0] == "mwaa.celery.at_max_concurrency"
for call in mock_stats.gauge.call_args_list)

if expect_consumption_paused:
mock_stats.gauge.assert_any_call("mwaa.celery.sqs.consumption_paused", 1)
else:
assert not any(call[0][0] == "mwaa.celery.sqs.consumption_paused"
for call in mock_stats.gauge.call_args_list)
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,36 @@ def test_worker_heartbeat_conditional_metrics(self, mock_stats, mock_channel, nu
else:
assert not any(call[0][0] == "mwaa.celery.sqs.consumption_paused"
for call in mock_stats.gauge.call_args_list)

def test_size_with_attributes(self, mock_channel):
"""Test _size returns message count when Attributes exist."""
queue = 'test-queue'
mock_sqs = MagicMock()
mock_sqs.get_queue_attributes.return_value = {
'Attributes': {'ApproximateNumberOfMessages': '5'}
}

with patch.object(mock_channel, '_new_queue', return_value='queue-url'), \
patch.object(mock_channel, 'canonical_queue_name', return_value=queue), \
patch.object(mock_channel, 'sqs', return_value=mock_sqs):
result = mock_channel._size(queue)

assert result == 5
mock_sqs.get_queue_attributes.assert_called_once_with(
QueueUrl='queue-url', AttributeNames=['ApproximateNumberOfMessages']
)

def test_size_without_attributes(self, mock_channel):
"""Test _size raises exception when Attributes missing."""
queue = 'test-queue'
mock_sqs = MagicMock()
mock_sqs.get_queue_attributes.return_value = {}

with patch.object(mock_channel, '_new_queue', return_value='queue-url'), \
patch.object(mock_channel, 'canonical_queue_name', return_value=queue), \
patch.object(mock_channel, 'sqs', return_value=mock_sqs), \
patch('mwaa.celery.sqs_broker.logger') as mock_logger:
with pytest.raises(KeyError):
mock_channel._size(queue)

mock_logger.error.assert_called_once()