@@ -100,13 +100,12 @@ def test_size_with_attributes(self, mock_channel):
100
100
mock_sqs .get_queue_attributes .return_value = {
101
101
'Attributes' : {'ApproximateNumberOfMessages' : '5' }
102
102
}
103
-
103
+
104
104
with patch .object (mock_channel , '_new_queue' , return_value = 'queue-url' ), \
105
- patch .object (mock_channel , 'canonical_queue_name' , return_value = queue ), \
106
- patch .object (mock_channel , 'sqs' , return_value = mock_sqs ):
107
-
105
+ patch .object (mock_channel , 'canonical_queue_name' , return_value = queue ), \
106
+ patch .object (mock_channel , 'sqs' , return_value = mock_sqs ):
108
107
result = mock_channel ._size (queue )
109
-
108
+
110
109
assert result == 5
111
110
mock_sqs .get_queue_attributes .assert_called_once_with (
112
111
QueueUrl = 'queue-url' , AttributeNames = ['ApproximateNumberOfMessages' ]
@@ -117,13 +116,89 @@ def test_size_without_attributes(self, mock_channel):
117
116
queue = 'test-queue'
118
117
mock_sqs = MagicMock ()
119
118
mock_sqs .get_queue_attributes .return_value = {}
120
-
119
+
121
120
with patch .object (mock_channel , '_new_queue' , return_value = 'queue-url' ), \
122
- patch .object (mock_channel , 'canonical_queue_name' , return_value = queue ), \
123
- patch .object (mock_channel , 'sqs' , return_value = mock_sqs ), \
124
- patch ('mwaa.celery.sqs_broker.logger' ) as mock_logger :
125
-
121
+ patch .object (mock_channel , 'canonical_queue_name' , return_value = queue ), \
122
+ patch .object (mock_channel , 'sqs' , return_value = mock_sqs ), \
123
+ patch ('mwaa.celery.sqs_broker.logger' ) as mock_logger :
126
124
with pytest .raises (KeyError ):
127
125
mock_channel ._size (queue )
128
-
129
- mock_logger .error .assert_called_once ()
126
+
127
+ mock_logger .error .assert_called_once ()
128
+
129
+ def test_celery_worker_task_limit_constant_parsing (self ):
130
+ """Test CELERY_WORKER_TASK_LIMIT correctly parses environment variable."""
131
+ with patch ('mwaa.celery.sqs_broker.os.environ.get' ) as mock_env_get :
132
+ mock_env_get .return_value = '20,10'
133
+
134
+ # Re-import to trigger constant re-evaluation
135
+ import importlib
136
+ import mwaa .celery .sqs_broker
137
+ importlib .reload (mwaa .celery .sqs_broker )
138
+
139
+ from mwaa .celery .sqs_broker import CELERY_WORKER_TASK_LIMIT
140
+ assert CELERY_WORKER_TASK_LIMIT == 20
141
+
142
+ @pytest .mark .parametrize ("monitoring_enabled,expected_tasks" , [
143
+ (True , 15 ), # When monitoring enabled, use actual task count
144
+ (False , 20 ), # When monitoring disabled, use CELERY_WORKER_TASK_LIMIT
145
+ ])
146
+ @patch ('mwaa.celery.sqs_broker.Stats' )
147
+ def test_worker_heartbeat_active_tasks_calculation (self , mock_stats , mock_channel , monitoring_enabled ,
148
+ expected_tasks ):
149
+ """Test num_active_tasks calculation logic."""
150
+ message = {'test' : 'message' }
151
+ exchange = 'celeryev'
152
+ routing_key = 'worker.heartbeat'
153
+
154
+ mock_channel ._get_tasks_from_state = MagicMock (return_value = [{'task' : 'test' }] * 15 )
155
+ mock_channel ._is_task_consumption_paused = MagicMock (return_value = False )
156
+ mock_channel .idle_worker_monitoring_enabled = monitoring_enabled
157
+
158
+ with patch ('mwaa.celery.sqs_broker.os.environ.get' ) as mock_env_get :
159
+ mock_env_get .return_value = 'true'
160
+
161
+ mock_channel .basic_publish (message , exchange , routing_key )
162
+
163
+ if expected_tasks >= 20 : # CELERY_WORKER_TASK_LIMIT
164
+ mock_stats .gauge .assert_any_call ("mwaa.celery.at_max_concurrency" , expected_tasks )
165
+ else :
166
+ assert not any (call [0 ][0 ] == "mwaa.celery.at_max_concurrency"
167
+ for call in mock_stats .gauge .call_args_list )
168
+
169
+ @pytest .mark .parametrize ("num_tasks,is_paused,expect_max_concurrency,expect_consumption_paused" , [
170
+ (25 , False , True , False ), # Above limit, not paused
171
+ (15 , True , False , True ), # Below limit, paused
172
+ (20 , False , True , False ), # At limit, not paused
173
+ (10 , False , False , False ), # Below limit, not paused
174
+ ])
175
+ @patch ('mwaa.celery.sqs_broker.Stats' )
176
+ def test_worker_heartbeat_conditional_metrics (self , mock_stats , mock_channel , num_tasks , is_paused ,
177
+ expect_max_concurrency , expect_consumption_paused ):
178
+ """Test conditional metric emission logic."""
179
+ message = {'test' : 'message' }
180
+ exchange = 'celeryev'
181
+ routing_key = 'worker.heartbeat'
182
+
183
+ mock_channel ._get_tasks_from_state = MagicMock (return_value = [{'task' : 'test' }] * num_tasks )
184
+ mock_channel ._is_task_consumption_paused = MagicMock (return_value = is_paused )
185
+ mock_channel .idle_worker_monitoring_enabled = True
186
+
187
+ with patch ('mwaa.celery.sqs_broker.os.environ.get' ) as mock_env_get :
188
+ mock_env_get .return_value = 'true'
189
+
190
+ mock_channel .basic_publish (message , exchange , routing_key )
191
+
192
+ mock_stats .gauge .assert_any_call ("mwaa.celery.process.heartbeat" , 1 )
193
+
194
+ if expect_max_concurrency :
195
+ mock_stats .gauge .assert_any_call ("mwaa.celery.at_max_concurrency" , num_tasks )
196
+ else :
197
+ assert not any (call [0 ][0 ] == "mwaa.celery.at_max_concurrency"
198
+ for call in mock_stats .gauge .call_args_list )
199
+
200
+ if expect_consumption_paused :
201
+ mock_stats .gauge .assert_any_call ("mwaa.celery.sqs.consumption_paused" , 1 )
202
+ else :
203
+ assert not any (call [0 ][0 ] == "mwaa.celery.sqs.consumption_paused"
204
+ for call in mock_stats .gauge .call_args_list )
0 commit comments