-
Notifications
You must be signed in to change notification settings - Fork 537
Add dynamic fan-out/fan-in with run templates #3826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
2063c42
6f8a430
17f3cc5
93ba23f
09a0f52
e4081db
fd02427
c6b0a1f
3fc5729
38d6400
5d6c813
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -303,6 +303,185 @@ The fan-in, fan-out method has the following limitations: | |
2. The number of steps need to be known ahead-of-time, and ZenML does not yet support the ability to dynamically create steps on the fly. | ||
{% endhint %} | ||
|
||
### Dynamic Fan-out/Fan-in with Run Templates | ||
|
||
For scenarios where you need to determine the number of parallel operations at runtime (e.g., based on database queries or dynamic data), you can use [run templates](https://docs.zenml.io/user-guides/tutorial/trigger-pipelines-from-external-systems) to create a more flexible fan-out/fan-in pattern. This approach allows you to trigger multiple pipeline runs dynamically and then aggregate their results. | ||
|
||
```python | ||
from typing import List, Optional | ||
from uuid import UUID | ||
import time | ||
|
||
from zenml import step, pipeline | ||
from zenml.client import Client | ||
|
||
|
||
@step | ||
def load_relevant_chunks() -> List[str]: | ||
"""Load chunk identifiers from database or other dynamic source.""" | ||
# Example: Query database for chunk IDs | ||
# In practice, this could be a database query, API call, etc. | ||
return ["chunk_1", "chunk_2", "chunk_3", "chunk_4"] | ||
|
||
|
||
@step | ||
def trigger_chunk_processing( | ||
chunks: List[str], | ||
template_id: Optional[UUID] = None | ||
) -> List[UUID]: | ||
"""Trigger multiple pipeline runs for each chunk and wait for completion.""" | ||
client = Client() | ||
|
||
# Use template ID if provided, otherwise use pipeline name | ||
pipeline_name = None if template_id else "chunk_processing_pipeline" | ||
|
||
# Trigger all chunk processing runs | ||
run_ids = [] | ||
for chunk_id in chunks: | ||
run_config = { | ||
"steps": { | ||
"process_chunk": { | ||
"parameters": { | ||
"chunk_id": chunk_id | ||
} | ||
} | ||
} | ||
} | ||
|
||
run = client.trigger_pipeline( | ||
template_id=template_id, | ||
pipeline_name_or_id=pipeline_name, | ||
run_configuration=run_config, | ||
synchronous=False # Run asynchronously | ||
) | ||
run_ids.append(run.id) | ||
|
||
# Wait for all runs to complete | ||
print(f"Waiting for {len(run_ids)} chunk processing runs to complete...") | ||
completed_runs = set() # Cache completed runs to avoid re-fetching | ||
while True: | ||
# Only check runs that haven't completed yet | ||
pending_runs = [run_id for run_id in run_ids if run_id not in completed_runs] | ||
|
||
for run_id in pending_runs: | ||
run = client.get_pipeline_run(run_id) | ||
if run.status.is_finished: | ||
completed_runs.add(run_id) | ||
|
||
if len(completed_runs) == len(run_ids): | ||
print("All chunk processing runs completed!") | ||
break | ||
|
||
print(f"Completed: {len(completed_runs)}/{len(run_ids)} runs") | ||
time.sleep(10) # Wait 10 seconds before checking again | ||
|
||
return run_ids | ||
|
||
|
||
@step | ||
def aggregate_results(run_ids: List[UUID]) -> dict: | ||
"""Aggregate results from all chunk processing runs.""" | ||
client = Client() | ||
aggregated_results = {} | ||
failed_runs = [] | ||
|
||
for run_id in run_ids: | ||
run = client.get_pipeline_run(run_id) | ||
|
||
# Check if run succeeded | ||
if run.status.value == "failed": | ||
failed_runs.append({ | ||
"run_id": str(run_id), | ||
"status": run.status.value, | ||
}) | ||
print(f"WARNING: Run {run_id} failed with status {run.status.value}") | ||
continue | ||
|
||
# Extract results from successful runs only | ||
if "process_chunk" in run.steps: | ||
step_run = run.steps["process_chunk"] | ||
# Simple assumption: process_chunk step has one output that we can load | ||
chunk_result = step_run.output.load() | ||
aggregated_results[str(run_id)] = chunk_result | ||
|
||
|
||
# Log summary of results | ||
total_runs = len(run_ids) | ||
successful_runs = len(aggregated_results) | ||
failed_count = len(failed_runs) | ||
|
||
print(f"Aggregation complete: {successful_runs}/{total_runs} runs successful") | ||
|
||
return { | ||
"successful_results": aggregated_results, | ||
"failed_runs": failed_runs, | ||
"summary": { | ||
"total_runs": total_runs, | ||
"successful_runs": successful_runs, | ||
"failed_runs": failed_count | ||
} | ||
} | ||
|
||
|
||
@pipeline(enable_cache=False) | ||
def fan_out_fan_in_pipeline(template_id: Optional[UUID] = None): | ||
"""Fan-out/fan-in pipeline that orchestrates dynamic chunk processing.""" | ||
# Load chunks dynamically at runtime | ||
chunks = load_relevant_chunks() | ||
|
||
# Trigger chunk processing runs and wait for completion | ||
run_ids = trigger_chunk_processing(chunks, template_id) | ||
|
||
# Aggregate results from all runs | ||
results = aggregate_results(run_ids) | ||
|
||
return results | ||
|
||
|
||
# Define the chunk processing pipeline that will be triggered | ||
@step | ||
def process_chunk(chunk_id: Optional[str] = None) -> dict: | ||
"""Process a single chunk of data.""" | ||
# Simulate chunk processing | ||
print(f"Processing chunk: {chunk_id}") | ||
return { | ||
"chunk_id": chunk_id, | ||
"processed_items": 100, | ||
"status": "completed" | ||
} | ||
|
||
|
||
@pipeline | ||
def chunk_processing_pipeline(): | ||
"""Pipeline that processes a single chunk.""" | ||
result = process_chunk() | ||
return result | ||
|
||
|
||
# Usage example | ||
if __name__ == "__main__": | ||
# First, make sure you run the chunk_processing_pipeline once | ||
# on a remote orchestrator: | ||
# chunk_processing_pipeline() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of doing this and later calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didnt know that! Thats great! |
||
|
||
# Second, create a run template for the chunk processing pipeline | ||
# This would typically be done once during setup. | ||
# You can also do this on the dashboard. | ||
pipe = Client().get_pipeline("chunk_processing_pipeline") | ||
run = pipe.runs[0] # We assume latest run | ||
template = Client().create_run_template( | ||
name="chunk_processing_template", | ||
deployment_id=run.deployment_id, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't exist anymore, and is also not needed in this call
htahir1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
description="Template for processing individual chunks" | ||
) | ||
|
||
# Run the fan-out/fan-in pipeline with the template | ||
# You can also get the template ID from the dashboard | ||
fan_out_fan_in_pipeline(template_id=template.id) | ||
``` | ||
|
||
This pattern enables dynamic scaling, true parallelism, and database-driven workflows. Key advantages include fault tolerance and separate monitoring for each chunk. Consider resource management and proper error handling when implementing. | ||
|
||
### Custom Step Invocation IDs | ||
|
||
When calling a ZenML step as part of your pipeline, it gets assigned a unique **invocation ID** that you can use to reference this step invocation when defining the execution order of your pipeline steps or use it to fetch information about the invocation after the pipeline has finished running. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it's worth explaining what happens here? Right now you're saying "use pipeline name", but that doesn't really mean anything to anyone. What this does is fetch the latest template for the pipeline with that name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed