diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 534e93c285..81efc1f17a 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -469,9 +469,15 @@ def benchmark_metrics(self) -> pd.DataFrame: df.index = blank_index return df - def display_benchmark_metrics(self, *args, **kwargs) -> None: + def display_benchmark_metrics(self, **kwargs) -> None: """Display deployment configs benchmark metrics.""" - print(self.benchmark_metrics.to_markdown(index=False, floatfmt=".2f"), *args, **kwargs) + df = self.benchmark_metrics + + instance_type = kwargs.get("instance_type") + if instance_type: + df = df[df["Instance Type"].str.contains(instance_type)] + + print(df.to_markdown(index=False, floatfmt=".2f")) def list_deployment_configs(self) -> List[Dict[str, Any]]: """List deployment configs for ``This`` model. @@ -898,11 +904,12 @@ def _get_deployment_configs( err = None for config_name, metadata_config in self._metadata_configs.items(): - resolved_config = metadata_config.resolved_config if selected_config_name == config_name: instance_type_to_use = selected_instance_type else: - instance_type_to_use = resolved_config.get("default_inference_instance_type") + instance_type_to_use = metadata_config.resolved_config.get( + "default_inference_instance_type" + ) if metadata_config.benchmark_metrics: err, metadata_config.benchmark_metrics = ( @@ -941,8 +948,7 @@ def _get_deployment_configs( deployment_config_metadata = DeploymentConfigMetadata( config_name, - metadata_config.benchmark_metrics, - resolved_config, + metadata_config, init_kwargs, deploy_kwargs, ) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 67d1622977..04e8b91e26 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1078,6 +1078,7 @@ class JumpStartMetadataConfig(JumpStartDataHolderType): __slots__ = [ "base_fields", "benchmark_metrics", + "acceleration_configs", "config_components", "resolved_metadata_config", "config_name", @@ -1115,6 +1116,7 @@ def __init__( if config and config.get("benchmark_metrics") else None ) + self.acceleration_configs = config.get("acceleration_configs") self.resolved_metadata_config: Optional[Dict[str, Any]] = None self.config_name: Optional[str] = config_name self.default_inference_config: Optional[str] = config.get("default_inference_config") @@ -2293,6 +2295,7 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder): __slots__ = [ "image_uri", "model_data", + "model_package_arn", "environment", "instance_type", "compute_resource_requirements", @@ -2310,6 +2313,7 @@ def __init__( if init_kwargs is not None: self.image_uri = init_kwargs.image_uri self.model_data = init_kwargs.model_data + self.model_package_arn = init_kwargs.model_package_arn self.instance_type = init_kwargs.instance_type self.environment = init_kwargs.env if init_kwargs.resources is not None: @@ -2341,14 +2345,14 @@ class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): def __init__( self, config_name: Optional[str] = None, - benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]] = None, - resolved_config: Optional[Dict[str, Any]] = None, + metadata_config: Optional[JumpStartMetadataConfig] = None, init_kwargs: Optional[JumpStartModelInitKwargs] = None, deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None, ): """Instantiates DeploymentConfigMetadata object.""" self.deployment_config_name = config_name - self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs, resolved_config) - self.benchmark_metrics = benchmark_metrics - if resolved_config is not None: - self.acceleration_configs = resolved_config.get("acceleration_configs") + self.deployment_args = DeploymentArgs( + init_kwargs, deploy_kwargs, metadata_config.resolved_config + ) + self.benchmark_metrics = metadata_config.benchmark_metrics + self.acceleration_configs = metadata_config.acceleration_configs diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 7de79204aa..a1b5d7fa9a 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1284,7 +1284,7 @@ def wrapped_f(*args, **kwargs): break elif isinstance(res, dict): keys = list(res.keys()) - if "Instance Rate" not in keys[-1]: + if len(keys) == 0 or "Instance Rate" not in keys[-1]: f.cache_clear() elif len(res[keys[1]]) > len(res[keys[-1]]): del res[keys[-1]] diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 7b9d935fb6..e80ce020f7 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1932,6 +1932,7 @@ def test_model_display_benchmark_metrics( model = JumpStartModel(model_id=model_id) model.display_benchmark_metrics() + model.display_benchmark_metrics(instance_type="g5.12xlarge") @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 7f2c7b2aad..23fa42c09a 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -1372,8 +1372,7 @@ def test_deployment_config_metadata(): deployment_config_metadata = DeploymentConfigMetadata( jumpstart_config.config_name, - jumpstart_config.benchmark_metrics, - jumpstart_config.resolved_config, + jumpstart_config, JumpStartModelInitKwargs( model_id=specs.model_id, model_data=INIT_KWARGS.get("model_data"), diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 97ee36e998..cc4ef71cee 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -378,8 +378,7 @@ def get_base_deployment_configs_metadata( configs.append( DeploymentConfigMetadata( config_name=config_name, - benchmark_metrics=jumpstart_config.benchmark_metrics, - resolved_config=jumpstart_config.resolved_config, + metadata_config=jumpstart_config, init_kwargs=get_mock_init_kwargs( get_base_spec_with_prototype_configs().model_id, config_name ),