Skip to content

Commit 43c531a

Browse files
avishniakovadtygan
authored andcommitted
Parallel pipelines can create entities in DB (zenml-io#2446)
* fix parallel artifacts registration * remove excessive warnings * parallel safe model versions * increase cool down a bit * coderabbitai * coderabbitai * update test signature * PR suggestions from Alex * kudos to windows * give some more retries for docker CIs * try to fix test case * fix parallel tests
1 parent 8492b0d commit 43c531a

File tree

10 files changed

+411
-179
lines changed

10 files changed

+411
-179
lines changed

src/zenml/artifacts/utils.py

Lines changed: 87 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,28 @@
1717
import contextlib
1818
import os
1919
import tempfile
20+
import time
2021
import zipfile
2122
from pathlib import Path
2223
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, cast
23-
from uuid import UUID
24+
from uuid import UUID, uuid4
2425

2526
from zenml.client import Client
26-
from zenml.constants import MODEL_METADATA_YAML_FILE_NAME
27+
from zenml.constants import (
28+
MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION,
29+
MODEL_METADATA_YAML_FILE_NAME,
30+
)
2731
from zenml.enums import (
2832
ExecutionStatus,
2933
MetadataResourceTypes,
3034
StackComponentType,
3135
VisualizationType,
3236
)
33-
from zenml.exceptions import DoesNotExistException, StepContextError
37+
from zenml.exceptions import (
38+
DoesNotExistException,
39+
EntityExistsError,
40+
StepContextError,
41+
)
3442
from zenml.io import fileio
3543
from zenml.logger import get_logger
3644
from zenml.models import (
@@ -107,26 +115,40 @@ def save_artifact(
107115
108116
Raises:
109117
RuntimeError: If artifact URI already exists.
118+
EntityExistsError: If artifact version already exists.
110119
"""
111120
from zenml.materializers.materializer_registry import (
112121
materializer_registry,
113122
)
114123
from zenml.utils import source_utils
115124

116-
# TODO: Can we handle this server side? If we leave it empty in the request,
117-
# it's an auto-increase?
118-
# TODO: This can probably lead to issues when multiple steps request a new
119-
# artifact version at the same time?
120-
# Get new artifact version if not specified
121-
version = version or _get_new_artifact_version(name)
125+
client = Client()
126+
127+
# Get or create the artifact
128+
try:
129+
artifact = client.list_artifacts(name=name)[0]
130+
if artifact.has_custom_name != has_custom_name:
131+
client.update_artifact(
132+
name_id_or_prefix=artifact.id, has_custom_name=has_custom_name
133+
)
134+
except IndexError:
135+
try:
136+
artifact = client.zen_store.create_artifact(
137+
ArtifactRequest(
138+
name=name,
139+
has_custom_name=has_custom_name,
140+
tags=tags,
141+
)
142+
)
143+
except EntityExistsError:
144+
artifact = client.list_artifacts(name=name)[0]
122145

123146
# Get the current artifact store
124-
client = Client()
125147
artifact_store = client.active_stack.artifact_store
126148

127149
# Build and check the artifact URI
128150
if not uri:
129-
uri = os.path.join("custom_artifacts", name, str(version))
151+
uri = os.path.join("custom_artifacts", name, str(uuid4()))
130152
if not uri.startswith(artifact_store.path):
131153
uri = os.path.join(artifact_store.path, uri)
132154

@@ -136,7 +158,7 @@ def save_artifact(
136158
other_artifacts = client.list_artifact_versions(uri=uri, size=1)
137159
if other_artifacts and (other_artifact := other_artifacts[0]):
138160
raise RuntimeError(
139-
f"Cannot save artifact {name} (version {version}) to URI "
161+
f"Cannot save new artifact {name} version to URI "
140162
f"{uri} because the URI is already used by artifact "
141163
f"{other_artifact.name} (version {other_artifact.version})."
142164
)
@@ -189,42 +211,62 @@ def save_artifact(
189211
f"Failed to extract metadata for output artifact '{name}': {e}"
190212
)
191213

192-
# Get or create the artifact
193-
try:
194-
artifact = client.list_artifacts(name=name)[0]
195-
if artifact.has_custom_name != has_custom_name:
196-
client.update_artifact(
197-
name_id_or_prefix=artifact.id, has_custom_name=has_custom_name
214+
# Create the artifact version
215+
def _create_version() -> Optional[ArtifactVersionResponse]:
216+
artifact_version = ArtifactVersionRequest(
217+
artifact_id=artifact.id,
218+
version=version,
219+
tags=tags,
220+
type=materializer_object.ASSOCIATED_ARTIFACT_TYPE,
221+
uri=materializer_object.uri,
222+
materializer=source_utils.resolve(materializer_object.__class__),
223+
data_type=source_utils.resolve(data_type),
224+
user=Client().active_user.id,
225+
workspace=Client().active_workspace.id,
226+
artifact_store_id=artifact_store.id,
227+
visualizations=visualizations,
228+
has_custom_name=has_custom_name,
229+
)
230+
try:
231+
return client.zen_store.create_artifact_version(
232+
artifact_version=artifact_version
198233
)
199-
except IndexError:
200-
artifact = client.zen_store.create_artifact(
201-
ArtifactRequest(
202-
name=name,
203-
has_custom_name=has_custom_name,
204-
tags=tags,
234+
except EntityExistsError:
235+
return None
236+
237+
response = None
238+
if not version:
239+
retries_made = 0
240+
for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
241+
# Get new artifact version
242+
version = _get_new_artifact_version(name)
243+
if response := _create_version():
244+
break
245+
# smoothed exponential back-off, it will go as 0.2, 0.3,
246+
# 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ...
247+
sleep = 0.2 * 1.5**i
248+
logger.debug(
249+
f"Failed to create artifact version `{version}` for "
250+
f"artifact `{name}`. Retrying in {sleep}..."
251+
)
252+
time.sleep(sleep)
253+
retries_made += 1
254+
if not response:
255+
raise EntityExistsError(
256+
f"Failed to create new artifact version for artifact "
257+
f"`{name}`. Retried {retries_made} times. "
258+
"This could be driven by exceptionally high concurrency of "
259+
"pipeline runs. Please, reach out to us on ZenML Slack for support."
260+
)
261+
else:
262+
response = _create_version()
263+
if not response:
264+
raise EntityExistsError(
265+
f"Failed to create artifact version `{version}` for artifact "
266+
f"`{name}`. Given version already exists."
205267
)
206-
)
207-
208-
# Create the artifact version
209-
artifact_version = ArtifactVersionRequest(
210-
artifact_id=artifact.id,
211-
version=version,
212-
tags=tags,
213-
type=materializer_object.ASSOCIATED_ARTIFACT_TYPE,
214-
uri=materializer_object.uri,
215-
materializer=source_utils.resolve(materializer_object.__class__),
216-
data_type=source_utils.resolve(data_type),
217-
user=Client().active_user.id,
218-
workspace=Client().active_workspace.id,
219-
artifact_store_id=artifact_store.id,
220-
visualizations=visualizations,
221-
has_custom_name=has_custom_name,
222-
)
223-
response = Client().zen_store.create_artifact_version(
224-
artifact_version=artifact_version
225-
)
226268
if artifact_metadata:
227-
Client().create_run_metadata(
269+
client.create_run_metadata(
228270
metadata=artifact_metadata,
229271
resource_id=response.id,
230272
resource_type=MetadataResourceTypes.ARTIFACT_VERSION,

src/zenml/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,8 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
313313

314314
# Service connector constants
315315
SERVICE_CONNECTOR_SKEW_TOLERANCE_SECONDS = 60 * 5 # 5 minutes
316+
317+
# Versioned entities
318+
MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION = (
319+
10 # empirical value to pass heavy parallelized tests
320+
)

src/zenml/model/model.py

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# permissions and limitations under the License.
1414
"""Model user facing interface to pass into pipeline or step."""
1515

16+
import time
1617
from typing import (
1718
TYPE_CHECKING,
1819
Any,
@@ -25,6 +26,7 @@
2526

2627
from pydantic import BaseModel, PrivateAttr, root_validator
2728

29+
from zenml.constants import MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION
2830
from zenml.enums import MetadataResourceTypes, ModelStages
2931
from zenml.exceptions import EntityExistsError
3032
from zenml.logger import get_logger
@@ -521,39 +523,6 @@ def _get_or_create_model(self) -> "ModelResponse":
521523
model = zenml_client.zen_store.get_model(
522524
model_name_or_id=self.name
523525
)
524-
525-
difference: Dict[str, Any] = {}
526-
for key in (
527-
"license",
528-
"audience",
529-
"use_cases",
530-
"limitations",
531-
"trade_offs",
532-
"ethics",
533-
"save_models_to_registry",
534-
):
535-
if self_attr := getattr(self, key, None):
536-
if self_attr != getattr(model, key):
537-
difference[key] = {
538-
"config": getattr(self, key),
539-
"db": getattr(model, key),
540-
}
541-
if self.tags:
542-
configured_tags = set(self.tags)
543-
db_tags = {t.name for t in model.tags}
544-
if db_tags != configured_tags:
545-
difference["tags added"] = list(configured_tags - db_tags)
546-
difference["tags removed"] = list(
547-
db_tags - configured_tags
548-
)
549-
if difference:
550-
logger.warning(
551-
"Provided model configuration does not match "
552-
f"existing model `{self.name}` with the "
553-
f"following changes: {difference}. If you want to "
554-
"update the model configuration, please use the "
555-
"`zenml model update` command."
556-
)
557526
except KeyError:
558527
model_request = ModelRequest(
559528
name=self.name,
@@ -646,6 +615,7 @@ def _get_or_create_model_version(
646615
647616
Raises:
648617
RuntimeError: if the model version needs to be created, but provided name is reserved
618+
RuntimeError: if the model version cannot be created
649619
"""
650620
from zenml.client import Client
651621
from zenml.models import ModelVersionRequest
@@ -723,9 +693,33 @@ def _get_or_create_model_version(
723693
" as an example. You can explore model versions using "
724694
f"`zenml model version list {self.name}` CLI command."
725695
)
726-
model_version = zenml_client.zen_store.create_model_version(
727-
model_version=mv_request
728-
)
696+
retries_made = 0
697+
for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
698+
try:
699+
model_version = (
700+
zenml_client.zen_store.create_model_version(
701+
model_version=mv_request
702+
)
703+
)
704+
break
705+
except EntityExistsError as e:
706+
if i == MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - 1:
707+
raise RuntimeError(
708+
f"Failed to create model version "
709+
f"`{self.version if self.version else 'new'}` "
710+
f"in model `{self.name}`. Retried {retries_made} times. "
711+
"This could be driven by exceptionally high concurrency of "
712+
"pipeline runs. Please, reach out to us on ZenML Slack for support."
713+
) from e
714+
# smoothed exponential back-off, it will go as 0.2, 0.3,
715+
# 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ...
716+
sleep = 0.2 * 1.5**i
717+
logger.debug(
718+
f"Failed to create new model version for "
719+
f"model `{self.name}`. Retrying in {sleep}..."
720+
)
721+
time.sleep(sleep)
722+
retries_made += 1
729723
self.version = model_version.name
730724
self.was_created_in_this_run = True
731725
logger.info(f"New model version `{self.version}` was created.")

0 commit comments

Comments
 (0)