Skip to content

feat(sagemaker-mlflow): New features for SageMaker MLflow #4744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/extras/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ nbformat>=5.9,<6
accelerate>=0.24.1,<=0.27.0
schema==0.7.5
tensorflow>=2.1,<=2.16
mlflow>=2.12.2,<2.13
12 changes: 12 additions & 0 deletions src/sagemaker/mlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
50 changes: 50 additions & 0 deletions src/sagemaker/mlflow/tracking_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.


"""This module contains code related to the Mlflow Tracking Server."""

from __future__ import absolute_import
from typing import Optional, TYPE_CHECKING
from sagemaker.apiutils import _utils

if TYPE_CHECKING:
from sagemaker import Session


def generate_mlflow_presigned_url(
name: str,
expires_in_seconds: Optional[int] = None,
session_expiration_duration_in_seconds: Optional[int] = None,
sagemaker_session: Optional["Session"] = None,
) -> str:
"""Generate a presigned url to acess the Mlflow UI.

Args:
name (str): Name of the Mlflow Tracking Server
expires_in_seconds (int): Expiration time of the first usage
of the presigned url in seconds.
session_expiration_duration_in_seconds (int): Session duration of the presigned url in
seconds after the first use.
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using the
default AWS configuration chain.
Returns:
(str): Authorized Url to acess the Mlflow UI.
"""
session = sagemaker_session or _utils.default_session()
api_response = session.create_presigned_mlflow_tracking_server_url(
name, expires_in_seconds, session_expiration_duration_in_seconds
)
return api_response["AuthorizedUrl"]
158 changes: 131 additions & 27 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# language governing permissions and limitations under the License.
"""Holds the ModelBuilder class and the ModelServer enum."""
from __future__ import absolute_import

import importlib.util
import uuid
from typing import Any, Type, List, Dict, Optional, Union
from dataclasses import dataclass, field
import logging
import os
import re

from pathlib import Path

Expand All @@ -43,12 +46,15 @@
from sagemaker.predictor import Predictor
from sagemaker.serve.model_format.mlflow.constants import (
MLFLOW_MODEL_PATH,
MLFLOW_TRACKING_ARN,
MLFLOW_RUN_ID_REGEX,
MLFLOW_REGISTRY_PATH_REGEX,
MODEL_PACKAGE_ARN_REGEX,
MLFLOW_METADATA_FILE,
MLFLOW_PIP_DEPENDENCY_FILE,
)
from sagemaker.serve.model_format.mlflow.utils import (
_get_default_model_server_for_mlflow,
_mlflow_input_is_local_path,
_download_s3_artifacts,
_select_container_for_mlflow_model,
_generate_mlflow_artifact_path,
Expand Down Expand Up @@ -276,8 +282,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
default=None,
metadata={
"help": "Define the model metadata to override, currently supports `HF_TASK`, "
"`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in "
"the Hub, Adding unsupported task types will throw an exception"
"`MLFLOW_MODEL_PATH`, and `MLFLOW_TRACKING_ARN`. HF_TASK should be set for new "
"models without task metadata in the Hub, Adding unsupported task types will "
"throw an exception"
},
)

Expand Down Expand Up @@ -502,6 +509,7 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
s3_upload_path=self.s3_upload_path,
sagemaker_session=self.sagemaker_session,
tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN),
)
return new_model_package

Expand Down Expand Up @@ -572,6 +580,7 @@ def _model_builder_deploy_wrapper(
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
s3_upload_path=self.s3_upload_path,
sagemaker_session=self.sagemaker_session,
tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN),
)
return predictor

Expand Down Expand Up @@ -625,11 +634,30 @@ def wrapper(*args, **kwargs):

return wrapper

