diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 7dae26fcac..63e6e63ea2 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -37,3 +37,4 @@ nbformat>=5.9,<6 accelerate>=0.24.1,<=0.27.0 schema==0.7.5 tensorflow>=2.1,<=2.16 +mlflow>=2.12.2,<2.13 diff --git a/src/sagemaker/mlflow/__init__.py b/src/sagemaker/mlflow/__init__.py new file mode 100644 index 0000000000..6549052177 --- /dev/null +++ b/src/sagemaker/mlflow/__init__.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/src/sagemaker/mlflow/tracking_server.py b/src/sagemaker/mlflow/tracking_server.py new file mode 100644 index 0000000000..0baa0f457b --- /dev/null +++ b/src/sagemaker/mlflow/tracking_server.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + + +"""This module contains code related to the Mlflow Tracking Server.""" + +from __future__ import absolute_import +from typing import Optional, TYPE_CHECKING +from sagemaker.apiutils import _utils + +if TYPE_CHECKING: + from sagemaker import Session + + +def generate_mlflow_presigned_url( + name: str, + expires_in_seconds: Optional[int] = None, + session_expiration_duration_in_seconds: Optional[int] = None, + sagemaker_session: Optional["Session"] = None, +) -> str: + """Generate a presigned url to acess the Mlflow UI. + + Args: + name (str): Name of the Mlflow Tracking Server + expires_in_seconds (int): Expiration time of the first usage + of the presigned url in seconds. + session_expiration_duration_in_seconds (int): Session duration of the presigned url in + seconds after the first use. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + Returns: + (str): Authorized Url to acess the Mlflow UI. + """ + session = sagemaker_session or _utils.default_session() + api_response = session.create_presigned_mlflow_tracking_server_url( + name, expires_in_seconds, session_expiration_duration_in_seconds + ) + return api_response["AuthorizedUrl"] diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 44bc46b00b..02e4fe81dd 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -12,11 +12,14 @@ # language governing permissions and limitations under the License. """Holds the ModelBuilder class and the ModelServer enum.""" from __future__ import absolute_import + +import importlib.util import uuid from typing import Any, Type, List, Dict, Optional, Union from dataclasses import dataclass, field import logging import os +import re from pathlib import Path @@ -43,12 +46,15 @@ from sagemaker.predictor import Predictor from sagemaker.serve.model_format.mlflow.constants import ( MLFLOW_MODEL_PATH, + MLFLOW_TRACKING_ARN, + MLFLOW_RUN_ID_REGEX, + MLFLOW_REGISTRY_PATH_REGEX, + MODEL_PACKAGE_ARN_REGEX, MLFLOW_METADATA_FILE, MLFLOW_PIP_DEPENDENCY_FILE, ) from sagemaker.serve.model_format.mlflow.utils import ( _get_default_model_server_for_mlflow, - _mlflow_input_is_local_path, _download_s3_artifacts, _select_container_for_mlflow_model, _generate_mlflow_artifact_path, @@ -276,8 +282,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, default=None, metadata={ "help": "Define the model metadata to override, currently supports `HF_TASK`, " - "`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in " - "the Hub, Adding unsupported task types will throw an exception" + "`MLFLOW_MODEL_PATH`, and `MLFLOW_TRACKING_ARN`. HF_TASK should be set for new " + "models without task metadata in the Hub, Adding unsupported task types will " + "throw an exception" }, ) @@ -502,6 +509,7 @@ def _model_builder_register_wrapper(self, *args, **kwargs): mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH], s3_upload_path=self.s3_upload_path, sagemaker_session=self.sagemaker_session, + tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN), ) return new_model_package @@ -572,6 +580,7 @@ def _model_builder_deploy_wrapper( mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH], s3_upload_path=self.s3_upload_path, sagemaker_session=self.sagemaker_session, + tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN), ) return predictor @@ -625,11 +634,30 @@ def wrapper(*args, **kwargs): return wrapper - def _check_if_input_is_mlflow_model(self) -> bool: - """Checks whether an MLmodel file exists in the given directory. + def _handle_mlflow_input(self): + """Check whether an MLflow model is present and handle accordingly""" + self._is_mlflow_model = self._has_mlflow_arguments() + if not self._is_mlflow_model: + return + + mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH) + artifact_path = self._get_artifact_path(mlflow_model_path) + if not self._mlflow_metadata_exists(artifact_path): + logger.info( + "MLflow model metadata not detected in %s. ModelBuilder is not " + "handling MLflow model input", + mlflow_model_path, + ) + return + + self._initialize_for_mlflow(artifact_path) + _validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR")) + + def _has_mlflow_arguments(self) -> bool: + """Check whether MLflow model arguments are present Returns: - bool: True if the MLmodel file exists, False otherwise. + bool: True if MLflow arguments are present, False otherwise. """ if self.inference_spec or self.model: logger.info( @@ -644,8 +672,8 @@ def _check_if_input_is_mlflow_model(self) -> bool: ) return False - path = self.model_metadata.get(MLFLOW_MODEL_PATH) - if not path: + mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH) + if not mlflow_model_path: logger.info( "%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model " "input", @@ -653,7 +681,73 @@ def _check_if_input_is_mlflow_model(self) -> bool: ) return False - # Check for S3 path + return True + + def _get_artifact_path(self, mlflow_model_path: str) -> str: + """Retrieves the model artifact location given the Mlflow model input. + + Args: + mlflow_model_path (str): The MLflow model path input. + + Returns: + str: The path to the model artifact. + """ + if (is_run_id_type := re.match(MLFLOW_RUN_ID_REGEX, mlflow_model_path)) or re.match( + MLFLOW_REGISTRY_PATH_REGEX, mlflow_model_path + ): + mlflow_tracking_arn = self.model_metadata.get(MLFLOW_TRACKING_ARN) + if not mlflow_tracking_arn: + raise ValueError( + "%s is not provided in ModelMetadata or through set_tracking_arn " + "but MLflow model path was provided." % MLFLOW_TRACKING_ARN, + ) + + if not importlib.util.find_spec("sagemaker_mlflow"): + raise ImportError( + "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed" + ) + + import mlflow + + mlflow.set_tracking_uri(mlflow_tracking_arn) + if is_run_id_type: + _, run_id, model_path = mlflow_model_path.split("/", 2) + artifact_uri = mlflow.get_run(run_id).info.artifact_uri + if not artifact_uri.endswith("/"): + artifact_uri += "/" + return artifact_uri + model_path + + mlflow_client = mlflow.MlflowClient() + if not mlflow_model_path.endswith("/"): + mlflow_model_path += "/" + + if "@" in mlflow_model_path: + _, model_name_and_alias, artifact_uri = mlflow_model_path.split("/", 2) + model_name, model_alias = model_name_and_alias.split("@") + model_metadata = mlflow_client.get_model_version_by_alias(model_name, model_alias) + else: + _, model_name, model_version, artifact_uri = mlflow_model_path.split("/", 3) + model_metadata = mlflow_client.get_model_version(model_name, model_version) + + source = model_metadata.source + if not source.endswith("/"): + source += "/" + return source + artifact_uri + + if re.match(MODEL_PACKAGE_ARN_REGEX, mlflow_model_path): + model_package = self.sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=mlflow_model_path + ) + return model_package["SourceUri"] + + return mlflow_model_path + + def _mlflow_metadata_exists(self, path: str) -> bool: + """Checks whether an MLmodel file exists in the given directory. + + Returns: + bool: True if the MLmodel file exists, False otherwise. + """ if path.startswith("s3://"): s3_downloader = S3Downloader() if not path.endswith("/"): @@ -665,17 +759,18 @@ def _check_if_input_is_mlflow_model(self) -> bool: file_path = os.path.join(path, MLFLOW_METADATA_FILE) return os.path.isfile(file_path) - def _initialize_for_mlflow(self) -> None: - """Initialize mlflow model artifacts, image uri and model server.""" - mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH) - if not _mlflow_input_is_local_path(mlflow_path): - # TODO: extend to package arn, run id and etc. - logger.info( - "Start downloading model artifacts from %s to %s", mlflow_path, self.model_path - ) - _download_s3_artifacts(mlflow_path, self.model_path, self.sagemaker_session) + def _initialize_for_mlflow(self, artifact_path: str) -> None: + """Initialize mlflow model artifacts, image uri and model server. + + Args: + artifact_path (str): The path to the artifact store. + """ + if artifact_path.startswith("s3://"): + _download_s3_artifacts(artifact_path, self.model_path, self.sagemaker_session) + elif os.path.exists(artifact_path): + _copy_directory_contents(artifact_path, self.model_path) else: - _copy_directory_contents(mlflow_path, self.model_path) + raise ValueError("Invalid path: %s" % artifact_path) mlflow_model_metadata_path = _generate_mlflow_artifact_path( self.model_path, MLFLOW_METADATA_FILE ) @@ -728,6 +823,8 @@ def build( # pylint: disable=R0911 self.role_arn = role_arn self.sagemaker_session = sagemaker_session or Session() + self.sagemaker_session.settings._local_download_dir = self.model_path + # https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258 # decorate to_string() due to # https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015 @@ -739,14 +836,8 @@ def build( # pylint: disable=R0911 self.serve_settings = self._get_serve_setting() self._is_custom_image_uri = self.image_uri is not None - self._is_mlflow_model = self._check_if_input_is_mlflow_model() - if self._is_mlflow_model: - logger.warning( - "Support of MLflow format models is experimental and is not intended" - " for production at this moment." - ) - self._initialize_for_mlflow() - _validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR")) + + self._handle_mlflow_input() if isinstance(self.model, str): model_task = None @@ -836,6 +927,19 @@ def validate(self, model_dir: str) -> Type[bool]: return get_metadata(model_dir) + def set_tracking_arn(self, arn: str): + """Set tracking server ARN""" + # TODO: support native MLflow URIs + if importlib.util.find_spec("sagemaker_mlflow"): + import mlflow + + mlflow.set_tracking_uri(arn) + self.model_metadata[MLFLOW_TRACKING_ARN] = arn + else: + raise ImportError( + "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed" + ) + def _hf_schema_builder_init(self, model_task: str): """Initialize the schema builder for the given HF_TASK diff --git a/src/sagemaker/serve/model_format/mlflow/constants.py b/src/sagemaker/serve/model_format/mlflow/constants.py index 28a3cbdc8d..d7ddcd9ef0 100644 --- a/src/sagemaker/serve/model_format/mlflow/constants.py +++ b/src/sagemaker/serve/model_format/mlflow/constants.py @@ -22,9 +22,10 @@ MODEL_PACKAGE_ARN_REGEX = ( r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$" ) -MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$" -MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$" +MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9\-_\.]*)+$" +MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+[@/]?[a-zA-Z0-9\-_\.][/a-zA-Z0-9\-_\.]*$" S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+(?:\/[a-zA-Z0-9\-_\/\.]*)?$" +MLFLOW_TRACKING_ARN = "MLFLOW_TRACKING_ARN" MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH" MLFLOW_METADATA_FILE = "MLmodel" MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt" diff --git a/src/sagemaker/serve/model_format/mlflow/utils.py b/src/sagemaker/serve/model_format/mlflow/utils.py index c92a6a8a27..0d41cf4e33 100644 --- a/src/sagemaker/serve/model_format/mlflow/utils.py +++ b/src/sagemaker/serve/model_format/mlflow/utils.py @@ -227,28 +227,6 @@ def _get_python_version_from_parsed_mlflow_model_file( raise ValueError(f"{MLFLOW_PYFUNC} cannot be found in MLmodel file.") -def _mlflow_input_is_local_path(model_path: str) -> bool: - """Checks if the given model_path is a local filesystem path. - - Args: - - model_path (str): The model path to check. - - Returns: - - bool: True if model_path is a local path, False otherwise. - """ - if model_path.startswith("s3://"): - return False - - if "/runs/" in model_path or model_path.startswith("runs:"): - return False - - # Check if it's not a local file path - if not os.path.exists(model_path): - return False - - return True - - def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> None: """Downloads all artifacts from a specified S3 path to a local destination path. diff --git a/src/sagemaker/serve/utils/lineage_constants.py b/src/sagemaker/serve/utils/lineage_constants.py index 51be20739f..dce4a41139 100644 --- a/src/sagemaker/serve/utils/lineage_constants.py +++ b/src/sagemaker/serve/utils/lineage_constants.py @@ -16,6 +16,8 @@ LINEAGE_POLLER_INTERVAL_SECS = 15 LINEAGE_POLLER_MAX_TIMEOUT_SECS = 120 +TRACKING_SERVER_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):mlflow-tracking-server/(.*?)$" +TRACKING_SERVER_CREATION_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE = "ModelBuilderInputModelData" MLFLOW_S3_PATH = "S3" MLFLOW_MODEL_PACKAGE_PATH = "ModelPackage" diff --git a/src/sagemaker/serve/utils/lineage_utils.py b/src/sagemaker/serve/utils/lineage_utils.py index 3435e138c9..7278dd8a3c 100644 --- a/src/sagemaker/serve/utils/lineage_utils.py +++ b/src/sagemaker/serve/utils/lineage_utils.py @@ -17,7 +17,7 @@ import time import re import logging -from typing import Optional, Union +from typing import List, Optional, Union from botocore.exceptions import ClientError @@ -35,6 +35,8 @@ from sagemaker.serve.utils.lineage_constants import ( LINEAGE_POLLER_MAX_TIMEOUT_SECS, LINEAGE_POLLER_INTERVAL_SECS, + TRACKING_SERVER_ARN_REGEX, + TRACKING_SERVER_CREATION_TIME_FORMAT, MLFLOW_S3_PATH, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, MLFLOW_LOCAL_PATH, @@ -51,24 +53,41 @@ def _load_artifact_by_source_uri( - source_uri: str, artifact_type: str, sagemaker_session: Session + source_uri: str, + sagemaker_session: Session, + source_types_to_match: Optional[List[str]] = None, + artifact_type: Optional[str] = None, ) -> Optional[ArtifactSummary]: """Load lineage artifact by source uri Arguments: source_uri (str): The s3 uri used for uploading transfomred model artifacts. - artifact_type (str): The type of the lineage artifact. sagemaker_session (Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the function creates one using the default AWS configuration chain. + source_types_to_match (Optional[List[str]]): A list of source type values to match against + the artifact's source types. If provided, the artifact's source types must match this + list. + artifact_type (Optional[str]): The type of the lineage artifact. Returns: ArtifactSummary: The Artifact Summary for the provided S3 URI. """ artifacts = Artifact.list(source_uri=source_uri, sagemaker_session=sagemaker_session) for artifact_summary in artifacts: - if artifact_summary.artifact_type == artifact_type: - return artifact_summary + if artifact_type is None or artifact_summary.artifact_type == artifact_type: + if source_types_to_match: + if artifact_summary.source.source_types is not None: + artifact_source_types = [ + source_type["Value"] for source_type in artifact_summary.source.source_types + ] + if set(artifact_source_types) == set(source_types_to_match): + return artifact_summary + else: + return None + else: + return artifact_summary + return None @@ -90,7 +109,9 @@ def _poll_lineage_artifact( logger.info("Polling lineage artifact for model data in %s", s3_uri) start_time = time.time() while time.time() - start_time < LINEAGE_POLLER_MAX_TIMEOUT_SECS: - result = _load_artifact_by_source_uri(s3_uri, artifact_type, sagemaker_session) + result = _load_artifact_by_source_uri( + s3_uri, sagemaker_session, artifact_type=artifact_type + ) if result is not None: return result time.sleep(LINEAGE_POLLER_INTERVAL_SECS) @@ -105,12 +126,12 @@ def _get_mlflow_model_path_type(mlflow_model_path: str) -> str: Returns: str: Description of what the input string is identified as. """ - mlflow_rub_id_pattern = MLFLOW_RUN_ID_REGEX + mlflow_run_id_pattern = MLFLOW_RUN_ID_REGEX mlflow_registry_id_pattern = MLFLOW_REGISTRY_PATH_REGEX sagemaker_arn_pattern = MODEL_PACKAGE_ARN_REGEX s3_pattern = S3_PATH_REGEX - if re.match(mlflow_rub_id_pattern, mlflow_model_path): + if re.match(mlflow_run_id_pattern, mlflow_model_path): return MLFLOW_RUN_ID if re.match(mlflow_registry_id_pattern, mlflow_model_path): return MLFLOW_REGISTRY_PATH @@ -127,12 +148,14 @@ def _get_mlflow_model_path_type(mlflow_model_path: str) -> str: def _create_mlflow_model_path_lineage_artifact( mlflow_model_path: str, sagemaker_session: Session, + source_types_to_match: Optional[List[str]] = None, ) -> Optional[Artifact]: """Creates a lineage artifact for the given MLflow model path. Args: mlflow_model_path (str): The path to the MLflow model. sagemaker_session (Session): The SageMaker session object. + source_types_to_match (Optional[List[str]]): Artifact source types. Returns: Optional[Artifact]: The created lineage artifact, or None if an error occurred. @@ -142,8 +165,17 @@ def _create_mlflow_model_path_lineage_artifact( model_builder_input_model_data_type=_artifact_name, ) try: + source_types = [dict(SourceIdType="Custom", Value="ModelBuilderInputModelData")] + if source_types_to_match: + source_types += [ + dict(SourceIdType="Custom", Value=source_type) + for source_type in source_types_to_match + if source_type != "ModelBuilderInputModelData" + ] + return Artifact.create( source_uri=mlflow_model_path, + source_types=source_types, artifact_type=MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, artifact_name=_artifact_name, properties=properties, @@ -160,6 +192,7 @@ def _create_mlflow_model_path_lineage_artifact( def _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( mlflow_model_path: str, sagemaker_session: Session, + tracking_server_arn: Optional[str] = None, ) -> Optional[Union[Artifact, ArtifactSummary]]: """Retrieves an existing artifact for the given MLflow model path or @@ -170,20 +203,35 @@ def _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( sagemaker_session (Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the function creates one using the default AWS configuration chain. - + tracking_server_arn (Optional[str]): The MLflow tracking server ARN. Returns: Optional[Union[Artifact, ArtifactSummary]]: The existing or newly created artifact, or None if an error occurred. """ + source_types_to_match = ["ModelBuilderInputModelData"] + input_type = _get_mlflow_model_path_type(mlflow_model_path) + if tracking_server_arn and input_type in [MLFLOW_RUN_ID, MLFLOW_REGISTRY_PATH]: + match = re.match(TRACKING_SERVER_ARN_REGEX, tracking_server_arn) + mlflow_tracking_server_name = match.group(4) + describe_result = sagemaker_session.sagemaker_client.describe_mlflow_tracking_server( + TrackingServerName=mlflow_tracking_server_name + ) + tracking_server_creation_time = describe_result["CreationTime"].strftime( + TRACKING_SERVER_CREATION_TIME_FORMAT + ) + source_types_to_match += [tracking_server_arn, tracking_server_creation_time] _loaded_artifact = _load_artifact_by_source_uri( - mlflow_model_path, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, sagemaker_session + mlflow_model_path, + sagemaker_session, + source_types_to_match, ) if _loaded_artifact is not None: return _loaded_artifact return _create_mlflow_model_path_lineage_artifact( mlflow_model_path, sagemaker_session, + source_types_to_match, ) @@ -229,6 +277,7 @@ def _maintain_lineage_tracking_for_mlflow_model( mlflow_model_path: str, s3_upload_path: str, sagemaker_session: Session, + tracking_server_arn: Optional[str] = None, ) -> None: """Maintains lineage tracking for an MLflow model by creating or retrieving artifacts. @@ -238,6 +287,7 @@ def _maintain_lineage_tracking_for_mlflow_model( sagemaker_session (Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the function creates one using the default AWS configuration chain. + tracking_server_arn (Optional[str]): The MLflow tracking server ARN. """ artifact_for_transformed_model_data = _poll_lineage_artifact( s3_uri=s3_upload_path, @@ -249,6 +299,7 @@ def _maintain_lineage_tracking_for_mlflow_model( _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( mlflow_model_path=mlflow_model_path, sagemaker_session=sagemaker_session, + tracking_server_arn=tracking_server_arn, ) ) if mlflow_model_artifact: diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 7b1aaef447..205b14a3e6 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -6744,6 +6744,36 @@ def wait_for_inference_recommendations_job( _check_job_status(job_name, desc, "Status") return desc + def create_presigned_mlflow_tracking_server_url( + self, + tracking_server_name: str, + expires_in_seconds: int = None, + session_expiration_duration_in_seconds: int = None, + ) -> Dict[str, Any]: + """Creates a Presigned Url to acess the Mlflow UI. + + Args: + tracking_server_name (str): Name of the Mlflow Tracking Server. + expires_in_seconds (int): Expiration duration of the URL. + session_expiration_duration_in_seconds (int): Session duration of the URL. + Returns: + (dict): Return value from the ``CreatePresignedMlflowTrackingServerUrl`` API. + + """ + + create_presigned_url_args = {"TrackingServerName": tracking_server_name} + if expires_in_seconds is not None: + create_presigned_url_args["ExpiresInSeconds"] = expires_in_seconds + + if session_expiration_duration_in_seconds is not None: + create_presigned_url_args["SessionExpirationDurationInSeconds"] = ( + session_expiration_duration_in_seconds + ) + + return self.sagemaker_client.create_presigned_mlflow_tracking_server_url( + **create_presigned_url_args + ) + def get_model_package_args( content_types=None, diff --git a/tests/unit/sagemaker/mlflow/__init__.py b/tests/unit/sagemaker/mlflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/mlflow/test_tracking_server.py b/tests/unit/sagemaker/mlflow/test_tracking_server.py new file mode 100644 index 0000000000..1fc4943f16 --- /dev/null +++ b/tests/unit/sagemaker/mlflow/test_tracking_server.py @@ -0,0 +1,42 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import absolute_import +from sagemaker.mlflow.tracking_server import generate_mlflow_presigned_url + + +def test_generate_presigned_url(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_presigned_mlflow_tracking_server_url.return_value = { + "AuthorizedUrl": "https://t-wo.example.com", + } + url = generate_mlflow_presigned_url( + "w", + expires_in_seconds=10, + session_expiration_duration_in_seconds=5, + sagemaker_session=sagemaker_session, + ) + client.create_presigned_mlflow_tracking_server_url.assert_called_with( + TrackingServerName="w", ExpiresInSeconds=10, SessionExpirationDurationInSeconds=5 + ) + assert url == "https://t-wo.example.com" + + +def test_generate_presigned_url_minimal(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_presigned_mlflow_tracking_server_url.return_value = { + "AuthorizedUrl": "https://t-wo.example.com", + } + url = generate_mlflow_presigned_url("w", sagemaker_session=sagemaker_session) + client.create_presigned_mlflow_tracking_server_url.assert_called_with(TrackingServerName="w") + assert url == "https://t-wo.example.com" diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 0c06b5ae8e..cecaf6450e 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -19,6 +19,7 @@ from sagemaker.serve.builder.model_builder import ModelBuilder from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.model_format.mlflow.constants import MLFLOW_TRACKING_ARN from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.predictors import TensorflowServingLocalPredictor @@ -2257,3 +2258,187 @@ def test_build_tensorflow_serving_non_mlflow_case( mock_role_arn, mock_session, ) + + def test_handle_mlflow_input_without_mlflow_model_path(self): + builder = ModelBuilder(model_metadata={}) + assert not builder._has_mlflow_arguments() + + @patch("importlib.util.find_spec") + @patch("mlflow.set_tracking_uri") + @patch("mlflow.get_run") + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_run_id( + self, + mock_validate, + mock_s3_downloader, + mock_initialize, + mock_check_mlflow_model, + mock_get_run, + mock_set_tracking_uri, + mock_find_spec, + ): + mock_find_spec.return_value = True + mock_run_info = Mock() + mock_run_info.info.artifact_uri = "s3://bucket/path" + mock_get_run.return_value = mock_run_info + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "runs:/runid/mlflow-path", + "MLFLOW_TRACKING_ARN": "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", + } + ) + builder._handle_mlflow_input() + mock_initialize.assert_called_once_with(builder, "s3://bucket/path/mlflow-path") + + @patch("importlib.util.find_spec") + @patch("mlflow.set_tracking_uri") + @patch("mlflow.MlflowClient.get_model_version") + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_registry_path_with_model_version( + self, + mock_validate, + mock_s3_downloader, + mock_initialize, + mock_check_mlflow_model, + mock_get_model_version, + mock_set_tracking_uri, + mock_find_spec, + ): + mock_find_spec.return_value = True + mock_registry_path = Mock() + mock_registry_path.source = "s3://bucket/path/" + mock_get_model_version.return_value = mock_registry_path + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "models:/model-name/1", + "MLFLOW_TRACKING_ARN": "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", + } + ) + builder._handle_mlflow_input() + mock_initialize.assert_called_once_with(builder, "s3://bucket/path/") + + @patch("importlib.util.find_spec") + @patch("mlflow.set_tracking_uri") + @patch("mlflow.MlflowClient.get_model_version_by_alias") + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_registry_path_with_model_alias( + self, + mock_validate, + mock_s3_downloader, + mock_initialize, + mock_check_mlflow_model, + mock_get_model_version_by_alias, + mock_set_tracking_uri, + mock_find_spec, + ): + mock_find_spec.return_value = True + mock_registry_path = Mock() + mock_registry_path.source = "s3://bucket/path" + mock_get_model_version_by_alias.return_value = mock_registry_path + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "models:/model-name@production", + "MLFLOW_TRACKING_ARN": "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", + } + ) + builder._handle_mlflow_input() + mock_initialize.assert_called_once_with(builder, "s3://bucket/path/") + + @patch("mlflow.MlflowClient.get_model_version") + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_registry_path_missing_tracking_server_arn( + self, + mock_validate, + mock_s3_downloader, + mock_initialize, + mock_check_mlflow_model, + mock_get_model_version, + ): + mock_registry_path = Mock() + mock_registry_path.source = "s3://bucket/path" + mock_get_model_version.return_value = mock_registry_path + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "models:/model-name/1", + } + ) + self.assertRaisesRegex( + Exception, + "%s is not provided in ModelMetadata or through set_tracking_arn " + "but MLflow model path was provided." % MLFLOW_TRACKING_ARN, + builder._handle_mlflow_input, + ) + + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_model_package_arn( + self, mock_validate, mock_s3_downloader, mock_initialize, mock_check_mlflow_model + ): + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + mock_model_package = {"SourceUri": "s3://bucket/path"} + mock_session.sagemaker_client.describe_model_package.return_value = mock_model_package + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "arn:aws:sagemaker:us-west-2:000000000000:model-package/test", + "MLFLOW_TRACKING_ARN": "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", + }, + sagemaker_session=mock_session, + ) + builder._handle_mlflow_input() + mock_initialize.assert_called_once_with(builder, "s3://bucket/path") + + @patch("importlib.util.find_spec", Mock(return_value=True)) + @patch("mlflow.set_tracking_uri") + def test_set_tracking_arn_success(self, mock_set_tracking_uri): + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "arn:aws:sagemaker:us-west-2:000000000000:model-package/test", + } + ) + tracking_arn = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + builder.set_tracking_arn(tracking_arn) + mock_set_tracking_uri.assert_called_once_with(tracking_arn) + assert builder.model_metadata[MLFLOW_TRACKING_ARN] == tracking_arn + + @patch("importlib.util.find_spec", Mock(return_value=False)) + def test_set_tracking_arn_mlflow_not_installed(self): + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "arn:aws:sagemaker:us-west-2:000000000000:model-package/test", + } + ) + tracking_arn = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + self.assertRaisesRegex( + ImportError, + "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed", + builder.set_tracking_arn, + tracking_arn, + ) diff --git a/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py b/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py index 23d1315647..819800ba46 100644 --- a/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py +++ b/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py @@ -32,7 +32,6 @@ _get_framework_version_from_requirements, _get_deployment_flavor, _get_python_version_from_parsed_mlflow_model_file, - _mlflow_input_is_local_path, _download_s3_artifacts, _select_container_for_mlflow_model, _validate_input_for_mlflow, @@ -197,17 +196,6 @@ def test_get_python_version_from_parsed_mlflow_model_file(): _get_python_version_from_parsed_mlflow_model_file({}) -@patch("os.path.exists") -def test_mlflow_input_is_local_path(mock_path_exists): - valid_path = "/path/to/mlflow_model" - mock_path_exists.side_effect = lambda path: path == valid_path - - assert not _mlflow_input_is_local_path("s3://my_bucket/path/to/model") - assert not _mlflow_input_is_local_path("runs:/run-id/run/relative/path/to/model") - assert not _mlflow_input_is_local_path("/invalid/path") - assert _mlflow_input_is_local_path(valid_path) - - def test_download_s3_artifacts(): pass diff --git a/tests/unit/sagemaker/serve/utils/test_lineage_utils.py b/tests/unit/sagemaker/serve/utils/test_lineage_utils.py index 25e4fe246e..99da766031 100644 --- a/tests/unit/sagemaker/serve/utils/test_lineage_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_lineage_utils.py @@ -14,6 +14,7 @@ from unittest.mock import call +import datetime import pytest from botocore.exceptions import ClientError from mock import Mock, patch @@ -22,6 +23,7 @@ from sagemaker.lineage.query import LineageSourceEnum from sagemaker.serve.utils.lineage_constants import ( + TRACKING_SERVER_CREATION_TIME_FORMAT, MLFLOW_RUN_ID, MLFLOW_MODEL_PACKAGE_PATH, MLFLOW_S3_PATH, @@ -55,7 +57,7 @@ def test_load_artifact_by_source_uri(mock_artifact_list): mock_artifact_list.return_value = mock_artifacts result = _load_artifact_by_source_uri( - source_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session + source_uri, sagemaker_session, artifact_type=LineageSourceEnum.MODEL_DATA.value ) mock_artifact_list.assert_called_once_with( @@ -77,7 +79,7 @@ def test_load_artifact_by_source_uri_no_match(mock_artifact_list): mock_artifact_list.return_value = mock_artifacts result = _load_artifact_by_source_uri( - source_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session + source_uri, sagemaker_session, artifact_type=LineageSourceEnum.MODEL_DATA.value ) mock_artifact_list.assert_called_once_with( @@ -104,7 +106,7 @@ def test_poll_lineage_artifact_found(mock_load_artifact): assert result == mock_artifact mock_load_artifact.assert_has_calls( [ - call(s3_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session), + call(s3_uri, sagemaker_session, artifact_type=LineageSourceEnum.MODEL_DATA.value), ] ) @@ -130,7 +132,7 @@ def test_poll_lineage_artifact_not_found(mock_load_artifact): @pytest.mark.parametrize( "mlflow_model_path, expected_output", [ - ("runs:/abc123", MLFLOW_RUN_ID), + ("runs:/abc123/my-model", MLFLOW_RUN_ID), ("models:/my-model/1", MLFLOW_REGISTRY_PATH), ( "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model-package", @@ -163,7 +165,8 @@ def test_get_mlflow_model_path_type_invalid(): def test_create_mlflow_model_path_lineage_artifact_success( mock_artifact_create, mock_get_mlflow_path_type ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_source_types = [dict(SourceIdType="Custom", Value="ModelBuilderInputModelData")] sagemaker_session = Mock(spec=Session) mock_artifact = Mock(spec=Artifact) mock_get_mlflow_path_type.return_value = "mlflow_run_id" @@ -175,6 +178,7 @@ def test_create_mlflow_model_path_lineage_artifact_success( mock_get_mlflow_path_type.assert_called_once_with(mlflow_model_path) mock_artifact_create.assert_called_once_with( source_uri=mlflow_model_path, + source_types=mock_source_types, artifact_type=MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, artifact_name="mlflow_run_id", properties={"model_builder_input_model_data_type": "mlflow_run_id"}, @@ -187,7 +191,7 @@ def test_create_mlflow_model_path_lineage_artifact_success( def test_create_mlflow_model_path_lineage_artifact_validation_exception( mock_artifact_create, mock_get_mlflow_path_type ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" sagemaker_session = Mock(spec=Session) mock_get_mlflow_path_type.return_value = "mlflow_run_id" mock_artifact_create.side_effect = ClientError( @@ -204,7 +208,7 @@ def test_create_mlflow_model_path_lineage_artifact_validation_exception( def test_create_mlflow_model_path_lineage_artifact_other_exception( mock_artifact_create, mock_get_mlflow_path_type ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" sagemaker_session = Mock(spec=Session) mock_get_mlflow_path_type.return_value = "mlflow_run_id" mock_artifact_create.side_effect = ClientError( @@ -220,18 +224,33 @@ def test_create_mlflow_model_path_lineage_artifact_other_exception( def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_existing( mock_load_artifact, mock_create_artifact ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_tracking_server_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + ) + mock_creation_time = datetime.datetime(2024, 5, 15, 0, 0, 0) sagemaker_session = Mock(spec=Session) + mock_sagemaker_client = Mock() + mock_describe_response = {"CreationTime": mock_creation_time} + mock_sagemaker_client.describe_mlflow_tracking_server.return_value = mock_describe_response + sagemaker_session.sagemaker_client = mock_sagemaker_client + mock_source_types_to_match = [ + "ModelBuilderInputModelData", + mock_tracking_server_arn, + mock_creation_time.strftime(TRACKING_SERVER_CREATION_TIME_FORMAT), + ] mock_artifact_summary = Mock(spec=ArtifactSummary) mock_load_artifact.return_value = mock_artifact_summary result = _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( - mlflow_model_path, sagemaker_session + mlflow_model_path, sagemaker_session, mock_tracking_server_arn ) assert result == mock_artifact_summary mock_load_artifact.assert_called_once_with( - mlflow_model_path, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, sagemaker_session + mlflow_model_path, + sagemaker_session, + mock_source_types_to_match, ) mock_create_artifact.assert_not_called() @@ -241,21 +260,38 @@ def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_exi def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_create( mock_load_artifact, mock_create_artifact ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_tracking_server_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + ) + mock_creation_time = datetime.datetime(2024, 5, 15, 0, 0, 0) sagemaker_session = Mock(spec=Session) + mock_sagemaker_client = Mock() + mock_describe_response = {"CreationTime": mock_creation_time} + mock_sagemaker_client.describe_mlflow_tracking_server.return_value = mock_describe_response + sagemaker_session.sagemaker_client = mock_sagemaker_client + mock_source_types_to_match = [ + "ModelBuilderInputModelData", + mock_tracking_server_arn, + mock_creation_time.strftime(TRACKING_SERVER_CREATION_TIME_FORMAT), + ] mock_artifact = Mock(spec=Artifact) mock_load_artifact.return_value = None mock_create_artifact.return_value = mock_artifact result = _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( - mlflow_model_path, sagemaker_session + mlflow_model_path, sagemaker_session, mock_tracking_server_arn ) assert result == mock_artifact mock_load_artifact.assert_called_once_with( - mlflow_model_path, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, sagemaker_session + mlflow_model_path, + sagemaker_session, + mock_source_types_to_match, + ) + mock_create_artifact.assert_called_once_with( + mlflow_model_path, sagemaker_session, mock_source_types_to_match ) - mock_create_artifact.assert_called_once_with(mlflow_model_path, sagemaker_session) @patch("sagemaker.lineage.association.Association.create") @@ -320,7 +356,10 @@ def test_add_association_between_artifacts_other_exception(mock_association_crea def test_maintain_lineage_tracking_for_mlflow_model_success( mock_add_association, mock_retrieve_create_artifact, mock_poll_artifact ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_tracking_server_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + ) s3_upload_path = "s3://mybucket/path/to/model" sagemaker_session = Mock(spec=Session) mock_model_data_artifact = Mock(spec=ArtifactSummary) @@ -329,7 +368,7 @@ def test_maintain_lineage_tracking_for_mlflow_model_success( mock_retrieve_create_artifact.return_value = mock_mlflow_model_artifact _maintain_lineage_tracking_for_mlflow_model( - mlflow_model_path, s3_upload_path, sagemaker_session + mlflow_model_path, s3_upload_path, sagemaker_session, mock_tracking_server_arn ) mock_poll_artifact.assert_called_once_with( @@ -338,7 +377,9 @@ def test_maintain_lineage_tracking_for_mlflow_model_success( sagemaker_session=sagemaker_session, ) mock_retrieve_create_artifact.assert_called_once_with( - mlflow_model_path=mlflow_model_path, sagemaker_session=sagemaker_session + mlflow_model_path=mlflow_model_path, + tracking_server_arn=mock_tracking_server_arn, + sagemaker_session=sagemaker_session, ) mock_add_association.assert_called_once_with( mlflow_model_path_artifact_arn=mock_mlflow_model_artifact.artifact_arn, @@ -355,14 +396,17 @@ def test_maintain_lineage_tracking_for_mlflow_model_success( def test_maintain_lineage_tracking_for_mlflow_model_no_model_data_artifact( mock_add_association, mock_retrieve_create_artifact, mock_poll_artifact ): - mlflow_model_path = "runs:/Ab12Cd34" + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_tracking_server_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + ) s3_upload_path = "s3://mybucket/path/to/model" sagemaker_session = Mock(spec=Session) mock_poll_artifact.return_value = None mock_retrieve_create_artifact.return_value = None _maintain_lineage_tracking_for_mlflow_model( - mlflow_model_path, s3_upload_path, sagemaker_session + mlflow_model_path, s3_upload_path, sagemaker_session, mock_tracking_server_arn ) mock_poll_artifact.assert_called_once_with( diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8ab186e27c..e95359d52c 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -6282,6 +6282,24 @@ def test_create_inference_recommendations_job_propogate_other_exception( assert "AccessDeniedException" in str(error) +def test_create_presigned_mlflow_tracking_server_url(sagemaker_session): + sagemaker_session.create_presigned_mlflow_tracking_server_url("ts", 1, 2) + assert ( + sagemaker_session.sagemaker_client.create_presigned_mlflow_tracking_server_url.called_with( + TrackingServerName="ts", ExpiresInSeconds=1, SessionExpirationDurationInSeconds=2 + ) + ) + + +def test_create_presigned_mlflow_tracking_server_url_minimal(sagemaker_session): + sagemaker_session.create_presigned_mlflow_tracking_server_url("ts") + assert ( + sagemaker_session.sagemaker_client.create_presigned_mlflow_tracking_server_url.called_with( + TrackingServerName="ts" + ) + ) + + DEFAULT_LOG_EVENTS_INFERENCE_RECOMMENDER = [ MockBotoException("ResourceNotFoundException"), {"nextForwardToken": None, "events": [{"timestamp": 1, "message": "hi there #1"}]},