Skip to content

Commit 96f0ee3

Browse files
authored
Server-side parent step computation (#3762)
1 parent ad0b3cc commit 96f0ee3

File tree

6 files changed

+36
-35
lines changed

6 files changed

+36
-35
lines changed

src/zenml/models/v2/core/step_run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class StepRunRequest(ProjectScopedRequest):
129129
parent_step_ids: List[UUID] = Field(
130130
title="The IDs of the parent steps of this step run.",
131131
default_factory=list,
132+
deprecated=True,
132133
)
133134
inputs: Dict[str, List[UUID]] = Field(
134135
title="The IDs of the input artifact versions of the step run.",

src/zenml/orchestrators/input_utils.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
"""Utilities for inputs."""
1515

1616
import json
17-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
18-
from uuid import UUID
17+
from typing import TYPE_CHECKING, Dict, Optional
1918

2019
from zenml.client import Client
2120
from zenml.config.step_configurations import Step
@@ -32,7 +31,7 @@ def resolve_step_inputs(
3231
step: "Step",
3332
pipeline_run: "PipelineRunResponse",
3433
step_runs: Optional[Dict[str, "StepRunResponse"]] = None,
35-
) -> Tuple[Dict[str, "StepRunInputResponse"], List[UUID]]:
34+
) -> Dict[str, "StepRunInputResponse"]:
3635
"""Resolves inputs for the current step.
3736
3837
Args:
@@ -49,16 +48,14 @@ def resolve_step_inputs(
4948
resolved in runtime due to missing object.
5049
5150
Returns:
52-
The IDs of the input artifact versions and the IDs of parent steps of
53-
the current step.
51+
The input artifact versions.
5452
"""
5553
from zenml.models import ArtifactVersionResponse
5654
from zenml.models.v2.core.step_run import StepRunInputResponse
5755

5856
step_runs = step_runs or {}
5957

60-
steps_to_fetch = set(step.spec.upstream_steps)
61-
steps_to_fetch.update(
58+
steps_to_fetch = set(
6259
input_.step_name for input_ in step.spec.inputs.values()
6360
)
6461
# Remove all the step runs that we've already fetched.
@@ -205,9 +202,4 @@ def resolve_step_inputs(
205202
else:
206203
step.config.parameters[name] = value_
207204

208-
parent_step_ids = [
209-
step_runs[upstream_step].id
210-
for upstream_step in step.spec.upstream_steps
211-
]
212-
213-
return input_artifacts, parent_step_ids
205+
return input_artifacts

src/zenml/orchestrators/step_run_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def populate_request(
9595
"""
9696
step = self.deployment.step_configurations[request.name]
9797

98-
input_artifacts, parent_step_ids = input_utils.resolve_step_inputs(
98+
input_artifacts = input_utils.resolve_step_inputs(
9999
step=step,
100100
pipeline_run=self.pipeline_run,
101101
step_runs=step_runs,
@@ -108,7 +108,6 @@ def populate_request(
108108
request.inputs = {
109109
name: [artifact.id] for name, artifact in input_artifacts.items()
110110
}
111-
request.parent_step_ids = parent_step_ids
112111

113112
cache_key = cache_utils.generate_cache_key(
114113
step=step,

src/zenml/step_operators/step_operator_entrypoint_configuration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _run_step(
9090
)
9191

9292
stack = Client().active_stack
93-
input_artifacts, _ = input_utils.resolve_step_inputs(
93+
input_artifacts = input_utils.resolve_step_inputs(
9494
step=step, pipeline_run=pipeline_run
9595
)
9696
output_artifact_uris = output_utils.prepare_output_artifact_uris(

src/zenml/zen_stores/sql_zen_store.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8779,19 +8779,18 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse:
87798779
session.commit()
87808780
session.refresh(step_schema)
87818781

8782-
# Save parent step IDs into the database.
8783-
for parent_step_id in step_run.parent_step_ids:
8784-
self._set_run_step_parent_step(
8785-
child_step_run=step_schema,
8786-
parent_id=parent_step_id,
8787-
session=session,
8788-
)
8789-
87908782
session.commit()
87918783
session.refresh(step_schema)
87928784

87938785
step_model = step_schema.to_model(include_metadata=True)
87948786

8787+
for upstream_step in step_model.spec.upstream_steps:
8788+
self._set_run_step_parent_step(
8789+
child_step_run=step_schema,
8790+
parent_step_name=upstream_step,
8791+
session=session,
8792+
)
8793+
87958794
# Save input artifact IDs into the database.
87968795
for input_name, artifact_version_ids in step_run.inputs.items():
87978796
for artifact_version_id in artifact_version_ids:
@@ -9047,22 +9046,33 @@ def _get_step_run_input_type_from_config(
90479046
return StepRunInputArtifactType.MANUAL
90489047

90499048
def _set_run_step_parent_step(
9050-
self, child_step_run: StepRunSchema, parent_id: UUID, session: Session
9049+
self,
9050+
child_step_run: StepRunSchema,
9051+
parent_step_name: str,
9052+
session: Session,
90519053
) -> None:
90529054
"""Sets the parent step run for a step run.
90539055
90549056
Args:
90559057
child_step_run: The child step run to set the parent for.
9056-
parent_id: The ID of the parent step run to set a child for.
9058+
parent_step_name: The name of the parent step run to set a child for.
90579059
session: The database session to use.
9060+
9061+
Raises:
9062+
RuntimeError: If the parent step run is not found.
90589063
"""
9059-
parent_step_run = self._get_reference_schema_by_id(
9060-
resource=child_step_run,
9061-
reference_schema=StepRunSchema,
9062-
reference_id=parent_id,
9063-
session=session,
9064-
reference_type="parent step",
9065-
)
9064+
parent_step_run = session.exec(
9065+
select(StepRunSchema)
9066+
.where(StepRunSchema.name == parent_step_name)
9067+
.where(
9068+
StepRunSchema.pipeline_run_id == child_step_run.pipeline_run_id
9069+
)
9070+
).first()
9071+
if parent_step_run is None:
9072+
raise RuntimeError(
9073+
f"Parent step run `{parent_step_name}` not found for step run "
9074+
f"`{child_step_run.name}`."
9075+
)
90669076

90679077
# Check if the parent step is already set.
90689078
assignment = session.exec(

tests/unit/orchestrators/test_input_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_input_resolution(
5959
}
6060
)
6161

62-
input_artifacts, parent_ids = input_utils.resolve_step_inputs(
62+
input_artifacts = input_utils.resolve_step_inputs(
6363
step=step, pipeline_run=sample_pipeline_run
6464
)
6565
assert input_artifacts == {
@@ -68,7 +68,6 @@ def test_input_resolution(
6868
**sample_artifact_version_model.model_dump(),
6969
)
7070
}
71-
assert parent_ids == [step_run.id]
7271

7372

7473
def test_input_resolution_with_missing_step_run(mocker, sample_pipeline_run):

0 commit comments

Comments
 (0)