Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 87 additions & 45 deletions src/zenml/artifacts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,28 @@
import contextlib
import os
import tempfile
import time
import zipfile
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, cast
from uuid import UUID
from uuid import UUID, uuid4

from zenml.client import Client
from zenml.constants import MODEL_METADATA_YAML_FILE_NAME
from zenml.constants import (
MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION,
MODEL_METADATA_YAML_FILE_NAME,
)
from zenml.enums import (
ExecutionStatus,
MetadataResourceTypes,
StackComponentType,
VisualizationType,
)
from zenml.exceptions import DoesNotExistException, StepContextError
from zenml.exceptions import (
DoesNotExistException,
EntityExistsError,
StepContextError,
)
from zenml.io import fileio
from zenml.logger import get_logger
from zenml.models import (
Expand Down Expand Up @@ -107,26 +115,40 @@ def save_artifact(

Raises:
RuntimeError: If artifact URI already exists.
EntityExistsError: If artifact version already exists.
"""
from zenml.materializers.materializer_registry import (
materializer_registry,
)
from zenml.utils import source_utils

# TODO: Can we handle this server side? If we leave it empty in the request,
# it's an auto-increase?
# TODO: This can probably lead to issues when multiple steps request a new
# artifact version at the same time?
# Get new artifact version if not specified
version = version or _get_new_artifact_version(name)
client = Client()

# Get or create the artifact
try:
artifact = client.list_artifacts(name=name)[0]
if artifact.has_custom_name != has_custom_name:
client.update_artifact(
name_id_or_prefix=artifact.id, has_custom_name=has_custom_name
)
except IndexError:
try:
artifact = client.zen_store.create_artifact(
ArtifactRequest(
name=name,
has_custom_name=has_custom_name,
tags=tags,
)
)
except EntityExistsError:
artifact = client.list_artifacts(name=name)[0]

# Get the current artifact store
client = Client()
artifact_store = client.active_stack.artifact_store

# Build and check the artifact URI
if not uri:
uri = os.path.join("custom_artifacts", name, str(version))
uri = os.path.join("custom_artifacts", name, str(uuid4()))
if not uri.startswith(artifact_store.path):
uri = os.path.join(artifact_store.path, uri)

Expand All @@ -136,7 +158,7 @@ def save_artifact(
other_artifacts = client.list_artifact_versions(uri=uri, size=1)
if other_artifacts and (other_artifact := other_artifacts[0]):
raise RuntimeError(
f"Cannot save artifact {name} (version {version}) to URI "
f"Cannot save new artifact {name} version to URI "
f"{uri} because the URI is already used by artifact "
f"{other_artifact.name} (version {other_artifact.version})."
)
Expand Down Expand Up @@ -189,42 +211,62 @@ def save_artifact(
f"Failed to extract metadata for output artifact '{name}': {e}"
)

# Get or create the artifact
try:
artifact = client.list_artifacts(name=name)[0]
if artifact.has_custom_name != has_custom_name:
client.update_artifact(
name_id_or_prefix=artifact.id, has_custom_name=has_custom_name
# Create the artifact version
def _create_version() -> Optional[ArtifactVersionResponse]:
artifact_version = ArtifactVersionRequest(
artifact_id=artifact.id,
version=version,
tags=tags,
type=materializer_object.ASSOCIATED_ARTIFACT_TYPE,
uri=materializer_object.uri,
materializer=source_utils.resolve(materializer_object.__class__),
data_type=source_utils.resolve(data_type),
user=Client().active_user.id,
workspace=Client().active_workspace.id,
artifact_store_id=artifact_store.id,
visualizations=visualizations,
has_custom_name=has_custom_name,
)
try:
return client.zen_store.create_artifact_version(
artifact_version=artifact_version
)
except IndexError:
artifact = client.zen_store.create_artifact(
ArtifactRequest(
name=name,
has_custom_name=has_custom_name,
tags=tags,
except EntityExistsError:
return None

response = None
if not version:
retries_made = 0
for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
# Get new artifact version
version = _get_new_artifact_version(name)
if response := _create_version():
break
# smoothed exponential back-off, it will go as 0.2, 0.3,
# 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ...
sleep = 0.2 * 1.5**i
logger.debug(
f"Failed to create artifact version `{version}` for "
f"artifact `{name}`. Retrying in {sleep}..."
)
time.sleep(sleep)
retries_made += 1
if not response:
raise EntityExistsError(
f"Failed to create new artifact version for artifact "
f"`{name}`. Retried {retries_made} times. "
"This could be driven by exceptionally high concurrency of "
"pipeline runs. Please, reach out to us on ZenML Slack for support."
)
else:
response = _create_version()
if not response:
raise EntityExistsError(
f"Failed to create artifact version `{version}` for artifact "
f"`{name}`. Given version already exists."
)
)

# Create the artifact version
artifact_version = ArtifactVersionRequest(
artifact_id=artifact.id,
version=version,
tags=tags,
type=materializer_object.ASSOCIATED_ARTIFACT_TYPE,
uri=materializer_object.uri,
materializer=source_utils.resolve(materializer_object.__class__),
data_type=source_utils.resolve(data_type),
user=Client().active_user.id,
workspace=Client().active_workspace.id,
artifact_store_id=artifact_store.id,
visualizations=visualizations,
has_custom_name=has_custom_name,
)
response = Client().zen_store.create_artifact_version(
artifact_version=artifact_version
)
if artifact_metadata:
Client().create_run_metadata(
client.create_run_metadata(
metadata=artifact_metadata,
resource_id=response.id,
resource_type=MetadataResourceTypes.ARTIFACT_VERSION,
Expand Down
5 changes: 5 additions & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,8 @@ def handle_int_env_var(var: str, default: int = 0) -> int:

# Service connector constants
SERVICE_CONNECTOR_SKEW_TOLERANCE_SECONDS = 60 * 5 # 5 minutes

# Versioned entities
MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION = (
10 # empirical value to pass heavy parallelized tests
)
66 changes: 30 additions & 36 deletions src/zenml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""Model user facing interface to pass into pipeline or step."""

import time
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -25,6 +26,7 @@

from pydantic import BaseModel, PrivateAttr, root_validator

from zenml.constants import MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION
from zenml.enums import MetadataResourceTypes, ModelStages
from zenml.exceptions import EntityExistsError
from zenml.logger import get_logger
Expand Down Expand Up @@ -521,39 +523,6 @@ def _get_or_create_model(self) -> "ModelResponse":
model = zenml_client.zen_store.get_model(
model_name_or_id=self.name
)

difference: Dict[str, Any] = {}
for key in (
"license",
"audience",
"use_cases",
"limitations",
"trade_offs",
"ethics",
"save_models_to_registry",
):
if self_attr := getattr(self, key, None):
if self_attr != getattr(model, key):
difference[key] = {
"config": getattr(self, key),
"db": getattr(model, key),
}
if self.tags:
configured_tags = set(self.tags)
db_tags = {t.name for t in model.tags}
if db_tags != configured_tags:
difference["tags added"] = list(configured_tags - db_tags)
difference["tags removed"] = list(
db_tags - configured_tags
)
if difference:
logger.warning(
"Provided model configuration does not match "
f"existing model `{self.name}` with the "
f"following changes: {difference}. If you want to "
"update the model configuration, please use the "
"`zenml model update` command."
)
except KeyError:
model_request = ModelRequest(
name=self.name,
Expand Down Expand Up @@ -646,6 +615,7 @@ def _get_or_create_model_version(

Raises:
RuntimeError: if the model version needs to be created, but provided name is reserved
RuntimeError: if the model version cannot be created
"""
from zenml.client import Client
from zenml.models import ModelVersionRequest
Expand Down Expand Up @@ -723,9 +693,33 @@ def _get_or_create_model_version(
" as an example. You can explore model versions using "
f"`zenml model version list {self.name}` CLI command."
)
model_version = zenml_client.zen_store.create_model_version(
model_version=mv_request
)
retries_made = 0
for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
try:
model_version = (
zenml_client.zen_store.create_model_version(
model_version=mv_request
)
)
break
except EntityExistsError as e:
if i == MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - 1:
raise RuntimeError(
f"Failed to create model version "
f"`{self.version if self.version else 'new'}` "
f"in model `{self.name}`. Retried {retries_made} times. "
"This could be driven by exceptionally high concurrency of "
"pipeline runs. Please, reach out to us on ZenML Slack for support."
) from e
# smoothed exponential back-off, it will go as 0.2, 0.3,
# 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ...
sleep = 0.2 * 1.5**i
logger.debug(
f"Failed to create new model version for "
f"model `{self.name}`. Retrying in {sleep}..."
)
time.sleep(sleep)
retries_made += 1
self.version = model_version.name
self.was_created_in_this_run = True
logger.info(f"New model version `{self.version}` was created.")
Expand Down
Loading