Skip to content

Commit a382218

Browse files
Fix cached artifacts produced via save_artifact inside steps linkage to MCP (#2619)
* also link artifacts from `save_artifact` in cached steps * extend tests to cover this case * fix test signature --------- Co-authored-by: Safoine El Khabich <[email protected]>
1 parent 469f92f commit a382218

File tree

3 files changed

+38
-37
lines changed

3 files changed

+38
-37
lines changed

src/zenml/model/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,12 @@ def link_artifact_config_to_model(
101101
)
102102

103103
if model:
104-
model._get_or_create_model_version()
105-
model_version_response = model._get_model_version()
106104
request = ModelVersionArtifactRequest(
107105
user=client.active_user.id,
108106
workspace=client.active_workspace.id,
109107
artifact_version=artifact_version_id,
110-
model=model_version_response.model.id,
111-
model_version=model_version_response.id,
108+
model=model.model_id,
109+
model_version=model.id,
112110
is_model_artifact=artifact_config.is_model_artifact,
113111
is_deployment_artifact=artifact_config.is_deployment_artifact,
114112
)

src/zenml/orchestrators/step_launcher.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -389,18 +389,21 @@ def _link_cached_artifacts_to_model(
389389
step_instance.entrypoint
390390
)
391391
for output_name_, output_id in step_run.outputs.items():
392+
artifact_config_ = None
392393
if output_name_ in output_annotations:
393394
annotation = output_annotations.get(output_name_, None)
394395
if annotation and annotation.artifact_config is not None:
395396
artifact_config_ = annotation.artifact_config.copy()
396-
else:
397-
artifact_config_ = ArtifactConfig(name=output_name_)
398-
399-
link_artifact_config_to_model(
400-
artifact_config=artifact_config_,
401-
model=model_from_context,
402-
artifact_version_id=output_id,
403-
)
397+
# no artifact config found or artifact was produced by `save_artifact`
398+
# inside the step body, so was never in annotations
399+
if artifact_config_ is None:
400+
artifact_config_ = ArtifactConfig(name=output_name_)
401+
402+
link_artifact_config_to_model(
403+
artifact_config=artifact_config_,
404+
model=model_from_context,
405+
artifact_version_id=output_id,
406+
)
404407

405408
def _run_step(
406409
self,

tests/integration/functional/model/test_model_version.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def artifact_linker(
8585
model: Optional[Model] = None,
8686
is_model_artifact: bool = False,
8787
is_deployment_artifact: bool = False,
88-
do_link: bool = True,
8988
) -> None:
9089
"""Step linking an artifact to a model via function or implicit."""
9190

@@ -96,7 +95,7 @@ def artifact_linker(
9695
is_deployment_artifact=is_deployment_artifact,
9796
)
9897

99-
if do_link:
98+
if model:
10099
link_artifact_to_model(
101100
artifact_version_id=artifact.id,
102101
model=model,
@@ -502,9 +501,7 @@ def my_pipeline(is_consume: bool):
502501
def test_link_artifact_via_function(self, clean_client: "Client"):
503502
"""Test that user can link artifacts via function to a model version."""
504503

505-
@pipeline(
506-
enable_cache=False,
507-
)
504+
@pipeline
508505
def _inner_pipeline(
509506
model: Model = None,
510507
is_model_artifact: bool = False,
@@ -520,46 +517,48 @@ def _inner_pipeline(
520517
name=MODEL_NAME,
521518
)
522519

523-
# no context, no model
524-
with pytest.raises(RuntimeError):
525-
_inner_pipeline()
520+
# no context, no model, artifact produced but not linked
521+
_inner_pipeline()
522+
artifact = clean_client.get_artifact_version("manual_artifact")
523+
assert int(artifact.version) == 1
526524

527-
# use context
525+
# pipeline will run in a model context and will use cached step version
528526
_inner_pipeline.with_options(model=mv_in_pipe)()
529527

530528
mv = Model(name=MODEL_NAME, version="latest")
531529
assert mv.number == 1
532-
assert mv.get_artifact("manual_artifact").load() == "Hello, World!"
530+
artifact = mv.get_artifact("manual_artifact")
531+
assert artifact.load() == "Hello, World!"
532+
assert int(artifact.version) == 1
533533

534-
# use custom model version
534+
# use custom model version (cache invalidated)
535535
_inner_pipeline(model=Model(name="custom_model_version"))
536536

537537
mv_custom = Model(name="custom_model_version", version="latest")
538538
assert mv_custom.number == 1
539-
assert (
540-
mv_custom.get_artifact("manual_artifact").load() == "Hello, World!"
541-
)
539+
artifact = mv_custom.get_artifact("manual_artifact")
540+
assert artifact.load() == "Hello, World!"
541+
assert int(artifact.version) == 2
542542

543-
# use context + model
543+
# use context + model (cache invalidated)
544544
_inner_pipeline.with_options(model=mv_in_pipe)(is_model_artifact=True)
545545

546546
mv = Model(name=MODEL_NAME, version="latest")
547547
assert mv.number == 2
548-
assert (
549-
mv.get_model_artifact("manual_artifact").load() == "Hello, World!"
550-
)
548+
artifact = mv.get_artifact("manual_artifact")
549+
assert artifact.load() == "Hello, World!"
550+
assert int(artifact.version) == 3
551551

552-
# use context + deployment
552+
# use context + deployment (cache invalidated)
553553
_inner_pipeline.with_options(model=mv_in_pipe)(
554554
is_deployment_artifact=True
555555
)
556556

557557
mv = Model(name=MODEL_NAME, version="latest")
558558
assert mv.number == 3
559-
assert (
560-
mv.get_deployment_artifact("manual_artifact").load()
561-
== "Hello, World!"
562-
)
559+
artifact = mv.get_deployment_artifact("manual_artifact")
560+
assert artifact.load() == "Hello, World!"
561+
assert int(artifact.version) == 4
563562

564563
# link outside of a step
565564
artifact = save_artifact(data="Hello, World!", name="manual_artifact")
@@ -570,7 +569,9 @@ def _inner_pipeline(
570569

571570
mv = Model(name=MODEL_NAME, version="latest")
572571
assert mv.number == 4
573-
assert mv.get_artifact("manual_artifact").load() == "Hello, World!"
572+
artifact = mv.get_artifact("manual_artifact")
573+
assert artifact.load() == "Hello, World!"
574+
assert int(artifact.version) == 5
574575

575576
def test_link_artifact_via_save_artifact(self, clean_client: "Client"):
576577
"""Test that artifacts are auto-linked to a model version on call of `save_artifact`."""
@@ -585,7 +586,6 @@ def _inner_pipeline(
585586
artifact_linker(
586587
is_model_artifact=is_model_artifact,
587588
is_deployment_artifact=is_deployment_artifact,
588-
do_link=False,
589589
)
590590

591591
mv_in_pipe = Model(

0 commit comments

Comments
 (0)