Skip to content

Commit 44d15de

Browse files
SameerMesiah97dominikhei
authored andcommitted
Add drift detection and optional recreation to ComputeEngineInsertInstanceOperator (apache#61830)
1 parent 2c163ef commit 44d15de

File tree

3 files changed

+320
-26
lines changed

3 files changed

+320
-26
lines changed

providers/google/src/airflow/providers/google/cloud/operators/compute.py

Lines changed: 100 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ class ComputeEngineInsertInstanceOperator(ComputeEngineBaseOperator):
124124
:param timeout: The amount of time, in seconds, to wait for the request to complete.
125125
Note that if `retry` is specified, the timeout applies to each individual attempt.
126126
:param metadata: Additional metadata that is provided to the method.
127+
:param recreate_if_machine_type_different: When True, delete and recreate the instance if
128+
the existing machine type differs from the requested body. Defaults to
129+
False, in which case differences are only logged.
127130
"""
128131

129132
operator_extra_links = (ComputeInstanceDetailsLink(),)
@@ -156,6 +159,7 @@ def __init__(
156159
api_version: str = "v1",
157160
validate_body: bool = True,
158161
impersonation_chain: str | Sequence[str] | None = None,
162+
recreate_if_machine_type_different: bool = False,
159163
**kwargs,
160164
) -> None:
161165
self.body = body
@@ -167,6 +171,7 @@ def __init__(
167171
self.retry = retry
168172
self.timeout = timeout
169173
self.metadata = metadata
174+
self.recreate_if_machine_type_different = recreate_if_machine_type_different
170175

171176
if validate_body:
172177
self._field_validator = GcpBodyFieldValidator(
@@ -206,54 +211,123 @@ def _validate_all_body_fields(self) -> None:
206211
if self._field_validator:
207212
self._field_validator.validate(self.body)
208213

214+
def _extract_machine_type(self, value: str | None) -> str | None:
215+
if not value:
216+
return None
217+
return value.split("/")[-1]
218+
219+
def _detect_instance_drift(self, existing: Instance) -> dict[str, Any]:
220+
"""Detect machine type differences between the existing instance and the requested body."""
221+
diffs = {}
222+
223+
# Compare machine_type.
224+
requested_machine_type = self.body.get("machine_type")
225+
existing_machine_type = getattr(existing, "machine_type", None)
226+
227+
requested_name = self._extract_machine_type(requested_machine_type)
228+
existing_name = self._extract_machine_type(existing_machine_type)
229+
230+
if requested_name and existing_name and requested_name != existing_name:
231+
diffs["machine_type"] = {
232+
"existing": existing_name,
233+
"requested": requested_name,
234+
}
235+
236+
return diffs
237+
238+
def _create_instance(self, hook: ComputeEngineHook, context: Context) -> dict:
239+
"""Create the instance using the current body and return the created instance as dict."""
240+
self._field_sanitizer.sanitize(self.body)
241+
242+
self.log.info("Creating Instance with specified body: %s", self.body)
243+
244+
hook.insert_instance(
245+
body=self.body,
246+
request_id=self.request_id,
247+
project_id=self.project_id,
248+
zone=self.zone,
249+
)
250+
251+
self.log.info("The specified Instance has been created SUCCESSFULLY")
252+
253+
new_instance = hook.get_instance(
254+
resource_id=self.resource_id,
255+
project_id=self.project_id,
256+
zone=self.zone,
257+
)
258+
259+
ComputeInstanceDetailsLink.persist(
260+
context=context,
261+
project_id=self.project_id or hook.project_id,
262+
)
263+
264+
return Instance.to_dict(new_instance)
265+
209266
def execute(self, context: Context) -> dict:
267+
"""
268+
Ensure that a Compute Engine instance with the given name exists.
269+
270+
If the instance does not exist, it is created. If it already exists,
271+
presence is treated as success (presence-based idempotence).
272+
273+
If machine type drift is detected and ``recreate_if_machine_type_different=True``,
274+
the existing instance is deleted and recreated using the requested body.
275+
"""
210276
hook = ComputeEngineHook(
211277
gcp_conn_id=self.gcp_conn_id,
212278
api_version=self.api_version,
213279
impersonation_chain=self.impersonation_chain,
214280
)
215281
self._validate_all_body_fields()
216282
self.check_body_fields()
283+
217284
try:
218-
# Idempotence check (sort of) - we want to check if the new Instance
219-
# is already created and if is, then we assume it was created previously - we do
220-
# not check if content of the Instance is as expected.
221-
# We assume success if the Instance is simply present.
222285
existing_instance = hook.get_instance(
223286
resource_id=self.resource_id,
224287
project_id=self.project_id,
225288
zone=self.zone,
226289
)
227290
except exceptions.NotFound as e:
228-
# We actually expect to get 404 / Not Found here as the should not yet exist
291+
# We expect a 404 here if the instance does not yet exist.
229292
if e.code != 404:
230293
raise e
231-
else:
232-
self.log.info("The %s Instance already exists", self.resource_id)
233-
ComputeInstanceDetailsLink.persist(
234-
context=context,
235-
project_id=self.project_id or hook.project_id,
294+
295+
# Create instance if it does not exist.
296+
return self._create_instance(hook, context)
297+
298+
# Instance already exists.
299+
self.log.info("The %s Instance already exists", self.resource_id)
300+
301+
# Detect drift.
302+
diffs = self._detect_instance_drift(existing_instance)
303+
if diffs:
304+
self.log.warning(
305+
"Existing instance '%s' differs from requested configuration: %s",
306+
self.resource_id,
307+
diffs,
236308
)
237-
return Instance.to_dict(existing_instance)
238-
self._field_sanitizer.sanitize(self.body)
239-
self.log.info("Creating Instance with specified body: %s", self.body)
240-
hook.insert_instance(
241-
body=self.body,
242-
request_id=self.request_id,
243-
project_id=self.project_id,
244-
zone=self.zone,
245-
)
246-
self.log.info("The specified Instance has been created SUCCESSFULLY")
247-
new_instance = hook.get_instance(
248-
resource_id=self.resource_id,
249-
project_id=self.project_id,
250-
zone=self.zone,
251-
)
309+
310+
if self.recreate_if_machine_type_different:
311+
self.log.info(
312+
"Recreating instance '%s' because recreate_if_machine_type_different=True",
313+
self.resource_id,
314+
)
315+
316+
hook.delete_instance(
317+
resource_id=self.resource_id,
318+
project_id=self.project_id,
319+
request_id=self.request_id,
320+
zone=self.zone,
321+
)
322+
323+
return self._create_instance(hook, context)
324+
252325
ComputeInstanceDetailsLink.persist(
253326
context=context,
254327
project_id=self.project_id or hook.project_id,
255328
)
256-
return Instance.to_dict(new_instance)
329+
330+
return Instance.to_dict(existing_instance)
257331

258332

259333
class ComputeEngineInsertInstanceFromTemplateOperator(ComputeEngineBaseOperator):
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
19+
"""
20+
System test for ComputeEngineInsertInstanceOperator
21+
verifying recreate_if_machine_type_different=True recreates the
22+
correct machine_type instance when machine_type drifts.
23+
"""
24+
25+
from __future__ import annotations
26+
27+
import os
28+
from datetime import datetime
29+
30+
from airflow.models.dag import DAG
31+
from airflow.operators.python import PythonOperator
32+
from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook
33+
from airflow.providers.google.cloud.operators.compute import (
34+
ComputeEngineDeleteInstanceOperator,
35+
ComputeEngineInsertInstanceOperator,
36+
)
37+
38+
try:
39+
from airflow.sdk import TriggerRule
40+
except ImportError:
41+
from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef]
42+
43+
from system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
44+
45+
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
46+
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
47+
48+
DAG_ID = "cloud_compute_insert_recreate_if_different"
49+
LOCATION = "us-central1-a"
50+
51+
INSTANCE_NAME = f"airflow-drift-test-{ENV_ID}"
52+
MACHINE_TYPE_A = "n1-standard-1"
53+
MACHINE_TYPE_B = "n1-standard-2"
54+
55+
BASE_BODY = {
56+
"name": INSTANCE_NAME,
57+
"disks": [
58+
{
59+
"boot": True,
60+
"auto_delete": True,
61+
"initialize_params": {
62+
"disk_size_gb": "10",
63+
"source_image": "projects/debian-cloud/global/images/family/debian-12",
64+
},
65+
}
66+
],
67+
"network_interfaces": [{"network": "global/networks/default"}],
68+
}
69+
70+
71+
def assert_machine_type():
72+
hook = ComputeEngineHook()
73+
instance = hook.get_instance(
74+
project_id=PROJECT_ID,
75+
zone=LOCATION,
76+
resource_id=INSTANCE_NAME,
77+
)
78+
79+
machine_type = instance.machine_type.split("/")[-1]
80+
81+
assert machine_type == MACHINE_TYPE_B, f"Expected machine type {MACHINE_TYPE_B}, got {machine_type}"
82+
83+
84+
with DAG(
85+
DAG_ID,
86+
schedule="@once",
87+
start_date=datetime(2021, 1, 1),
88+
catchup=False,
89+
tags=["example", "compute"],
90+
) as dag:
91+
# Step 1: Create with machine type A.
92+
create_instance = ComputeEngineInsertInstanceOperator(
93+
task_id="create_instance",
94+
project_id=PROJECT_ID,
95+
zone=LOCATION,
96+
body={
97+
**BASE_BODY,
98+
"machine_type": f"zones/{LOCATION}/machineTypes/{MACHINE_TYPE_A}",
99+
},
100+
)
101+
102+
# Step 2: Re-run with different machine type and recreate recreate_if_machine_type_different=True.
103+
recreate_instance = ComputeEngineInsertInstanceOperator(
104+
task_id="recreate_instance",
105+
project_id=PROJECT_ID,
106+
zone=LOCATION,
107+
body={
108+
**BASE_BODY,
109+
"machine_type": f"zones/{LOCATION}/machineTypes/{MACHINE_TYPE_B}",
110+
},
111+
recreate_if_machine_type_different=True,
112+
)
113+
114+
# Step 3: Validate new machine type.
115+
validate_machine_type = PythonOperator(
116+
task_id="validate_machine_type",
117+
python_callable=assert_machine_type,
118+
)
119+
120+
# Step 4: Cleanup.
121+
delete_instance = ComputeEngineDeleteInstanceOperator(
122+
task_id="delete_instance",
123+
project_id=PROJECT_ID,
124+
zone=LOCATION,
125+
resource_id=INSTANCE_NAME,
126+
trigger_rule=TriggerRule.ALL_DONE,
127+
)
128+
129+
create_instance >> recreate_instance >> validate_machine_type >> delete_instance
130+
131+
# Everything below this line is required for system tests.
132+
from tests_common.test_utils.watcher import watcher
133+
134+
list(dag.tasks) >> watcher()
135+
136+
137+
from tests_common.test_utils.system_tests import get_test_run # noqa: E402
138+
139+
test_run = get_test_run(dag)

0 commit comments

Comments
 (0)