def _check_if_input_is_mlflow_model(self) -> bool:
"""Checks whether an MLmodel file exists in the given directory.
def _handle_mlflow_input(self):
"""Check whether an MLflow model is present and handle accordingly"""
self._is_mlflow_model = self._has_mlflow_arguments()
if not self._is_mlflow_model:
return

mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
artifact_path = self._get_artifact_path(mlflow_model_path)
if not self._mlflow_metadata_exists(artifact_path):
logger.info(
"MLflow model metadata not detected in %s. ModelBuilder is not "
"handling MLflow model input",
mlflow_model_path,
)
return

self._initialize_for_mlflow(artifact_path)
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))

def _has_mlflow_arguments(self) -> bool:
"""Check whether MLflow model arguments are present

Returns:
bool: True if the MLmodel file exists, False otherwise.
bool: True if MLflow arguments are present, False otherwise.
"""
if self.inference_spec or self.model:
logger.info(
Expand All @@ -644,16 +672,82 @@ def _check_if_input_is_mlflow_model(self) -> bool:
)
return False

path = self.model_metadata.get(MLFLOW_MODEL_PATH)
if not path:
mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
if not mlflow_model_path:
logger.info(
"%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model "
"input",
MLFLOW_MODEL_PATH,
)
return False

# Check for S3 path
return True

def _get_artifact_path(self, mlflow_model_path: str) -> str:
"""Retrieves the model artifact location given the Mlflow model input.

Args:
mlflow_model_path (str): The MLflow model path input.

Returns:
str: The path to the model artifact.
"""
if (is_run_id_type := re.match(MLFLOW_RUN_ID_REGEX, mlflow_model_path)) or re.match(
MLFLOW_REGISTRY_PATH_REGEX, mlflow_model_path
):
mlflow_tracking_arn = self.model_metadata.get(MLFLOW_TRACKING_ARN)
if not mlflow_tracking_arn:
raise ValueError(
"%s is not provided in ModelMetadata or through set_tracking_arn "
"but MLflow model path was provided." % MLFLOW_TRACKING_ARN,
)

if not importlib.util.find_spec("sagemaker_mlflow"):
raise ImportError(
"Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
)

import mlflow

mlflow.set_tracking_uri(mlflow_tracking_arn)
if is_run_id_type:
_, run_id, model_path = mlflow_model_path.split("/", 2)
artifact_uri = mlflow.get_run(run_id).info.artifact_uri
if not artifact_uri.endswith("/"):
artifact_uri += "/"
return artifact_uri + model_path

mlflow_client = mlflow.MlflowClient()
if not mlflow_model_path.endswith("/"):
mlflow_model_path += "/"

if "@" in mlflow_model_path:
_, model_name_and_alias, artifact_uri = mlflow_model_path.split("/", 2)
model_name, model_alias = model_name_and_alias.split("@")
model_metadata = mlflow_client.get_model_version_by_alias(model_name, model_alias)
else:
_, model_name, model_version, artifact_uri = mlflow_model_path.split("/", 3)
model_metadata = mlflow_client.get_model_version(model_name, model_version)

source = model_metadata.source
if not source.endswith("/"):
source += "/"
return source + artifact_uri

if re.match(MODEL_PACKAGE_ARN_REGEX, mlflow_model_path):
model_package = self.sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=mlflow_model_path
)
return model_package["SourceUri"]

return mlflow_model_path

def _mlflow_metadata_exists(self, path: str) -> bool:
"""Checks whether an MLmodel file exists in the given directory.

Returns:
bool: True if the MLmodel file exists, False otherwise.
"""
if path.startswith("s3://"):
s3_downloader = S3Downloader()
if not path.endswith("/"):
Expand All @@ -665,17 +759,18 @@ def _check_if_input_is_mlflow_model(self) -> bool:
file_path = os.path.join(path, MLFLOW_METADATA_FILE)
return os.path.isfile(file_path)

def _initialize_for_mlflow(self) -> None:
"""Initialize mlflow model artifacts, image uri and model server."""
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
if not _mlflow_input_is_local_path(mlflow_path):
# TODO: extend to package arn, run id and etc.
logger.info(
"Start downloading model artifacts from %s to %s", mlflow_path, self.model_path
)
_download_s3_artifacts(mlflow_path, self.model_path, self.sagemaker_session)
def _initialize_for_mlflow(self, artifact_path: str) -> None:
"""Initialize mlflow model artifacts, image uri and model server.

