Skip to content

Commit ed80f56

Browse files
jiapinwknikure
authored andcommitted
[Fix] regex for RunId to handle empty artifact path and change mlflow plugin name (aws#1455)
* [Fix] run id regex pattern such that empty artifact path is handled * Change mlflow plugin name as per legal team requirement
1 parent 509652c commit ed80f56

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -702,8 +702,10 @@ def _get_artifact_path(self, mlflow_model_path: str) -> str:
702702
"but MLflow model path was provided." % MLFLOW_TRACKING_ARN,
703703
)
704704

705-
if not importlib.util.find_spec("awsmlflow"):
706-
raise ImportError("Unable to import awsmlflow, check if awsmlflow is installed")
705+
if not importlib.util.find_spec("mlflow_sagemaker"):
706+
raise ImportError(
707+
"Unable to import mlflow_sagemaker, check if mlflow_sagemaker is installed"
708+
)
707709

708710
import mlflow
709711

@@ -928,13 +930,15 @@ def validate(self, model_dir: str) -> Type[bool]:
928930
def set_tracking_arn(self, arn: str):
929931
"""Set tracking server ARN"""
930932
# TODO: support native MLflow URIs
931-
if importlib.util.find_spec("awsmlflow"):
933+
if importlib.util.find_spec("mlflow_sagemaker"):
932934
import mlflow
933935

934936
mlflow.set_tracking_uri(arn)
935937
self.model_metadata[MLFLOW_TRACKING_ARN] = arn
936938
else:
937-
raise ImportError("Unable to import awsmlflow, check if awsmlflow is installed")
939+
raise ImportError(
940+
"Unable to import mlflow_sagemaker, check if mlflow_sagemaker is installed"
941+
)
938942

939943
def _hf_schema_builder_init(self, model_task: str):
940944
"""Initialize the schema builder for the given HF_TASK

src/sagemaker/serve/model_format/mlflow/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
MODEL_PACKAGE_ARN_REGEX = (
2323
r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$"
2424
)
25-
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+/[/a-zA-Z0-9\-_\.]+$"
25+
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9\-_\.]*)+$"
2626
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+[@/]?[a-zA-Z0-9\-_\.][/a-zA-Z0-9\-_\.]*$"
2727
S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+(?:\/[a-zA-Z0-9\-_\/\.]*)?$"
2828
MLFLOW_TRACKING_ARN = "MLFLOW_TRACKING_ARN"

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2438,7 +2438,7 @@ def test_set_tracking_arn_mlflow_not_installed(self):
24382438
tracking_arn = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test"
24392439
self.assertRaisesRegex(
24402440
ImportError,
2441-
"Unable to import awsmlflow, check if awsmlflow is installed",
2441+
"Unable to import mlflow_sagemaker, check if mlflow_sagemaker is installed",
24422442
builder.set_tracking_arn,
24432443
tracking_arn,
24442444
)

0 commit comments

Comments
 (0)