1212# language governing permissions and limitations under the License.
1313"""Holds the ModelBuilder class and the ModelServer enum."""
1414from __future__ import absolute_import
15+
16+ import importlib .util
1517import uuid
1618from typing import Any , Type , List , Dict , Optional , Union
1719from dataclasses import dataclass , field
1820import logging
1921import os
22+ import re
2023
2124from pathlib import Path
2225
4346from sagemaker .predictor import Predictor
4447from sagemaker .serve .model_format .mlflow .constants import (
4548 MLFLOW_MODEL_PATH ,
49+ MLFLOW_TRACKING_ARN ,
50+ MLFLOW_RUN_ID_REGEX ,
51+ MLFLOW_REGISTRY_PATH_REGEX ,
52+ MODEL_PACKAGE_ARN_REGEX ,
4653 MLFLOW_METADATA_FILE ,
4754 MLFLOW_PIP_DEPENDENCY_FILE ,
4855)
4956from sagemaker .serve .model_format .mlflow .utils import (
5057 _get_default_model_server_for_mlflow ,
51- _mlflow_input_is_local_path ,
5258 _download_s3_artifacts ,
5359 _select_container_for_mlflow_model ,
5460 _generate_mlflow_artifact_path ,
@@ -276,8 +282,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
276282 default = None ,
277283 metadata = {
278284 "help" : "Define the model metadata to override, currently supports `HF_TASK`, "
279- "`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in "
280- "the Hub, Adding unsupported task types will throw an exception"
285+ "`MLFLOW_MODEL_PATH`, and `MLFLOW_TRACKING_ARN`. HF_TASK should be set for new "
286+ "models without task metadata in the Hub, Adding unsupported task types will "
287+ "throw an exception"
281288 },
282289 )
283290
@@ -501,6 +508,7 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
501508 _maintain_lineage_tracking_for_mlflow_model (
502509 mlflow_model_path = self .model_metadata [MLFLOW_MODEL_PATH ],
503510 s3_upload_path = self .s3_upload_path ,
511+ tracking_server_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN ),
504512 sagemaker_session = self .sagemaker_session ,
505513 )
506514 return new_model_package
@@ -571,6 +579,7 @@ def _model_builder_deploy_wrapper(
571579 _maintain_lineage_tracking_for_mlflow_model (
572580 mlflow_model_path = self .model_metadata [MLFLOW_MODEL_PATH ],
573581 s3_upload_path = self .s3_upload_path ,
582+ tracking_server_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN ),
574583 sagemaker_session = self .sagemaker_session ,
575584 )
576585 return predictor
@@ -625,11 +634,30 @@ def wrapper(*args, **kwargs):
625634
626635 return wrapper
627636
628- def _check_if_input_is_mlflow_model (self ) -> bool :
629- """Checks whether an MLmodel file exists in the given directory.
637+ def _handle_mlflow_input (self ):
638+ """Check whether an MLflow model is present and handle accordingly"""
639+ self ._is_mlflow_model = self ._has_mlflow_arguments ()
640+ if not self ._is_mlflow_model :
641+ return
642+
643+ mlflow_model_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
644+ artifact_path = self ._get_artifact_path (mlflow_model_path )
645+ if not self ._mlflow_metadata_exists (artifact_path ):
646+ logger .info (
647+ "MLflow model metadata not detected in %s. ModelBuilder is not "
648+ "handling MLflow model input" ,
649+ mlflow_model_path ,
650+ )
651+ return
652+
653+ self ._initialize_for_mlflow (artifact_path )
654+ _validate_input_for_mlflow (self .model_server , self .env_vars .get ("MLFLOW_MODEL_FLAVOR" ))
655+
656+ def _has_mlflow_arguments (self ) -> bool :
657+ """Check whether MLflow model arguments are present
630658
631659 Returns:
632- bool: True if the MLmodel file exists , False otherwise.
660+ bool: True if MLflow arguments are present , False otherwise.
633661 """
634662 if self .inference_spec or self .model :
635663 logger .info (
@@ -644,16 +672,80 @@ def _check_if_input_is_mlflow_model(self) -> bool:
644672 )
645673 return False
646674
647- path = self .model_metadata .get (MLFLOW_MODEL_PATH )
648- if not path :
675+ mlflow_model_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
676+ if not mlflow_model_path :
649677 logger .info (
650678 "%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model "
651679 "input" ,
652680 MLFLOW_MODEL_PATH ,
653681 )
654682 return False
655683
656- # Check for S3 path
684+ return True
685+
686+ def _get_artifact_path (self , mlflow_model_path : str ) -> str :
687+ """Retrieves the model artifact location given the Mlflow model input.
688+
689+ Args:
690+ mlflow_model_path (str): The MLflow model path input.
691+
692+ Returns:
693+ str: The path to the model artifact.
694+ """
695+ if (is_run_id_type := re .match (MLFLOW_RUN_ID_REGEX , mlflow_model_path )) or re .match (
696+ MLFLOW_REGISTRY_PATH_REGEX , mlflow_model_path
697+ ):
698+ mlflow_tracking_arn = self .model_metadata .get (MLFLOW_TRACKING_ARN )
699+ if not mlflow_tracking_arn :
700+ raise ValueError (
701+ "%s is not provided in ModelMetadata or through set_tracking_arn "
702+ "but MLflow model path was provided." % MLFLOW_TRACKING_ARN ,
703+ )
704+
705+ if not importlib .util .find_spec ("awsmlflow" ):
706+ raise ImportError ("Unable to import awsmlflow, check if awsmlflow is installed" )
707+
708+ import mlflow
709+
710+ mlflow .set_tracking_uri (mlflow_tracking_arn )
711+ if is_run_id_type :
712+ _ , run_id , model_path = mlflow_model_path .split ("/" , 2 )
713+ artifact_uri = mlflow .get_run (run_id ).info .artifact_uri
714+ if not artifact_uri .endswith ("/" ):
715+ artifact_uri += "/"
716+ return artifact_uri + model_path
717+
718+ mlflow_client = mlflow .MlflowClient ()
719+ if not mlflow_model_path .endswith ("/" ):
720+ mlflow_model_path += "/"
721+
722+ if "@" in mlflow_model_path :
723+ _ , model_name_and_alias , artifact_uri = mlflow_model_path .split ("/" , 2 )
724+ model_name , model_alias = model_name_and_alias .split ("@" )
725+ model_metadata = mlflow_client .get_model_version_by_alias (model_name , model_alias )
726+ else :
727+ _ , model_name , model_version , artifact_uri = mlflow_model_path .split ("/" , 3 )
728+ model_metadata = mlflow_client .get_model_version (model_name , model_version )
729+
730+ source = model_metadata .source
731+ if not source .endswith ("/" ):
732+ source += "/"
733+ return source + artifact_uri
734+
735+ if re .match (MODEL_PACKAGE_ARN_REGEX , mlflow_model_path ):
736+ model_package = self .sagemaker_session .sagemaker_client .describe_model_package (
737+ ModelPackageName = mlflow_model_path
738+ )
739+ return model_package ["SourceUri" ]
740+
741+ return mlflow_model_path
742+
743+ def _mlflow_metadata_exists (self , path : str ) -> bool :
744+ """Checks whether an MLmodel file exists in the given directory.
745+
746+ Returns:
747+ bool: True if the MLmodel file exists, False otherwise.
748+ """
657749 if path .startswith ("s3://" ):
658750 s3_downloader = S3Downloader ()
659751 if not path .endswith ("/" ):
@@ -665,17 +757,18 @@ def _check_if_input_is_mlflow_model(self) -> bool:
665757 file_path = os .path .join (path , MLFLOW_METADATA_FILE )
666758 return os .path .isfile (file_path )
667759
668- def _initialize_for_mlflow (self ) -> None :
669- """Initialize mlflow model artifacts, image uri and model server."""
670- mlflow_path = self .model_metadata .get (MLFLOW_MODEL_PATH )
671- if not _mlflow_input_is_local_path (mlflow_path ):
672- # TODO: extend to package arn, run id and etc.
673- logger .info (
674- "Start downloading model artifacts from %s to %s" , mlflow_path , self .model_path
675- )
676- _download_s3_artifacts (mlflow_path , self .model_path , self .sagemaker_session )
760+ def _initialize_for_mlflow (self , artifact_path : str ) -> None :
761+ """Initialize mlflow model artifacts, image uri and model server.
762+
763+ Args:
764+ artifact_path (str): The path to the artifact store.
765+ """
766+ if artifact_path .startswith ("s3://" ):
767+ _download_s3_artifacts (artifact_path , self .model_path , self .sagemaker_session )
768+ elif os .path .exists (artifact_path ):
769+ _copy_directory_contents (artifact_path , self .model_path )
677770 else :
678- _copy_directory_contents ( mlflow_path , self . model_path )
771+ raise ValueError ( "Invalid path: %s" % artifact_path )
679772 mlflow_model_metadata_path = _generate_mlflow_artifact_path (
680773 self .model_path , MLFLOW_METADATA_FILE
681774 )
@@ -728,6 +821,8 @@ def build( # pylint: disable=R0911
728821 self .role_arn = role_arn
729822 self .sagemaker_session = sagemaker_session or Session ()
730823
824+ self .sagemaker_session .settings ._local_download_dir = self .model_path
825+
731826 # https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258
732827 # decorate to_string() due to
733828 # https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015
@@ -739,14 +834,8 @@ def build( # pylint: disable=R0911
739834 self .serve_settings = self ._get_serve_setting ()
740835
741836 self ._is_custom_image_uri = self .image_uri is not None
742- self ._is_mlflow_model = self ._check_if_input_is_mlflow_model ()
743- if self ._is_mlflow_model :
744- logger .warning (
745- "Support of MLflow format models is experimental and is not intended"
746- " for production at this moment."
747- )
748- self ._initialize_for_mlflow ()
749- _validate_input_for_mlflow (self .model_server , self .env_vars .get ("MLFLOW_MODEL_FLAVOR" ))
837+
838+ self ._handle_mlflow_input ()
750839
751840 if isinstance (self .model , str ):
752841 model_task = None
@@ -836,6 +925,17 @@ def validate(self, model_dir: str) -> Type[bool]:
836925
837926 return get_metadata (model_dir )
838927
928+ def set_tracking_arn (self , arn : str ):
929+ """Set tracking server ARN"""
930+ # TODO: support native MLflow URIs
931+ if importlib .util .find_spec ("awsmlflow" ):
932+ import mlflow
933+
934+ mlflow .set_tracking_uri (arn )
935+ self .model_metadata [MLFLOW_TRACKING_ARN ] = arn
936+ else :
937+ raise ImportError ("Unable to import awsmlflow, check if awsmlflow is installed" )
938+
839939 def _hf_schema_builder_init (self , model_task : str ):
840940 """Initialize the schema builder for the given HF_TASK
841941
0 commit comments