Skip to content

Commit ba879d2

Browse files
dheerajturagadominikhei
authored andcommitted
Centralized runtime control of Edge Worker concurrency in distributed deployments (apache#62896)
1 parent c6e3482 commit ba879d2

File tree

13 files changed

+367
-6
lines changed

13 files changed

+367
-6
lines changed

providers/edge3/src/airflow/providers/edge3/cli/definition.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@
5959
help="Comma delimited list of queues to add or remove.",
6060
required=True,
6161
)
62+
ARG_CONCURRENCY_REQUIRED = Arg(
63+
("-c", "--concurrency"),
64+
type=int,
65+
help="The number of worker processes. Must be a positive integer.",
66+
required=True,
67+
)
6268
ARG_WAIT_MAINT = Arg(
6369
("-w", "--wait"),
6470
default=False,
@@ -229,6 +235,15 @@
229235
func=lazy_load_command("airflow.providers.edge3.cli.edge_command.shutdown_all_workers"),
230236
args=(ARG_YES,),
231237
),
238+
ActionCommand(
239+
name="set-worker-concurrency",
240+
help="Set the concurrency of a remote edge worker.",
241+
func=lazy_load_command("airflow.providers.edge3.cli.edge_command.set_remote_worker_concurrency"),
242+
args=(
243+
ARG_REQUIRED_EDGE_HOSTNAME,
244+
ARG_CONCURRENCY_REQUIRED,
245+
),
246+
),
232247
]
233248

234249

providers/edge3/src/airflow/providers/edge3/cli/edge_command.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,27 @@ def remove_worker_queues(args) -> None:
427427
except TypeError as e:
428428
logger.error(str(e))
429429
raise SystemExit
430+
431+
432+
@cli_utils.action_cli(check_db=False)
433+
@providers_configuration_loaded
434+
def set_remote_worker_concurrency(args) -> None:
435+
"""Set the concurrency of a remote edge worker."""
436+
_check_valid_db_connection()
437+
_check_if_registered_edge_host(hostname=args.edge_hostname)
438+
from airflow.providers.edge3.models.edge_worker import set_worker_concurrency
439+
440+
if args.concurrency <= 0:
441+
raise SystemExit("Error: Concurrency must be a positive integer.")
442+
443+
try:
444+
set_worker_concurrency(args.edge_hostname, args.concurrency)
445+
logger.info(
446+
"Concurrency set to %d for Edge Worker host %s by %s.",
447+
args.concurrency,
448+
args.edge_hostname,
449+
getuser(),
450+
)
451+
except TypeError as e:
452+
logger.error(str(e))
453+
raise SystemExit

providers/edge3/src/airflow/providers/edge3/cli/worker.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,13 @@ async def heartbeat(self, new_maintenance_comments: str | None = None) -> bool:
401401
new_maintenance_comments,
402402
)
403403
self.queues = worker_info.queues
404+
if worker_info.concurrency is not None and worker_info.concurrency != self.concurrency:
405+
logger.info(
406+
"Concurrency updated from %d to %d by remote request.",
407+
self.concurrency,
408+
worker_info.concurrency,
409+
)
410+
self.concurrency = worker_info.concurrency
404411
if worker_info.state == EdgeWorkerState.MAINTENANCE_REQUEST:
405412
logger.info("Maintenance mode requested!")
406413
self.maintenance_mode = True
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
19+
"""
20+
Add concurrency column to edge_worker table.
21+
22+
Revision ID: b3c4d5e6f7a8
23+
Revises: 9d34dfc2de06
24+
Create Date: 2026-03-04 00:00:00.000000
25+
"""
26+
27+
from __future__ import annotations
28+
29+
import sqlalchemy as sa
30+
from alembic import op
31+
32+
# revision identifiers, used by Alembic.
33+
revision = "b3c4d5e6f7a8"
34+
down_revision = "9d34dfc2de06"
35+
branch_labels = None
36+
depends_on = None
37+
edge3_version = "3.2.0"
38+
39+
40+
def upgrade() -> None:
41+
bind = op.get_bind()
42+
inspector = sa.inspect(bind)
43+
existing_columns = {col["name"] for col in inspector.get_columns("edge_worker")}
44+
if "concurrency" not in existing_columns:
45+
op.add_column("edge_worker", sa.Column("concurrency", sa.Integer(), nullable=True))
46+
47+
48+
def downgrade() -> None:
49+
op.drop_column("edge_worker", "concurrency")

providers/edge3/src/airflow/providers/edge3/models/db.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
_REVISION_HEADS_MAP: dict[str, str] = {
3333
"3.0.0": "9d34dfc2de06",
34+
"3.2.0": "b3c4d5e6f7a8",
3435
}
3536

3637

@@ -45,6 +46,33 @@ class EdgeDBManager(BaseDBManager):
4546
supports_table_dropping = True
4647
revision_heads_map = _REVISION_HEADS_MAP
4748

