Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/zenml/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,12 @@ def link_artifact_config_to_model(
)

if model:
model._get_or_create_model_version()
model_version_response = model._get_model_version()
request = ModelVersionArtifactRequest(
user=client.active_user.id,
workspace=client.active_workspace.id,
artifact_version=artifact_version_id,
model=model_version_response.model.id,
model_version=model_version_response.id,
model=model.model_id,
model_version=model.id,
is_model_artifact=artifact_config.is_model_artifact,
is_deployment_artifact=artifact_config.is_deployment_artifact,
)
Expand Down
19 changes: 11 additions & 8 deletions src/zenml/orchestrators/step_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,18 +389,21 @@ def _link_cached_artifacts_to_model(
step_instance.entrypoint
)
for output_name_, output_id in step_run.outputs.items():
artifact_config_ = None
if output_name_ in output_annotations:
annotation = output_annotations.get(output_name_, None)
if annotation and annotation.artifact_config is not None:
artifact_config_ = annotation.artifact_config.copy()
else:
artifact_config_ = ArtifactConfig(name=output_name_)

link_artifact_config_to_model(
artifact_config=artifact_config_,
model=model_from_context,
artifact_version_id=output_id,
)
# no artifact config found or artifact was produced by `save_artifact`
# inside the step body, so was never in annotations
if artifact_config_ is None:
artifact_config_ = ArtifactConfig(name=output_name_)

link_artifact_config_to_model(
artifact_config=artifact_config_,
model=model_from_context,
artifact_version_id=output_id,
)

def _run_step(
self,
Expand Down
50 changes: 25 additions & 25 deletions tests/integration/functional/model/test_model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def artifact_linker(
model: Optional[Model] = None,
is_model_artifact: bool = False,
is_deployment_artifact: bool = False,
do_link: bool = True,
) -> None:
"""Step linking an artifact to a model via function or implicit."""

Expand All @@ -96,7 +95,7 @@ def artifact_linker(
is_deployment_artifact=is_deployment_artifact,
)

if do_link:
if model:
link_artifact_to_model(
artifact_version_id=artifact.id,
model=model,
Expand Down Expand Up @@ -502,9 +501,7 @@ def my_pipeline(is_consume: bool):
def test_link_artifact_via_function(self, clean_client: "Client"):
"""Test that user can link artifacts via function to a model version."""

@pipeline(
enable_cache=False,
)
@pipeline
def _inner_pipeline(
model: Model = None,
is_model_artifact: bool = False,
Expand All @@ -520,46 +517,48 @@ def _inner_pipeline(
name=MODEL_NAME,
)

# no context, no model
with pytest.raises(RuntimeError):
_inner_pipeline()
# no context, no model, artifact produced but not linked
_inner_pipeline()
artifact = clean_client.get_artifact_version("manual_artifact")
assert int(artifact.version) == 1

# use context
# pipeline will run in a model context and will use cached step version
_inner_pipeline.with_options(model=mv_in_pipe)()

mv = Model(name=MODEL_NAME, version="latest")
assert mv.number == 1
assert mv.get_artifact("manual_artifact").load() == "Hello, World!"
artifact = mv.get_artifact("manual_artifact")
assert artifact.load() == "Hello, World!"
assert int(artifact.version) == 1

# use custom model version
# use custom model version (cache invalidated)
_inner_pipeline(model=Model(name="custom_model_version"))

mv_custom = Model(name="custom_model_version", version="latest")
assert mv_custom.number == 1
assert (
mv_custom.get_artifact("manual_artifact").load() == "Hello, World!"
)
artifact = mv_custom.get_artifact("manual_artifact")
assert artifact.load() == "Hello, World!"
assert int(artifact.version) == 2

# use context + model
# use context + model (cache invalidated)
_inner_pipeline.with_options(model=mv_in_pipe)(is_model_artifact=True)

mv = Model(name=MODEL_NAME, version="latest")
assert mv.number == 2
assert (
mv.get_model_artifact("manual_artifact").load() == "Hello, World!"
)
artifact = mv.get_artifact("manual_artifact")
assert artifact.load() == "Hello, World!"
assert int(artifact.version) == 3

# use context + deployment
# use context + deployment (cache invalidated)
_inner_pipeline.with_options(model=mv_in_pipe)(
is_deployment_artifact=True
)

mv = Model(name=MODEL_NAME, version="latest")
assert mv.number == 3
assert (
mv.get_deployment_artifact("manual_artifact").load()
== "Hello, World!"
)
artifact = mv.get_deployment_artifact("manual_artifact")
assert artifact.load() == "Hello, World!"
assert int(artifact.version) == 4

# link outside of a step
artifact = save_artifact(data="Hello, World!", name="manual_artifact")
Expand All @@ -570,7 +569,9 @@ def _inner_pipeline(

mv = Model(name=MODEL_NAME, version="latest")
assert mv.number == 4
assert mv.get_artifact("manual_artifact").load() == "Hello, World!"
artifact = mv.get_artifact("manual_artifact")
assert artifact.load() == "Hello, World!"
assert int(artifact.version) == 5

def test_link_artifact_via_save_artifact(self, clean_client: "Client"):
"""Test that artifacts are auto-linked to a model version on call of `save_artifact`."""
Expand All @@ -585,7 +586,6 @@ def _inner_pipeline(
artifact_linker(
is_model_artifact=is_model_artifact,
is_deployment_artifact=is_deployment_artifact,
do_link=False,
)

mv_in_pipe = Model(
Expand Down