Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 71 additions & 45 deletions src/zenml/artifacts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,28 @@
import contextlib
import os
import tempfile
import time
import zipfile
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, cast
from uuid import UUID
from uuid import UUID, uuid4

from zenml.client import Client
from zenml.constants import MODEL_METADATA_YAML_FILE_NAME
from zenml.constants import (
MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION,
MODEL_METADATA_YAML_FILE_NAME,
)
from zenml.enums import (
ExecutionStatus,
MetadataResourceTypes,
StackComponentType,
VisualizationType,
)
from zenml.exceptions import DoesNotExistException, StepContextError
from zenml.exceptions import (
DoesNotExistException,
EntityExistsError,
StepContextError,
)
from zenml.io import fileio
from zenml.logger import get_logger
from zenml.models import (
Expand Down Expand Up @@ -107,26 +115,40 @@ def save_artifact(

Raises:
RuntimeError: If artifact URI already exists.
EntityExistsError: If artifact version already exists.
"""
from zenml.materializers.materializer_registry import (
materializer_registry,
)
from zenml.utils import source_utils

# TODO: Can we handle this server side? If we leave it empty in the request,
# it's an auto-increase?
# TODO: This can probably lead to issues when multiple steps request a new
# artifact version at the same time?
# Get new artifact version if not specified
version = version or _get_new_artifact_version(name)
client = Client()

# Get or create the artifact
try:
artifact = client.list_artifacts(name=name)[0]
if artifact.has_custom_name != has_custom_name:
client.update_artifact(
name_id_or_prefix=artifact.id, has_custom_name=has_custom_name
)
except IndexError:
try:
artifact = client.zen_store.create_artifact(
ArtifactRequest(
name=name,
has_custom_name=has_custom_name,
tags=tags,
)
)
except EntityExistsError:
artifact = client.list_artifacts(name=name)[0]

# Get the current artifact store
client = Client()
artifact_store = client.active_stack.artifact_store

# Build and check the artifact URI
if not uri:
uri = os.path.join("custom_artifacts", name, str(version))
uri = os.path.join("custom_artifacts", name, str(uuid4()))
if not uri.startswith(artifact_store.path):
uri = os.path.join(artifact_store.path, uri)

Expand All @@ -136,7 +158,7 @@ def save_artifact(
other_artifacts = client.list_artifact_versions(uri=uri, size=1)
if other_artifacts and (other_artifact := other_artifacts[0]):
raise RuntimeError(
f"Cannot save artifact {name} (version {version}) to URI "
f"Cannot save new artifact {name} version to URI "
f"{uri} because the URI is already used by artifact "
f"{other_artifact.name} (version {other_artifact.version})."
)
Expand Down Expand Up @@ -189,42 +211,46 @@ def save_artifact(
f"Failed to extract metadata for output artifact '{name}': {e}"
)

# Get or create the artifact
try:
artifact = client.list_artifacts(name=name)[0]
if artifact.has_custom_name != has_custom_name:
client.update_artifact(
name_id_or_prefix=artifact.id, has_custom_name=has_custom_name
)
except IndexError:
artifact = client.zen_store.create_artifact(
ArtifactRequest(
name=name,
has_custom_name=has_custom_name,
tags=tags,
# Create the artifact version
def _create_version() -> Optional[ArtifactVersionResponse]:
artifact_version = ArtifactVersionRequest(
artifact_id=artifact.id,
version=version,
tags=tags,
type=materializer_object.ASSOCIATED_ARTIFACT_TYPE,
uri=materializer_object.uri,
materializer=source_utils.resolve(materializer_object.__class__),
data_type=source_utils.resolve(data_type),
user=Client().active_user.id,
workspace=Client().active_workspace.id,
artifact_store_id=artifact_store.id,
visualizations=visualizations,
has_custom_name=has_custom_name,
)
try:
return client.zen_store.create_artifact_version(
artifact_version=artifact_version
)
except EntityExistsError:
return None

response = None
if not version:
for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
# Get new artifact version
version = _get_new_artifact_version(name)
if response := _create_version():
break
time.sleep(0.2 * i)
else:
response = _create_version()
if not response:
raise EntityExistsError(
f"Failed to create artifact version `{version}` for artifact "
f"`{name}`, given version already exists."
)

# Create the artifact version
artifact_version = ArtifactVersionRequest(
artifact_id=artifact.id,
version=version,
tags=tags,
type=materializer_object.ASSOCIATED_ARTIFACT_TYPE,
uri=materializer_object.uri,
materializer=source_utils.resolve(materializer_object.__class__),
data_type=source_utils.resolve(data_type),
user=Client().active_user.id,
workspace=Client().active_workspace.id,
artifact_store_id=artifact_store.id,
visualizations=visualizations,
has_custom_name=has_custom_name,
)
response = Client().zen_store.create_artifact_version(
artifact_version=artifact_version
)
if artifact_metadata:
Client().create_run_metadata(
client.create_run_metadata(
metadata=artifact_metadata,
resource_id=response.id,
resource_type=MetadataResourceTypes.ARTIFACT_VERSION,
Expand Down
5 changes: 5 additions & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,8 @@ def handle_int_env_var(var: str, default: int = 0) -> int:

# Service connector constants
SERVICE_CONNECTOR_SKEW_TOLERANCE_SECONDS = 60 * 5 # 5 minutes

# Versioned entities
MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION = (
10 # empirical value to pass heavy parallelized tests
)
55 changes: 19 additions & 36 deletions src/zenml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""Model user facing interface to pass into pipeline or step."""

import time
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -25,6 +26,7 @@

from pydantic import BaseModel, PrivateAttr, root_validator

from zenml.constants import MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION
from zenml.enums import MetadataResourceTypes, ModelStages
from zenml.exceptions import EntityExistsError
from zenml.logger import get_logger
Expand Down Expand Up @@ -521,39 +523,6 @@ def _get_or_create_model(self) -> "ModelResponse":
model = zenml_client.zen_store.get_model(
model_name_or_id=self.name
)

difference: Dict[str, Any] = {}
for key in (
"license",
"audience",
"use_cases",
"limitations",
"trade_offs",
"ethics",
"save_models_to_registry",
):
if self_attr := getattr(self, key, None):
if self_attr != getattr(model, key):
difference[key] = {
"config": getattr(self, key),
"db": getattr(model, key),
}
if self.tags:
configured_tags = set(self.tags)
db_tags = {t.name for t in model.tags}
if db_tags != configured_tags:
difference["tags added"] = list(configured_tags - db_tags)
difference["tags removed"] = list(
db_tags - configured_tags
)
if difference:
logger.warning(
"Provided model configuration does not match "
f"existing model `{self.name}` with the "
f"following changes: {difference}. If you want to "
"update the model configuration, please use the "
"`zenml model update` command."
)
except KeyError:
model_request = ModelRequest(
name=self.name,
Expand Down Expand Up @@ -646,6 +615,7 @@ def _get_or_create_model_version(

Raises:
RuntimeError: if the model version needs to be created, but provided name is reserved
RuntimeError: if the model version cannot be created
"""
from zenml.client import Client
from zenml.models import ModelVersionRequest
Expand Down Expand Up @@ -723,9 +693,22 @@ def _get_or_create_model_version(
" as an example. You can explore model versions using "
f"`zenml model version list {self.name}` CLI command."
)
model_version = zenml_client.zen_store.create_model_version(
model_version=mv_request
)
for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION):
try:
model_version = (
zenml_client.zen_store.create_model_version(
model_version=mv_request
)
)
break
except EntityExistsError as e:
if i == MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - 1:
raise RuntimeError(
f"Failed to create model version "
f"`{self.version if self.version else 'new'}` "
f"in model `{self.name}`."
) from e
time.sleep(0.2 * i)
self.version = model_version.name
self.was_created_in_this_run = True
logger.info(f"New model version `{self.version}` was created.")
Expand Down
86 changes: 50 additions & 36 deletions src/zenml/new/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from zenml.config.schedule import Schedule
from zenml.config.step_configurations import StepConfigurationUpdate
from zenml.enums import ExecutionStatus, StackComponentType
from zenml.exceptions import EntityExistsError
from zenml.hooks.hook_validators import resolve_and_validate_hook
from zenml.logger import get_logger
from zenml.models import (
Expand Down Expand Up @@ -757,9 +758,11 @@ def _run(
user=Client().active_user.id,
workspace=deployment_model.workspace.id,
deployment=deployment_model.id,
pipeline=deployment_model.pipeline.id
if deployment_model.pipeline
else None,
pipeline=(
deployment_model.pipeline.id
if deployment_model.pipeline
else None
),
status=ExecutionStatus.INITIALIZING,
)
run = Client().zen_store.create_run(run_request)
Expand Down Expand Up @@ -1174,46 +1177,57 @@ def _register(self, pipeline_spec: "PipelineSpec") -> "PipelineResponse":
Returns:
The registered pipeline model.
"""

def _get(version_hash: str) -> PipelineResponse:
client = Client()

matching_pipelines = client.list_pipelines(
name=self.name,
version_hash=version_hash,
size=1,
sort_by="desc:created",
)
if matching_pipelines.total:
registered_pipeline = matching_pipelines.items[0]
logger.info(
"Reusing registered pipeline version: `(version: %s)`.",
registered_pipeline.version,
)
return registered_pipeline
raise RuntimeError("No matching pipelines found.")

version_hash = self._compute_unique_identifier(
pipeline_spec=pipeline_spec
)

client = Client()
matching_pipelines = client.list_pipelines(
name=self.name,
version_hash=version_hash,
size=1,
sort_by="desc:created",
)
if matching_pipelines.total:
registered_pipeline = matching_pipelines.items[0]
logger.info(
"Reusing registered pipeline version: `(version: %s)`.",
registered_pipeline.version,
try:
return _get(version_hash)
except RuntimeError:
latest_version = self._get_latest_version() or 0
version = str(latest_version + 1)

request = PipelineRequest(
workspace=client.active_workspace.id,
user=client.active_user.id,
name=self.name,
version=version,
version_hash=version_hash,
spec=pipeline_spec,
docstring=self.__doc__,
)
return registered_pipeline

latest_version = self._get_latest_version() or 0
version = str(latest_version + 1)

request = PipelineRequest(
workspace=client.active_workspace.id,
user=client.active_user.id,
name=self.name,
version=version,
version_hash=version_hash,
spec=pipeline_spec,
docstring=self.__doc__,
)

registered_pipeline = client.zen_store.create_pipeline(
pipeline=request
)
logger.info(
"Registered new version: `(version %s)`.",
registered_pipeline.version,
)
return registered_pipeline
try:
registered_pipeline = client.zen_store.create_pipeline(
pipeline=request
)
logger.info(
"Registered new version: `(version %s)`.",
registered_pipeline.version,
)
return registered_pipeline
except EntityExistsError:
return _get(version_hash)

def _compute_unique_identifier(self, pipeline_spec: PipelineSpec) -> str:
"""Computes a unique identifier from the pipeline spec and steps.
Expand Down
Loading