49+
def initdb(self):
50+
"""
51+
Initialize the database, handling pre-alembic installations.
52+
53+
If the edge3 tables already exist but the alembic version table does not
54+
(e.g. created via create_all before the migration chain was introduced),
55+
stamp to the first revision and run the incremental upgrade so every
56+
migration is applied rather than jumping straight to head.
57+
"""
58+
db_exists = self.get_current_revision()
59+
if db_exists:
60+
self.upgradedb()
61+
else:
62+
from airflow import settings
63+
64+
existing_tables = set(inspect(settings.engine).get_table_names())
65+
if any(table in existing_tables for table in self.metadata.tables):
66+
script = self.get_script_object()
67+
base_revision = next(r.revision for r in script.walk_revisions() if r.down_revision is None)
68+
config = self.get_alembic_config()
69+
from alembic import command
70+
71+
command.stamp(config, base_revision)
72+
self.upgradedb()
73+
else:
74+
self.create_db_from_orm()
75+
4876
def drop_tables(self, connection):
4977
"""Drop only edge3 tables in reverse dependency order."""
5078
if not self.supports_table_dropping:

providers/edge3/src/airflow/providers/edge3/models/edge_worker.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class EdgeWorkerModel(Base, LoggingMixin):
103103
jobs_success: Mapped[int] = mapped_column(Integer, default=0)
104104
jobs_failed: Mapped[int] = mapped_column(Integer, default=0)
105105
sysinfo: Mapped[str | None] = mapped_column(String(256))
106+
concurrency: Mapped[int | None] = mapped_column(Integer, nullable=True)
106107

107108
def __init__(
108109
self,
@@ -392,3 +393,23 @@ def remove_worker_queues(worker_name: str, queues: list[str], session: Session =
392393
logger.error(error_message)
393394
raise TypeError(error_message)
394395
worker.remove_queues(queues)
396+
397+
398+
@provide_session
399+
def set_worker_concurrency(worker_name: str, concurrency: int, session: Session = NEW_SESSION) -> None:
400+
"""Set the concurrency of an edge worker."""
401+
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
402+
worker: EdgeWorkerModel | None = session.scalar(query)
403+
if not worker:
404+
raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers")
405+
if worker.state in (
406+
EdgeWorkerState.OFFLINE,
407+
EdgeWorkerState.OFFLINE_MAINTENANCE,
408+
EdgeWorkerState.UNKNOWN,
409+
):
410+
error_message = (
411+
f"Cannot set concurrency for edge worker {worker_name} as it is in {worker.state} state!"
412+
)
413+
logger.error(error_message)
414+
raise TypeError(error_message)
415+
worker.concurrency = concurrency

providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,10 @@ class WorkerSetStateReturn(BaseModel):
200200
str | None,
201201
Field(description="Comments about the maintenance state of the worker."),
202202
] = None
203+
concurrency: Annotated[
204+
int | None,
205+
Field(
206+
description="Desired concurrency for the worker set by an administrator. "
207+
"None means no remote override; the worker uses its startup value.",
208+
),
209+
] = None

providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,10 @@ def set_state(
238238
)
239239
_assert_version(body.sysinfo) # Exception only after worker state is in the DB
240240
return WorkerSetStateReturn(
241-
state=worker.state, queues=worker.queues, maintenance_comments=worker.maintenance_comment
241+
state=worker.state,
242+
queues=worker.queues,
243+
maintenance_comments=worker.maintenance_comment,
244+
concurrency=worker.concurrency,
242245
)
243246

244247

providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,6 +1413,13 @@ components:
14131413
- type: 'null'
14141414
title: Maintenance Comments
14151415
description: Comments about the maintenance state of the worker.
1416+
concurrency:
1417+
anyOf:
1418+
- type: integer
1419+
- type: 'null'
1420+
title: Concurrency
1421+
description: Desired concurrency for the worker set by an administrator.
1422+
None means no remote override; the worker uses its startup value.
14161423
type: object
14171424
required:
14181425
- state

providers/edge3/tests/unit/edge3/cli/test_definition.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def test_edge_cli_commands_count(self):
5353
assert len(commands) == 1
5454

5555
def test_edge_commands_count(self):
56-
"""Test that EDGE_COMMANDS contains all 13 subcommands."""
57-
assert len(EDGE_COMMANDS) == 13
56+
"""Test that EDGE_COMMANDS contains all 14 subcommands."""
57+
assert len(EDGE_COMMANDS) == 14
5858

5959
@pytest.mark.parametrize(
6060
"command",
@@ -234,3 +234,17 @@ def test_shutdown_all_workers_args(self):
234234
params = ["edge", "shutdown-all-workers", "--yes"]
235235
args = self.arg_parser.parse_args(params)
236236
assert args.yes is True
237+
238+
def test_set_worker_concurrency_args(self):
239+
"""Test set-worker-concurrency command with required arguments."""
240+
params = [
241+
"edge",
242+
"set-worker-concurrency",
243+
"--edge-hostname",
244+
"remote-worker-1",
245+
"--concurrency",
246+
"16",
247+
]
248+
args = self.arg_parser.parse_args(params)
249+
assert args.edge_hostname == "remote-worker-1"
250+
assert args.concurrency == 16

0 commit comments

Comments
 (0)