Skip to content

Commit f7a3e91

Browse files
authored
Split run submission and monitoring (#3723)
* POC * Improve error handling and add warning * Linting * Catch more errors * More linting * Docs * Update more orchestrators * More orchestrators * Remaining orchestrators * Don't fail when metadata publishing failed * More linting * More linting * Docstrings * Lightning orchestrator cleanup * Hyperai orchestrator cleanup * More cleanup * Tests * Remove more unused code
1 parent d6a7eff commit f7a3e91

File tree

24 files changed

+574
-534
lines changed

24 files changed

+574
-534
lines changed

docs/book/component-guide/orchestrators/custom.md

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,14 @@ class BaseOrchestratorConfig(StackComponentConfig):
2828
class BaseOrchestrator(StackComponent, ABC):
2929
"""Base class for all ZenML orchestrators"""
3030

31-
@abstractmethod
32-
def prepare_or_run_pipeline(
31+
def submit_pipeline(
3332
self,
34-
deployment: PipelineDeploymentResponseModel,
35-
stack: Stack,
33+
deployment: "PipelineDeploymentResponse",
34+
stack: "Stack",
3635
environment: Dict[str, str],
37-
placeholder_run: Optional[PipelineRunResponse] = None,
38-
) -> Any:
39-
"""Prepares and runs the pipeline outright or returns an intermediate
40-
pipeline representation that gets deployed.
41-
"""
36+
placeholder_run: Optional["PipelineRunResponse"] = None,
37+
) -> Optional[SubmissionResult]:
38+
"""Submits a pipeline to the orchestrator."""
4239

4340
@abstractmethod
4441
def get_orchestrator_run_id(self) -> str:
@@ -84,7 +81,7 @@ This is a slimmed-down version of the base implementation which aims to highligh
8481

8582
If you want to create your own custom flavor for an orchestrator, you can follow the following steps:
8683

87-
1. Create a class that inherits from the `BaseOrchestrator` class and implement the abstract `prepare_or_run_pipeline(...)` and `get_orchestrator_run_id()` methods.
84+
1. Create a class that inherits from the `BaseOrchestrator` class and implement the abstract `submit_pipeline(...)` and `get_orchestrator_run_id()` methods.
8885
2. If you need to provide any configuration, create a class that inherits from the `BaseOrchestratorConfig` class and add your configuration parameters.
8986
3. Bring both the implementation and the configuration together by inheriting from the `BaseOrchestratorFlavor` class. Make sure that you give a `name` to the flavor through its abstract property.
9087

@@ -125,12 +122,15 @@ The design behind this interaction lets us separate the configuration of the fla
125122
## Implementation guide
126123

127124
1. **Create your orchestrator class:** This class should either inherit from `BaseOrchestrator`, or more commonly from `ContainerizedOrchestrator`. If your orchestrator uses container images to run code, you should inherit from `ContainerizedOrchestrator` which handles building all Docker images for the pipeline to be executed. If your orchestator does not use container images, you'll be responsible that the execution environment contains all the necessary requirements and code files to run the pipeline.
128-
2. **Implement the `prepare_or_run_pipeline(...)` method:** This method is responsible for running or scheduling the pipeline. In most cases, this means converting the pipeline into a format that your orchestration tool understands and running it. To do so, you should:
125+
2. **Implement the `submit_pipeline(...)` method:** This method is responsible for submitting the pipeline run or schedule. In most cases, this means converting the pipeline into a format that your orchestration backend understands and submitting it. To do so, you should:
129126

130127
* Loop over all steps of the pipeline and configure your orchestration tool to run the correct command and arguments in the correct Docker image
131128
* Make sure the passed environment variables are set when the container is run
132129
* Make sure the containers are running in the correct order
133130

131+
* If you want to store any metadata for the run or schedule, return it as part of the `SubmissionResult`.
132+
* If your orchestrator is configured to run synchronous, make sure to return a `wait_for_completion` closure in the `SubmissionResult`.
133+
134134
Check out the [code sample](custom.md#code-sample) below for more details on how to fetch the Docker image, command, arguments and step order.
135135
3. **Implement the `get_orchestrator_run_id()` method:** This must return a ID that is different for each pipeline run, but identical if called from within Docker containers running different steps of the same pipeline run. If your orchestrator is based on an external tool like Kubeflow or Airflow, it is usually best to use an unique ID provided by this tool.
136136

@@ -152,7 +152,7 @@ from typing import Dict
152152

153153
from zenml.entrypoints import StepEntrypointConfiguration
154154
from zenml.models import PipelineDeploymentResponseModel, PipelineRunResponse
155-
from zenml.orchestrators import ContainerizedOrchestrator
155+
from zenml.orchestrators import ContainerizedOrchestrator, SubmissionResult
156156
from zenml.stack import Stack
157157

158158

@@ -165,13 +165,13 @@ class MyOrchestrator(ContainerizedOrchestrator):
165165
# can usually use the run ID of that tool here.
166166
...
167167

168-
def prepare_or_run_pipeline(
168+
def submit_pipeline(
169169
self,
170170
deployment: "PipelineDeploymentResponseModel",
171171
stack: "Stack",
172172
environment: Dict[str, str],
173173
placeholder_run: Optional["PipelineRunResponse"] = None,
174-
) -> None:
174+
) -> Optional[SubmissionResult]:
175175
# If your orchestrator supports scheduling, you should handle the schedule
176176
# configured by the user. Otherwise you might raise an exception or log a warning
177177
# that the orchestrator doesn't support scheduling
@@ -209,6 +209,13 @@ class MyOrchestrator(ContainerizedOrchestrator):
209209
# specific resources were specified for this step:
210210
if self.requires_resources_in_orchestration_environment(step):
211211
resources = step.config.resource_settings
212+
213+
if self.config.synchronous:
214+
def _wait_for_completion() -> None:
215+
# Query your orchestrator backend to wait until the run has finished.
216+
# If possible, you can also stream the logs of the pipeline run here.
217+
218+
return SubmissionResult(wait_for_completion=_wait_for_completion)
212219
```
213220

214221
{% hint style="info" %}

src/zenml/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
173173
ENV_ZENML_DISABLE_STEP_NAMES_IN_LOGS = "ZENML_DISABLE_STEP_NAMES_IN_LOGS"
174174
ENV_ZENML_IGNORE_FAILURE_HOOK = "ZENML_IGNORE_FAILURE_HOOK"
175175
ENV_ZENML_CUSTOM_SOURCE_ROOT = "ZENML_CUSTOM_SOURCE_ROOT"
176-
ENV_ZENML_WHEEL_PACKAGE_NAME = "ZENML_WHEEL_PACKAGE_NAME"
177176
ENV_ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION = (
178177
"ZENML_PIPELINE_API_TOKEN_EXPIRATION"
179178
)

src/zenml/exceptions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,19 @@ class CustomFlavorImportError(ImportError):
220220

221221
class MaxConcurrentTasksError(ZenMLBaseException):
222222
"""Raised when the maximum number of concurrent tasks is reached."""
223+
224+
225+
class RunMonitoringError(ZenMLBaseException):
226+
"""Raised when an error occurs while monitoring a pipeline run."""
227+
228+
def __init__(
229+
self,
230+
original_exception: BaseException,
231+
) -> None:
232+
"""Initializes the error.
233+
234+
Args:
235+
original_exception: The original exception that occurred while
236+
monitoring the pipeline run.
237+
"""
238+
self.original_exception = original_exception

src/zenml/integrations/airflow/orchestrators/airflow_orchestrator.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from zenml.io import fileio
4040
from zenml.logger import get_logger
41-
from zenml.orchestrators import ContainerizedOrchestrator
41+
from zenml.orchestrators import ContainerizedOrchestrator, SubmissionResult
4242
from zenml.orchestrators.utils import get_orchestrator_run_name
4343
from zenml.stack import StackValidator
4444
from zenml.utils import io_utils
@@ -191,21 +191,29 @@ def prepare_pipeline_deployment(
191191
if self.config.local:
192192
stack.check_local_paths()
193193

194-
def prepare_or_run_pipeline(
194+
def submit_pipeline(
195195
self,
196196
deployment: "PipelineDeploymentResponse",
197197
stack: "Stack",
198198
environment: Dict[str, str],
199199
placeholder_run: Optional["PipelineRunResponse"] = None,
200-
) -> Any:
201-
"""Creates and writes an Airflow DAG zip file.
200+
) -> Optional[SubmissionResult]:
201+
"""Submits a pipeline to the orchestrator.
202+
203+
This method should only submit the pipeline and not wait for it to
204+
complete. If the orchestrator is configured to wait for the pipeline run
205+
to complete, a function that waits for the pipeline run to complete can
206+
be passed as part of the submission result.
202207
203208
Args:
204-
deployment: The pipeline deployment to prepare or run.
209+
deployment: The pipeline deployment to submit.
205210
stack: The stack the pipeline will run on.
206211
environment: Environment variables to set in the orchestration
207-
environment.
212+
environment. These don't need to be set if running locally.
208213
placeholder_run: An optional placeholder run for the deployment.
214+
215+
Returns:
216+
Optional submission result.
209217
"""
210218
pipeline_settings = cast(
211219
AirflowOrchestratorSettings, self.get_settings(deployment)
@@ -277,6 +285,7 @@ def prepare_or_run_pipeline(
277285
dag_generator_values=dag_generator_values,
278286
output_dir=pipeline_settings.dag_output_dir or self.dags_directory,
279287
)
288+
return None
280289

281290
def _apply_resource_settings(
282291
self,

src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py

Lines changed: 54 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
TYPE_CHECKING,
2020
Any,
2121
Dict,
22-
Iterator,
2322
List,
2423
Optional,
2524
Tuple,
@@ -60,7 +59,6 @@
6059
)
6160
from zenml.enums import (
6261
ExecutionStatus,
63-
MetadataResourceTypes,
6462
StackComponentType,
6563
)
6664
from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import (
@@ -73,7 +71,7 @@
7371
)
7472
from zenml.logger import get_logger
7573
from zenml.metadata.metadata_types import MetadataType, Uri
76-
from zenml.orchestrators import ContainerizedOrchestrator
74+
from zenml.orchestrators import ContainerizedOrchestrator, SubmissionResult
7775
from zenml.orchestrators.utils import get_orchestrator_run_name
7876
from zenml.stack import StackValidator
7977
from zenml.utils.env_utils import split_environment_variables
@@ -273,20 +271,25 @@ def _get_sagemaker_session(self) -> Session:
273271
boto_session=boto_session, default_bucket=self.config.bucket
274272
)
275273

276-
def prepare_or_run_pipeline(
274+
def submit_pipeline(
277275
self,
278276
deployment: "PipelineDeploymentResponse",
279277
stack: "Stack",
280278
environment: Dict[str, str],
281279
placeholder_run: Optional["PipelineRunResponse"] = None,
282-
) -> Iterator[Dict[str, MetadataType]]:
283-
"""Prepares or runs a pipeline on Sagemaker.
280+
) -> Optional[SubmissionResult]:
281+
"""Submits a pipeline to the orchestrator.
282+
283+
This method should only submit the pipeline and not wait for it to
284+
complete. If the orchestrator is configured to wait for the pipeline run
285+
to complete, a function that waits for the pipeline run to complete can
286+
be passed as part of the submission result.
284287
285288
Args:
286-
deployment: The deployment to prepare or run.
287-
stack: The stack to run on.
289+
deployment: The pipeline deployment to submit.
290+
stack: The stack the pipeline will run on.
288291
environment: Environment variables to set in the orchestration
289-
environment.
292+
environment. These don't need to be set if running locally.
290293
placeholder_run: An optional placeholder run for the deployment.
291294
292295
Raises:
@@ -296,8 +299,8 @@ def prepare_or_run_pipeline(
296299
AWS SageMaker NetworkConfig class.
297300
ValueError: If the schedule is not valid.
298301
299-
Yields:
300-
A dictionary of metadata related to the pipeline run.
302+
Returns:
303+
Optional submission result.
301304
"""
302305
# sagemaker requires pipelineName to use alphanum and hyphens only
303306
unsanitized_orchestrator_run_name = get_orchestrator_run_name(
@@ -705,26 +708,14 @@ def prepare_or_run_pipeline(
705708
)
706709
logger.info(f"The schedule ARN is: {triggers[0]}")
707710

711+
schedule_metadata = {}
708712
try:
709-
from zenml.models import RunMetadataResource
710-
711713
schedule_metadata = self.generate_schedule_metadata(
712714
schedule_arn=triggers[0]
713715
)
714-
715-
Client().create_run_metadata(
716-
metadata=schedule_metadata, # type: ignore[arg-type]
717-
resources=[
718-
RunMetadataResource(
719-
id=deployment.schedule.id,
720-
type=MetadataResourceTypes.SCHEDULE,
721-
)
722-
],
723-
)
724716
except Exception as e:
725717
logger.debug(
726-
"There was an error attaching metadata to the "
727-
f"schedule: {e}"
718+
"There was an error generating schedule metadata: %s", e
728719
)
729720

730721
logger.info(
@@ -749,6 +740,7 @@ def prepare_or_run_pipeline(
749740
logger.info(
750741
f"`aws scheduler delete-schedule --name {schedule_name}`"
751742
)
743+
return SubmissionResult(metadata=schedule_metadata)
752744
else:
753745
# Execute the pipeline immediately if no schedule is specified
754746
execution = pipeline.start()
@@ -757,33 +749,40 @@ def prepare_or_run_pipeline(
757749
"when using the Sagemaker Orchestrator."
758750
)
759751

760-
# Yield metadata based on the generated execution object
761-
yield from self.compute_metadata(
752+
run_metadata = self.compute_metadata(
762753
execution_arn=execution.arn, settings=settings
763754
)
764755

765-
# mainly for testing purposes, we wait for the pipeline to finish
756+
_wait_for_completion = None
766757
if settings.synchronous:
767-
logger.info(
768-
"Executing synchronously. Waiting for pipeline to "
769-
"finish... \n"
770-
"At this point you can `Ctrl-C` out without cancelling the "
771-
"execution."
772-
)
773-
try:
774-
execution.wait(
775-
delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
776-
)
777-
logger.info("Pipeline completed successfully.")
778-
except WaiterError:
779-
raise RuntimeError(
780-
"Timed out while waiting for pipeline execution to "
781-
"finish. For long-running pipelines we recommend "
782-
"configuring your orchestrator for asynchronous "
783-
"execution. The following command does this for you: \n"
784-
f"`zenml orchestrator update {self.name} "
785-
f"--synchronous=False`"
758+
759+
def _wait_for_completion() -> None:
760+
logger.info(
761+
"Executing synchronously. Waiting for pipeline to "
762+
"finish... \n"
763+
"At this point you can `Ctrl-C` out without cancelling the "
764+
"execution."
786765
)
766+
try:
767+
execution.wait(
768+
delay=POLLING_DELAY,
769+
max_attempts=MAX_POLLING_ATTEMPTS,
770+
)
771+
logger.info("Pipeline completed successfully.")
772+
except WaiterError:
773+
raise RuntimeError(
774+
"Timed out while waiting for pipeline execution to "
775+
"finish. For long-running pipelines we recommend "
776+
"configuring your orchestrator for asynchronous "
777+
"execution. The following command does this for you: \n"
778+
f"`zenml orchestrator update {self.name} "
779+
f"--synchronous=False`"
780+
)
781+
782+
return SubmissionResult(
783+
wait_for_completion=_wait_for_completion,
784+
metadata=run_metadata,
785+
)
787786

788787
def get_pipeline_run_metadata(
789788
self, run_id: UUID
@@ -798,20 +797,15 @@ def get_pipeline_run_metadata(
798797
"""
799798
execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
800799

801-
run_metadata: Dict[str, "MetadataType"] = {}
802-
803800
settings = cast(
804801
SagemakerOrchestratorSettings,
805802
self.get_settings(Client().get_pipeline_run(run_id)),
806803
)
807804

808-
for metadata in self.compute_metadata(
805+
return self.compute_metadata(
809806
execution_arn=execution_arn,
810807
settings=settings,
811-
):
812-
run_metadata.update(metadata)
813-
814-
return run_metadata
808+
)
815809

816810
def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
817811
"""Refreshes the status of a specific pipeline run.
@@ -873,14 +867,14 @@ def compute_metadata(
873867
self,
874868
execution_arn: str,
875869
settings: SagemakerOrchestratorSettings,
876-
) -> Iterator[Dict[str, MetadataType]]:
870+
) -> Dict[str, MetadataType]:
877871
"""Generate run metadata based on the generated Sagemaker Execution.
878872
879873
Args:
880874
execution_arn: The ARN of the pipeline execution.
881875
settings: The Sagemaker orchestrator settings.
882876
883-
Yields:
877+
Returns:
884878
A dictionary of metadata related to the pipeline run.
885879
"""
886880
# Orchestrator Run ID
@@ -901,7 +895,7 @@ def compute_metadata(
901895
):
902896
metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)
903897

904-
yield metadata
898+
return metadata
905899

906900
def _compute_orchestrator_url(
907901
self,
@@ -979,7 +973,9 @@ def _compute_orchestrator_logs_url(
979973
return None
980974

981975
@staticmethod
982-
def generate_schedule_metadata(schedule_arn: str) -> Dict[str, str]:
976+
def generate_schedule_metadata(
977+
schedule_arn: str,
978+
) -> Dict[str, MetadataType]:
983979
"""Attaches metadata to the ZenML Schedules.
984980
985981
Args:

0 commit comments

Comments
 (0)