Args:
artifact_path (str): The path to the artifact store.
"""
if artifact_path.startswith("s3://"):
_download_s3_artifacts(artifact_path, self.model_path, self.sagemaker_session)
elif os.path.exists(artifact_path):
_copy_directory_contents(artifact_path, self.model_path)
else:
_copy_directory_contents(mlflow_path, self.model_path)
raise ValueError("Invalid path: %s" % artifact_path)
mlflow_model_metadata_path = _generate_mlflow_artifact_path(
self.model_path, MLFLOW_METADATA_FILE
)
Expand Down Expand Up @@ -728,6 +823,8 @@ def build( # pylint: disable=R0911
self.role_arn = role_arn
self.sagemaker_session = sagemaker_session or Session()

self.sagemaker_session.settings._local_download_dir = self.model_path

# https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258
# decorate to_string() due to
# https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015
Expand All @@ -739,14 +836,8 @@ def build( # pylint: disable=R0911
self.serve_settings = self._get_serve_setting()

self._is_custom_image_uri = self.image_uri is not None
self._is_mlflow_model = self._check_if_input_is_mlflow_model()
if self._is_mlflow_model:
logger.warning(
"Support of MLflow format models is experimental and is not intended"
" for production at this moment."
)
self._initialize_for_mlflow()
_validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))

self._handle_mlflow_input()

if isinstance(self.model, str):
model_task = None
Expand Down Expand Up @@ -836,6 +927,19 @@ def validate(self, model_dir: str) -> Type[bool]:

return get_metadata(model_dir)

def set_tracking_arn(self, arn: str):
"""Set tracking server ARN"""
# TODO: support native MLflow URIs
if importlib.util.find_spec("sagemaker_mlflow"):
import mlflow

mlflow.set_tracking_uri(arn)
self.model_metadata[MLFLOW_TRACKING_ARN] = arn
else:
raise ImportError(
"Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed"
)

def _hf_schema_builder_init(self, model_task: str):
"""Initialize the schema builder for the given HF_TASK

Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/serve/model_format/mlflow/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
MODEL_PACKAGE_ARN_REGEX = (
r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$"
)
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$"
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$"
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9\-_\.]*)+$"
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+[@/]?[a-zA-Z0-9\-_\.][/a-zA-Z0-9\-_\.]*$"
S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+(?:\/[a-zA-Z0-9\-_\/\.]*)?$"
MLFLOW_TRACKING_ARN = "MLFLOW_TRACKING_ARN"
MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH"
MLFLOW_METADATA_FILE = "MLmodel"
MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt"
Expand Down
22 changes: 0 additions & 22 deletions src/sagemaker/serve/model_format/mlflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,28 +227,6 @@ def _get_python_version_from_parsed_mlflow_model_file(
raise ValueError(f"{MLFLOW_PYFUNC} cannot be found in MLmodel file.")


def _mlflow_input_is_local_path(model_path: str) -> bool:
"""Checks if the given model_path is a local filesystem path.

Args:
- model_path (str): The model path to check.

Returns:
- bool: True if model_path is a local path, False otherwise.
"""
if model_path.startswith("s3://"):
return False

if "/runs/" in model_path or model_path.startswith("runs:"):
return False

# Check if it's not a local file path
if not os.path.exists(model_path):
return False

return True


def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> None:
"""Downloads all artifacts from a specified S3 path to a local destination path.

Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/serve/utils/lineage_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

LINEAGE_POLLER_INTERVAL_SECS = 15
LINEAGE_POLLER_MAX_TIMEOUT_SECS = 120
TRACKING_SERVER_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):mlflow-tracking-server/(.*?)$"
TRACKING_SERVER_CREATION_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE = "ModelBuilderInputModelData"
MLFLOW_S3_PATH = "S3"
MLFLOW_MODEL_PACKAGE_PATH = "ModelPackage"
Expand Down
Loading