diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index ad70ffa805..2addb0a044 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -14,7 +14,9 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Union +from functools import lru_cache +from typing import Dict, List, Optional, Union, Any +import pandas as pd from botocore.exceptions import ClientError from sagemaker import payloads @@ -36,14 +38,21 @@ get_init_kwargs, get_register_kwargs, ) -from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.types import ( + JumpStartSerializablePayload, + DeploymentConfigMetadata, + JumpStartBenchmarkStat, + JumpStartMetadataConfig, +) from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, verify_model_region_and_return_specs, + get_jumpstart_configs, + extract_metrics_from_deployment_configs, ) from sagemaker.jumpstart.constants import JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType -from sagemaker.utils import stringify_object, format_tags, Tags +from sagemaker.utils import stringify_object, format_tags, Tags, get_instance_rate_per_hour from sagemaker.model import ( Model, ModelPackage, @@ -352,6 +361,18 @@ def _validate_model_id_and_type(): self.model_package_arn = model_init_kwargs.model_package_arn self.init_kwargs = model_init_kwargs.to_kwargs_dict(False) + metadata_configs = get_jumpstart_configs( + region=self.region, + model_id=self.model_id, + model_version=self.model_version, + sagemaker_session=self.sagemaker_session, + model_type=self.model_type, + ) + self._deployment_configs = [ + self._convert_to_deployment_config_metadata(config_name, config) + for config_name, config in metadata_configs.items() + ] + def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" subscription_link = verify_model_region_and_return_specs( @@ -420,6 +441,27 @@ def set_deployment_config(self, config_name: Optional[str]) -> None: model_id=self.model_id, model_version=self.model_version, config_name=config_name ) + @property + def benchmark_metrics(self) -> pd.DataFrame: + """Benchmark Metrics for deployment configs + + Returns: + Metrics: Pandas DataFrame object. + """ + return pd.DataFrame(self._get_benchmark_data(self.config_name)) + + def display_benchmark_metrics(self) -> None: + """Display Benchmark Metrics for deployment configs.""" + print(self.benchmark_metrics.to_markdown()) + + def list_deployment_configs(self) -> List[Dict[str, Any]]: + """List deployment configs for ``This`` model. + + Returns: + List[Dict[str, Any]]: A list of deployment configs. + """ + return self._deployment_configs + def _create_sagemaker_model( self, instance_type=None, @@ -808,6 +850,67 @@ def register_deploy_wrapper(*args, **kwargs): return model_package + @lru_cache + def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]: + """Constructs deployment configs benchmark data. + + Args: + config_name (str): The name of the selected deployment config. + Returns: + Dict[str, List[str]]: Deployment config benchmark data. + """ + return extract_metrics_from_deployment_configs( + self._deployment_configs, + config_name, + ) + + def _convert_to_deployment_config_metadata( + self, config_name: str, metadata_config: JumpStartMetadataConfig + ) -> Dict[str, Any]: + """Retrieve deployment config for config name. + + Args: + config_name (str): Name of deployment config. + metadata_config (JumpStartMetadataConfig): Metadata config for deployment config. + Returns: + A deployment metadata config for config name (dict[str, Any]). + """ + default_inference_instance_type = metadata_config.resolved_config.get( + "default_inference_instance_type" + ) + + instance_rate = get_instance_rate_per_hour( + instance_type=default_inference_instance_type, region=self.region + ) + + benchmark_metrics = ( + metadata_config.benchmark_metrics.get(default_inference_instance_type) + if metadata_config.benchmark_metrics is not None + else None + ) + if instance_rate is not None: + if benchmark_metrics is not None: + benchmark_metrics.append(JumpStartBenchmarkStat(instance_rate)) + else: + benchmark_metrics = [JumpStartBenchmarkStat(instance_rate)] + + init_kwargs = get_init_kwargs( + model_id=self.model_id, + instance_type=default_inference_instance_type, + sagemaker_session=self.sagemaker_session, + ) + deploy_kwargs = get_deploy_kwargs( + model_id=self.model_id, + instance_type=default_inference_instance_type, + sagemaker_session=self.sagemaker_session, + ) + + deployment_config_metadata = DeploymentConfigMetadata( + config_name, benchmark_metrics, init_kwargs, deploy_kwargs + ) + + return deployment_config_metadata.to_json() + def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" return stringify_object(self) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 1de0f662da..07bd769054 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2206,3 +2206,99 @@ def __init__( self.skip_model_validation = skip_model_validation self.source_uri = source_uri self.config_name = config_name + + +class BaseDeploymentConfigDataHolder(JumpStartDataHolderType): + """Base class for Deployment Config Data.""" + + def _convert_to_pascal_case(self, attr_name: str) -> str: + """Converts a snake_case attribute name into a camelCased string. + + Args: + attr_name (str): The snake_case attribute name. + Returns: + str: The PascalCased attribute name. + """ + return attr_name.replace("_", " ").title().replace(" ", "") + + def to_json(self) -> Dict[str, Any]: + """Represents ``This`` object as JSON.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + att = self._convert_to_pascal_case(att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + elif isinstance(cur_val, dict): + json_obj[att] = {} + for key, val in cur_val.items(): + if issubclass(type(val), JumpStartDataHolderType): + json_obj[att][self._convert_to_pascal_case(key)] = val.to_json() + else: + json_obj[att][key] = val + else: + json_obj[att] = cur_val + return json_obj + + +class DeploymentConfig(BaseDeploymentConfigDataHolder): + """Dataclass representing a Deployment Config.""" + + __slots__ = [ + "model_data_download_timeout", + "container_startup_health_check_timeout", + "image_uri", + "model_data", + "instance_type", + "environment", + "compute_resource_requirements", + ] + + def __init__( + self, init_kwargs: JumpStartModelInitKwargs, deploy_kwargs: JumpStartModelDeployKwargs + ): + """Instantiates DeploymentConfig object.""" + if init_kwargs is not None: + self.image_uri = init_kwargs.image_uri + self.model_data = init_kwargs.model_data + self.instance_type = init_kwargs.instance_type + self.environment = init_kwargs.env + if init_kwargs.resources is not None: + self.compute_resource_requirements = ( + init_kwargs.resources.get_compute_resource_requirements() + ) + if deploy_kwargs is not None: + self.model_data_download_timeout = deploy_kwargs.model_data_download_timeout + self.container_startup_health_check_timeout = ( + deploy_kwargs.container_startup_health_check_timeout + ) + + +class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): + """Dataclass representing a Deployment Config Metadata""" + + __slots__ = [ + "config_name", + "benchmark_metrics", + "deployment_config", + ] + + def __init__( + self, + config_name: str, + benchmark_metrics: List[JumpStartBenchmarkStat], + init_kwargs: JumpStartModelInitKwargs, + deploy_kwargs: JumpStartModelDeployKwargs, + ): + """Instantiates DeploymentConfigMetadata object.""" + self.config_name = config_name + self.benchmark_metrics = benchmark_metrics + self.deployment_config = DeploymentConfig(init_kwargs, deploy_kwargs) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 1459594faa..905f2a18d5 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -999,3 +999,44 @@ def get_jumpstart_configs( if metadata_configs else {} ) + + +def extract_metrics_from_deployment_configs( + deployment_configs: List[Dict[str, Any]], config_name: str +) -> Dict[str, List[str]]: + """Extracts metrics from deployment configs. + + Args: + deployment_configs (list[dict[str, Any]]): List of deployment configs. + config_name (str): The name of the deployment config use by the model. + """ + + data = {"Config Name": [], "Instance Type": [], "Selected": []} + + for index, deployment_config in enumerate(deployment_configs): + if deployment_config.get("DeploymentConfig") is None: + continue + + benchmark_metrics = deployment_config.get("BenchmarkMetrics") + if benchmark_metrics is not None: + data["Config Name"].append(deployment_config.get("ConfigName")) + data["Instance Type"].append( + deployment_config.get("DeploymentConfig").get("InstanceType") + ) + data["Selected"].append( + "Yes" + if (config_name is not None and config_name == deployment_config.get("ConfigName")) + else "No" + ) + + if index == 0: + for benchmark_metric in benchmark_metrics: + column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})" + data[column_name] = [] + + for benchmark_metric in benchmark_metrics: + column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})" + if column_name in data.keys(): + data[column_name].append(benchmark_metric.get("value")) + + return data diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index e3368869fe..c1760311e7 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -16,7 +16,7 @@ import copy from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Type +from typing import Type, Any, List, Dict import logging from sagemaker.model import Model @@ -431,6 +431,18 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration ) + def display_benchmark_metrics(self): + """Display Markdown Benchmark Metrics for deployment configs.""" + self.pysdk_model.display_benchmark_metrics() + + def list_deployment_configs(self) -> List[Dict[str, Any]]: + """List deployment configs for ``This`` model in the current region. + + Returns: + List[Dict[str, Any]]: A list of deployment configs. + """ + return self.pysdk_model.list_deployment_configs() + def _build_for_jumpstart(self): """Placeholder docstring""" # we do not pickle for jumpstart. set to none diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 0436c0afea..35f60b37e1 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -25,6 +25,7 @@ import tarfile import tempfile import time +from functools import lru_cache from typing import Union, Any, List, Optional, Dict import json import abc @@ -33,6 +34,8 @@ from os.path import abspath, realpath, dirname, normpath, join as joinpath from importlib import import_module + +import boto3 import botocore from botocore.utils import merge_dicts from six.moves.urllib import parse @@ -1655,3 +1658,69 @@ def deep_override_dict( ) flattened_dict1.update(flattened_dict2) return unflatten_dict(flattened_dict1) if flattened_dict1 else {} + + +@lru_cache +def get_instance_rate_per_hour( + instance_type: str, + region: str, +) -> Union[Dict[str, str], None]: + """Gets instance rate per hour for the given instance type. + + Args: + instance_type (str): The instance type. + region (str): The region. + Returns: + Union[Dict[str, str], None]: Instance rate per hour. + Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.1250000000'}}. + """ + + region_name = "us-east-1" + if region.startswith("eu") or region.startswith("af"): + region_name = "eu-central-1" + elif region.startswith("ap") or region.startswith("cn"): + region_name = "ap-south-1" + + pricing_client: boto3.client = boto3.client("pricing", region_name=region_name) + try: + res = pricing_client.get_products( + ServiceCode="AmazonSageMaker", + Filters=[ + {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type}, + {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"}, + {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region}, + ], + ) + + price_list = res.get("PriceList", []) + if len(price_list) > 0: + price_data = price_list[0] + if isinstance(price_data, str): + price_data = json.loads(price_data) + + return extract_instance_rate_per_hour(price_data) + except Exception as e: # pylint: disable=W0703 + logging.exception("Error getting instance rate: %s", e) + return None + + +def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Union[Dict[str, str], None]: + """Extract instance rate per hour for the given Price JSON data. + + Args: + price_data (Dict[str, Any]): The Price JSON data. + Returns: + Union[Dict[str, str], None]: Instance rate per hour. + """ + + if price_data is not None: + price_dimensions = price_data.get("terms", {}).get("OnDemand", {}).values() + for dimension in price_dimensions: + for price in dimension.get("priceDimensions", {}).values(): + for currency in price.get("pricePerUnit", {}).keys(): + return { + "unit": f"{currency}/{price.get('unit', 'Hrs')}", + "value": price.get("pricePerUnit", {}).get(currency), + "name": "Instance Rate", + } + return None diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index f165a513a9..b83f85ffde 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7907,3 +7907,160 @@ }, } } + + +DEPLOYMENT_CONFIGS = [ + { + "ConfigName": "neuron-inference", + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], + "DeploymentConfig": { + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + }, + }, + { + "ConfigName": "neuron-inference-budget", + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], + "DeploymentConfig": { + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + }, + }, + { + "ConfigName": "gpu-inference-budget", + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], + "DeploymentConfig": { + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + }, + }, + { + "ConfigName": "gpu-inference", + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], + "DeploymentConfig": { + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + }, + }, +] + + +INIT_KWARGS = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu" + "-py310-cu121-ubuntu20.04", + "model_data": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface-textgeneration" + "-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "instance_type": "ml.p2.xlarge", + "env": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "role": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", + "name": "hf-textgeneration-bloom-1b1-2024-04-22-20-23-48-799", + "enable_network_isolation": True, +} diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index cb7b602fbf..2df904dce2 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -15,6 +15,7 @@ from typing import Optional, Set from unittest import mock import unittest +import pandas as pd from mock import MagicMock, Mock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig @@ -47,6 +48,9 @@ get_special_model_spec_for_inference_component_based_endpoint, get_prototype_manifest, get_prototype_model_spec, + get_base_spec_with_prototype_configs, + get_mock_init_kwargs, + get_base_deployment_configs, ) import boto3 @@ -64,6 +68,9 @@ class ModelTest(unittest.TestCase): mock_session_empty_config = MagicMock(sagemaker_config={}) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @@ -81,6 +88,7 @@ def test_non_prepacked( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -140,6 +148,9 @@ def test_non_prepacked( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -155,6 +166,7 @@ def test_non_prepacked_inference_component_based_endpoint( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -220,6 +232,9 @@ def test_non_prepacked_inference_component_based_endpoint( endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -235,6 +250,7 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -295,6 +311,9 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -308,6 +327,7 @@ def test_prepacked( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -354,6 +374,9 @@ def test_prepacked( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.model.LOGGER.warning") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @@ -371,6 +394,7 @@ def test_no_compiled_model_warning_log_js_models( mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, mock_warning: mock.Mock(), + mock_get_jumpstart_configs: mock.Mock, ): mock_timestamp.return_value = "1234" @@ -391,6 +415,9 @@ def test_no_compiled_model_warning_log_js_models( mock_warning.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") @@ -406,6 +433,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_timestamp.return_value = "1234" @@ -453,6 +481,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( ], ) + @mock.patch("sagemaker.jumpstart.model.get_jumpstart_configs") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @@ -470,7 +499,9 @@ def test_proprietary_model_endpoint( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): + mock_get_jumpstart_configs.side_effect = lambda *args, **kwargs: {} mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) @@ -510,6 +541,7 @@ def test_proprietary_model_endpoint( container_startup_health_check_timeout=600, ) + @mock.patch("sagemaker.jumpstart.model.get_jumpstart_configs") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -521,7 +553,9 @@ def test_deprecated( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): + mock_get_jumpstart_configs.side_effect = lambda *args, **kwargs: {} mock_model_deploy.return_value = default_predictor mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -537,6 +571,9 @@ def test_deprecated( JumpStartModel(model_id=model_id, tolerate_deprecated_model=True).deploy() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -548,6 +585,7 @@ def test_vulnerable( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -611,6 +649,9 @@ def test_model_use_kwargs(self): deploy_kwargs=all_deploy_kwargs_used, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.factory.model.environment_variables.retrieve_default") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -626,6 +667,7 @@ def evaluate_model_workflow_with_kwargs( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_retrieve_environment_variables: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, init_kwargs: Optional[dict] = None, deploy_kwargs: Optional[dict] = None, ): @@ -729,6 +771,9 @@ def test_jumpstart_model_kwargs_match_parent_class(self): assert js_class_deploy_args - parent_class_deploy_args == set() assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @@ -737,6 +782,7 @@ def test_validate_model_id_and_get_type( mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartModel(model_id="valid_model_id") @@ -745,6 +791,9 @@ def test_validate_model_id_and_get_type( with pytest.raises(ValueError): JumpStartModel(model_id="invalid_model_id") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -760,6 +809,7 @@ def test_no_predictor_returns_default_predictor( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -793,6 +843,9 @@ def test_no_predictor_returns_default_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -808,6 +861,7 @@ def test_no_predictor_yes_async_inference_config( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -829,6 +883,9 @@ def test_no_predictor_yes_async_inference_config( mock_get_default_predictor.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -844,6 +901,7 @@ def test_yes_predictor_returns_default_predictor( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -865,6 +923,9 @@ def test_yes_predictor_returns_default_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -880,6 +941,7 @@ def test_model_id_not_found_refeshes_cache_inference( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.side_effect = [False, False] @@ -948,6 +1010,9 @@ def test_model_id_not_found_refeshes_cache_inference( ] ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -955,6 +1020,7 @@ def test_jumpstart_model_tags( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -984,6 +1050,9 @@ def test_jumpstart_model_tags( [{"Key": "blah", "Value": "blahagain"}] + js_tags, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -991,6 +1060,7 @@ def test_jumpstart_model_tags_disabled( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1018,6 +1088,9 @@ def test_jumpstart_model_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1025,6 +1098,7 @@ def test_jumpstart_model_package_arn( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1052,6 +1126,9 @@ def test_jumpstart_model_package_arn( self.assertIn(tag, mock_session.create_model.call_args[1]["tags"]) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1059,6 +1136,7 @@ def test_jumpstart_model_package_arn_override( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1094,6 +1172,9 @@ def test_jumpstart_model_package_arn_override( }, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1103,6 +1184,7 @@ def test_jumpstart_model_package_arn_unsupported_region( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1120,6 +1202,9 @@ def test_jumpstart_model_package_arn_unsupported_region( "us-east-2. Please try one of the following regions: us-west-2, us-east-1." ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -1137,6 +1222,7 @@ def test_model_data_s3_prefix_override( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1186,6 +1272,9 @@ def test_model_data_s3_prefix_override( '"S3DataType": "S3Prefix", "CompressionType": "None"}}', ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1201,6 +1290,7 @@ def test_model_data_s3_prefix_model( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1230,6 +1320,9 @@ def test_model_data_s3_prefix_model( mock_js_info_logger.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1245,6 +1338,7 @@ def test_model_artifact_variant_model( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1295,6 +1389,9 @@ def test_model_artifact_variant_model( enable_network_isolation=True, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1308,6 +1405,7 @@ def test_model_registry_accept_and_response_types( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1327,6 +1425,9 @@ def test_model_registry_accept_and_response_types( response_types=["application/json;verbose", "application/json"], ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -1340,6 +1441,7 @@ def test_jumpstart_model_session( mock_deploy, mock_init, get_default_predictor, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = True @@ -1373,6 +1475,9 @@ def test_jumpstart_model_session( assert len(s3_clients) == 1 assert list(s3_clients)[0] == session.s3_client + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch.dict( "sagemaker.jumpstart.cache.os.environ", { @@ -1391,6 +1496,7 @@ def test_model_local_mode( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_model_specs.side_effect = get_prototype_model_spec mock_get_manifest.side_effect = ( @@ -1417,6 +1523,9 @@ def test_model_local_mode( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1428,6 +1537,7 @@ def test_model_initialization_with_config_name( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( @@ -1454,6 +1564,9 @@ def test_model_initialization_with_config_name( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1465,6 +1578,7 @@ def test_model_set_deployment_config( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_model_specs.side_effect = get_prototype_model_spec mock_get_manifest.side_effect = ( @@ -1509,6 +1623,9 @@ def test_model_set_deployment_config( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1520,6 +1637,7 @@ def test_model_unset_deployment_config( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( @@ -1564,6 +1682,173 @@ def test_model_unset_deployment_config( endpoint_logging=False, ) + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_list_deployment_configs( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_instance_rate_per_hour: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "0.0083000000", + } + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + configs = model.list_deployment_configs() + + self.assertEqual(configs, get_base_deployment_configs()) + + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_list_deployment_configs_empty( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_instance_rate_per_hour: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_special_model_spec(model_id="gemma-model") + ) + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "0.0083000000", + } + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + configs = model.list_deployment_configs() + + self.assertTrue(len(configs) == 0) + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_display_benchmark_metrics( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_instance_rate_per_hour: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "0.0083000000", + } + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + model.display_benchmark_metrics() + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_benchmark_metrics( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_instance_rate_per_hour: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "0.0083000000", + } + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + df = model.benchmark_metrics + + self.assertTrue(isinstance(df, pd.DataFrame)) + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 70409704e6..2be4bde7e4 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -60,6 +60,9 @@ class IntelligentDefaultsModelTest(unittest.TestCase): region = "us-west-2" sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -77,6 +80,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -101,6 +105,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( assert "enable_network_isolation" not in mock_model_init.call_args[1] + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -118,6 +125,7 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -147,6 +155,9 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -164,6 +175,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -193,6 +205,9 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( config_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -210,6 +225,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -241,6 +257,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -258,6 +277,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -287,6 +307,9 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( metadata_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -304,6 +327,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -334,6 +358,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -351,6 +378,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -375,6 +403,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_model_init.call_args[1] + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -392,6 +423,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 52f28f2da1..abce2bf687 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -159,6 +159,9 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( patched_get_default_predictor.assert_not_called() +@mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} +) @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") @patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @@ -170,6 +173,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_validate_model_id_and_get_type, patched_get_object_cached, patched_get_model_id_version_from_endpoint, + patched_get_jumpstart_configs, ): patched_get_object_cached.return_value = base64.b64decode("encodedimage") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index c1ea8abcb8..85911a2854 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -49,6 +49,7 @@ get_spec_from_base_spec, get_special_model_spec, get_prototype_manifest, + get_base_deployment_configs, ) from mock import MagicMock @@ -1708,3 +1709,52 @@ def test_get_jumpstart_benchmark_stats_training( ] }, } + + +@pytest.mark.parametrize( + "config_name, expected", + [ + ( + None, + { + "Config Name": [ + "neuron-inference", + "neuron-inference-budget", + "gpu-inference-budget", + "gpu-inference", + ], + "Instance Type": ["ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge"], + "Selected": ["No", "No", "No", "No"], + "Instance Rate (USD/Hrs)": [ + "0.0083000000", + "0.0083000000", + "0.0083000000", + "0.0083000000", + ], + }, + ), + ( + "neuron-inference", + { + "Config Name": [ + "neuron-inference", + "neuron-inference-budget", + "gpu-inference-budget", + "gpu-inference", + ], + "Instance Type": ["ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge"], + "Selected": ["Yes", "No", "No", "No"], + "Instance Rate (USD/Hrs)": [ + "0.0083000000", + "0.0083000000", + "0.0083000000", + "0.0083000000", + ], + }, + ), + ], +) +def test_extract_metrics_from_deployment_configs(config_name, expected): + data = utils.extract_metrics_from_deployment_configs(get_base_deployment_configs(), config_name) + + assert data == expected diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index aee1497ec9..e0d6f645a8 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -12,9 +12,10 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy -from typing import List +from typing import List, Dict, Any import boto3 +from sagemaker.compute_resource_requirements import ResourceRequirements from sagemaker.jumpstart.cache import JumpStartModelsCache from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, @@ -27,6 +28,7 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, + JumpStartModelInitKwargs, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest @@ -43,6 +45,8 @@ SPECIAL_MODEL_SPECS_DICT, TRAINING_CONFIG_RANKINGS, TRAINING_CONFIGS, + DEPLOYMENT_CONFIGS, + INIT_KWARGS, ) @@ -297,3 +301,19 @@ def overwrite_dictionary( base_dictionary[key] = value return base_dictionary + + +def get_base_deployment_configs() -> List[Dict[str, Any]]: + return DEPLOYMENT_CONFIGS + + +def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs: + return JumpStartModelInitKwargs( + model_id=model_id, + model_type=JumpStartModelType.OPEN_WEIGHTS, + model_data=INIT_KWARGS.get("model_data"), + image_uri=INIT_KWARGS.get("image_uri"), + instance_type=INIT_KWARGS.get("instance_type"), + env=INIT_KWARGS.get("env"), + resources=ResourceRequirements(), + ) diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 2a0c791215..3d5148772e 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -23,6 +23,7 @@ LocalModelOutOfMemoryException, LocalModelInvocationException, ) +from tests.unit.sagemaker.serve.constants import DEPLOYMENT_CONFIGS mock_model_id = "huggingface-llm-amazon-falconlite" mock_t5_model_id = "google/flan-t5-xxl" @@ -638,3 +639,87 @@ def test_js_gated_model_ex( ValueError, lambda: builder.build(), ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_list_deployment_configs( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS + ) + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + configs = model.list_deployment_configs() + + self.assertEqual(configs, DEPLOYMENT_CONFIGS) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_display_benchmark_metrics( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.display_benchmark_metrics.side_effect = ( + lambda *args, **kwargs: "metric data" + ) + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + model.display_benchmark_metrics() diff --git a/tests/unit/sagemaker/serve/constants.py b/tests/unit/sagemaker/serve/constants.py index db9dd623d8..5c40c1bf64 100644 --- a/tests/unit/sagemaker/serve/constants.py +++ b/tests/unit/sagemaker/serve/constants.py @@ -15,3 +15,153 @@ MOCK_IMAGE_CONFIG = {"RepositoryAccessMode": "Vpc"} MOCK_VPC_CONFIG = {"Subnets": ["subnet-1234"], "SecurityGroupIds": ["sg123"]} +DEPLOYMENT_CONFIGS = [ + { + "ConfigName": "neuron-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentConfig": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "neuron-inference-budget", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentConfig": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "gpu-inference-budget", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentConfig": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "gpu-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentConfig": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, +] diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 81d8279e6d..bf6a7cb09f 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -50,6 +50,8 @@ _is_bad_link, custom_extractall_tarfile, can_model_package_source_uri_autopopulate, + get_instance_rate_per_hour, + extract_instance_rate_per_hour, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -1866,3 +1868,77 @@ def test_deep_override_skip_keys(self): expected_result = {"a": 1, "b": {"x": 20, "y": 3, "z": 30}, "c": [4, 5]} self.assertEqual(deep_override_dict(dict1, dict2, skip_keys=["c", "d"]), expected_result) + + +@pytest.mark.parametrize( + "instance, region", + [ + ("t4g.nano", "us-west-2"), + ("t4g.nano", "eu-central-1"), + ("t4g.nano", "af-south-1"), + ("t4g.nano", "ap-northeast-2"), + ("t4g.nano", "cn-north-1"), + ], +) +@patch("boto3.client") +def test_get_instance_rate_per_hour(mock_client, instance, region): + amazon_sagemaker_price_result = { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": "$0.0083 per ' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": "0.0083000000"}}}, ' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF", ' + '"termAttributes": {}}}}}' + ] + } + + mock_client.return_value.get_products.side_effect = ( + lambda *args, **kwargs: amazon_sagemaker_price_result + ) + instance_rate = get_instance_rate_per_hour(instance_type=instance, region=region) + + assert instance_rate == {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.0083000000"} + + +@patch("boto3.client") +def test_get_instance_rate_per_hour_ex(mock_client): + mock_client.return_value.get_products.side_effect = lambda *args, **kwargs: Exception() + instance_rate = get_instance_rate_per_hour(instance_type="ml.t4g.nano", region="us-west-2") + + assert instance_rate is None + + +@pytest.mark.parametrize( + "price_data, expected_result", + [ + (None, None), + ( + { + "terms": { + "OnDemand": { + "3WK7G7WSYVS3K492.JRTCKXETXF": { + "priceDimensions": { + "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7": { + "unit": "Hrs", + "endRange": "Inf", + "description": "$0.9 per Unused Reservation Linux p2.xlarge Instance Hour", + "appliesTo": [], + "rateCode": "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7", + "beginRange": "0", + "pricePerUnit": {"USD": "0.9000000000"}, + } + } + } + } + } + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9000000000"}, + ), + ], +) +def test_extract_instance_rate_per_hour(price_data, expected_result): + out = extract_instance_rate_per_hour(price_data) + + assert out == expected_result