Skip to content

Commit a3deb08

Browse files
authored
Merge branch 'master' into feat/instance-specific-jumpstart-host-requirements
2 parents 39d3fa6 + 8b206ba commit a3deb08

File tree

9 files changed

+1221
-24
lines changed

9 files changed

+1221
-24
lines changed

src/sagemaker/image_uri_config/instance_gpu_info.json

Lines changed: 782 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve instance types GPU info."""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import os
18+
from typing import Dict
19+
20+
21+
def retrieve(region: str) -> Dict[str, Dict[str, int]]:
22+
"""Retrieves instance types GPU info of the given region.
23+
24+
Args:
25+
region (str): The AWS region.
26+
27+
Returns:
28+
dict[str, dict[str, int]]: A dictionary that contains instance types as keys
29+
and GPU info as values or empty dictionary if the
30+
config for the given region is not found.
31+
32+
Raises:
33+
ValueError: If no config found.
34+
"""
35+
config_path = os.path.join(
36+
os.path.dirname(__file__), "image_uri_config", "instance_gpu_info.json"
37+
)
38+
try:
39+
with open(config_path) as f:
40+
instance_types_gpu_info_config = json.load(f)
41+
return instance_types_gpu_info_config.get(region, {})
42+
except FileNotFoundError:
43+
raise ValueError("Could not find instance types gpu info.")

src/sagemaker/model_monitor/clarify_model_monitoring.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,7 @@ def create_monitoring_schedule(
669669
self.monitoring_schedule_name = monitor_schedule_name
670670
except Exception:
671671
logger.exception("Failed to create monitoring schedule.")
672+
self.monitoring_schedule_name = None
672673
# noinspection PyBroadException
673674
try:
674675
self.sagemaker_session.sagemaker_client.delete_model_bias_job_definition(
@@ -1109,6 +1110,7 @@ def create_monitoring_schedule(
11091110
self.monitoring_schedule_name = monitor_schedule_name
11101111
except Exception:
11111112
logger.exception("Failed to create monitoring schedule.")
1113+
self.monitoring_schedule_name = None
11121114
# noinspection PyBroadException
11131115
try:
11141116
self.sagemaker_session.sagemaker_client.delete_model_explainability_job_definition(

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -415,30 +415,34 @@ def create_monitoring_schedule(
415415
if arguments is not None:
416416
self.arguments = arguments
417417

418-
self.sagemaker_session.create_monitoring_schedule(
419-
monitoring_schedule_name=self.monitoring_schedule_name,
420-
schedule_expression=schedule_cron_expression,
421-
statistics_s3_uri=statistics_s3_uri,
422-
constraints_s3_uri=constraints_s3_uri,
423-
monitoring_inputs=[normalized_monitoring_input],
424-
monitoring_output_config=monitoring_output_config,
425-
instance_count=self.instance_count,
426-
instance_type=self.instance_type,
427-
volume_size_in_gb=self.volume_size_in_gb,
428-
volume_kms_key=self.volume_kms_key,
429-
image_uri=self.image_uri,
430-
entrypoint=self.entrypoint,
431-
arguments=self.arguments,
432-
record_preprocessor_source_uri=None,
433-
post_analytics_processor_source_uri=None,
434-
max_runtime_in_seconds=self.max_runtime_in_seconds,
435-
environment=self.env,
436-
network_config=network_config_dict,
437-
role_arn=self.sagemaker_session.expand_role(self.role),
438-
tags=self.tags,
439-
data_analysis_start_time=data_analysis_start_time,
440-
data_analysis_end_time=data_analysis_end_time,
441-
)
418+
try:
419+
self.sagemaker_session.create_monitoring_schedule(
420+
monitoring_schedule_name=self.monitoring_schedule_name,
421+
schedule_expression=schedule_cron_expression,
422+
statistics_s3_uri=statistics_s3_uri,
423+
constraints_s3_uri=constraints_s3_uri,
424+
monitoring_inputs=[normalized_monitoring_input],
425+
monitoring_output_config=monitoring_output_config,
426+
instance_count=self.instance_count,
427+
instance_type=self.instance_type,
428+
volume_size_in_gb=self.volume_size_in_gb,
429+
volume_kms_key=self.volume_kms_key,
430+
image_uri=self.image_uri,
431+
entrypoint=self.entrypoint,
432+
arguments=self.arguments,
433+
record_preprocessor_source_uri=None,
434+
post_analytics_processor_source_uri=None,
435+
max_runtime_in_seconds=self.max_runtime_in_seconds,
436+
environment=self.env,
437+
network_config=network_config_dict,
438+
role_arn=self.sagemaker_session.expand_role(self.role),
439+
tags=self.tags,
440+
data_analysis_start_time=data_analysis_start_time,
441+
data_analysis_end_time=data_analysis_end_time,
442+
)
443+
except Exception:
444+
self.monitoring_schedule_name = None
445+
raise
442446

443447
def update_monitoring_schedule(
444448
self,
@@ -2054,6 +2058,7 @@ def create_monitoring_schedule(
20542058
self.monitoring_schedule_name = monitor_schedule_name
20552059
except Exception:
20562060
logger.exception("Failed to create monitoring schedule.")
2061+
self.monitoring_schedule_name = None
20572062
# noinspection PyBroadException
20582063
try:
20592064
self.sagemaker_session.sagemaker_client.delete_data_quality_job_definition(
@@ -3173,6 +3178,7 @@ def create_monitoring_schedule(
31733178
self.monitoring_schedule_name = monitor_schedule_name
31743179
except Exception:
31753180
logger.exception("Failed to create monitoring schedule.")
3181+
self.monitoring_schedule_name = None
31763182
# noinspection PyBroadException
31773183
try:
31783184
self.sagemaker_session.sagemaker_client.delete_model_quality_job_definition(
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Utilities for detecting available GPUs and Aggregate GPU Memory size of an instance"""
14+
from __future__ import absolute_import
15+
16+
import logging
17+
from typing import Tuple
18+
19+
from botocore.exceptions import ClientError
20+
21+
from sagemaker import Session
22+
from sagemaker import instance_types_gpu_info
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
def _get_gpu_info(instance_type: str, session: Session) -> Tuple[int, int]:
28+
"""Get GPU info for the provided instance
29+
30+
Args:
31+
instance_type (str)
32+
session: The session to use.
33+
34+
Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0,
35+
and aggregate memory size in MiB at index 1.
36+
37+
Raises:
38+
ValueError: If The given instance type does not exist or GPU is not enabled.
39+
"""
40+
ec2_client = session.boto_session.client("ec2")
41+
ec2_instance = _format_instance_type(instance_type)
42+
43+
try:
44+
instance_info = ec2_client.describe_instance_types(InstanceTypes=[ec2_instance]).get(
45+
"InstanceTypes"
46+
)[0]
47+
except ClientError:
48+
raise ValueError(f"Provided instance_type is not GPU enabled: [#{ec2_instance}]")
49+
50+
if instance_info is not None:
51+
gpus_info = instance_info.get("GpuInfo")
52+
if gpus_info is not None:
53+
gpus = gpus_info.get("Gpus")
54+
if gpus is not None and len(gpus) > 0:
55+
count = gpus[0].get("Count")
56+
total_gpu_memory_in_mib = gpus_info.get("TotalGpuMemoryInMiB")
57+
if count and total_gpu_memory_in_mib:
58+
instance_gpu_info = (
59+
count,
60+
total_gpu_memory_in_mib,
61+
)
62+
logger.info("GPU Info [%s]: %s", ec2_instance, instance_gpu_info)
63+
return instance_gpu_info
64+
65+
raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]")
66+
67+
68+
def _get_gpu_info_fallback(instance_type: str, region: str) -> Tuple[int, int]:
69+
"""Get GPU info for the provided from the config
70+
71+
Args:
72+
instance_type (str):
73+
region: The AWS region.
74+
75+
Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0,
76+
and aggregate memory size in MiB at index 1.
77+
78+
Raises:
79+
ValueError: If The given instance type does not exist.
80+
"""
81+
instance_types_gpu_info_config = instance_types_gpu_info.retrieve(region)
82+
fallback_instance_gpu_info = instance_types_gpu_info_config.get(instance_type)
83+
84+
ec2_instance = _format_instance_type(instance_type)
85+
if fallback_instance_gpu_info is None:
86+
raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]")
87+
88+
fallback_instance_gpu_info = (
89+
fallback_instance_gpu_info.get("Count"),
90+
fallback_instance_gpu_info.get("TotalGpuMemoryInMiB"),
91+
)
92+
logger.info("GPU Info [%s]: %s", ec2_instance, fallback_instance_gpu_info)
93+
return fallback_instance_gpu_info
94+
95+
96+
def _format_instance_type(instance_type: str) -> str:
97+
"""Formats provided instance type name
98+
99+
Args:
100+
instance_type (str):
101+
102+
Returns: formatted instance type.
103+
"""
104+
split_instance = instance_type.split(".")
105+
106+
if len(split_instance) > 2:
107+
split_instance.pop(0)
108+
109+
ec2_instance = ".".join(split_instance)
110+
return ec2_instance
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
17+
from sagemaker.serve.utils import hardware_detector
18+
19+
REGION = "us-west-2"
20+
VALID_INSTANCE_TYPE = "ml.g5.48xlarge"
21+
INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge"
22+
EXPECTED_INSTANCE_GPU_INFO = (8, 196608)
23+
24+
25+
def test_get_gpu_info_success(sagemaker_session):
26+
gpu_info = hardware_detector._get_gpu_info(VALID_INSTANCE_TYPE, sagemaker_session)
27+
28+
assert gpu_info == EXPECTED_INSTANCE_GPU_INFO
29+
30+
31+
def test_get_gpu_info_throws(sagemaker_session):
32+
with pytest.raises(ValueError):
33+
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session)
34+
35+
36+
def test_get_gpu_info_fallback_success():
37+
gpu_info = hardware_detector._get_gpu_info_fallback(VALID_INSTANCE_TYPE, REGION)
38+
39+
assert gpu_info == EXPECTED_INSTANCE_GPU_INFO
40+
41+
42+
def test_get_gpu_info_fallback_throws():
43+
with pytest.raises(ValueError):
44+
hardware_detector._get_gpu_info_fallback(INVALID_INSTANCE_TYPE, REGION)

tests/integ/test_model_monitor.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,3 +2488,85 @@ def test_one_time_monitoring_schedule(sagemaker_session):
24882488
my_default_monitor.stop_monitoring_schedule()
24892489
my_default_monitor.delete_monitoring_schedule()
24902490
raise e
2491+
2492+
2493+
def test_create_monitoring_schedule_with_validation_error(sagemaker_session):
2494+
my_default_monitor = DefaultModelMonitor(
2495+
role=ROLE,
2496+
instance_count=INSTANCE_COUNT,
2497+
instance_type=INSTANCE_TYPE,
2498+
volume_size_in_gb=VOLUME_SIZE_IN_GB,
2499+
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
2500+
sagemaker_session=sagemaker_session,
2501+
env=ENVIRONMENT,
2502+
tags=TAGS,
2503+
network_config=NETWORK_CONFIG,
2504+
)
2505+
2506+
output_s3_uri = os.path.join(
2507+
"s3://",
2508+
sagemaker_session.default_bucket(),
2509+
"integ-test-monitoring-output-bucket",
2510+
str(uuid.uuid4()),
2511+
)
2512+
2513+
data_captured_destination_s3_uri = os.path.join(
2514+
"s3://",
2515+
sagemaker_session.default_bucket(),
2516+
"sagemaker-serving-batch-transform",
2517+
str(uuid.uuid4()),
2518+
)
2519+
2520+
batch_transform_input = BatchTransformInput(
2521+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
2522+
destination="/opt/ml/processing/output",
2523+
dataset_format=MonitoringDatasetFormat.csv(header=False),
2524+
)
2525+
2526+
statistics = Statistics.from_file_path(
2527+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
2528+
sagemaker_session=sagemaker_session,
2529+
)
2530+
2531+
constraints = Constraints.from_file_path(
2532+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
2533+
sagemaker_session=sagemaker_session,
2534+
)
2535+
2536+
try:
2537+
my_default_monitor.create_monitoring_schedule(
2538+
monitor_schedule_name="schedule-name-more-than-63-characters-to-get-a-validation-exception",
2539+
batch_transform_input=batch_transform_input,
2540+
output_s3_uri=output_s3_uri,
2541+
statistics=statistics,
2542+
constraints=constraints,
2543+
schedule_cron_expression=CronExpressionGenerator.now(),
2544+
data_analysis_start_time="-PT1H",
2545+
data_analysis_end_time="-PT0H",
2546+
enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS,
2547+
)
2548+
except Exception as e:
2549+
assert "ValidationException" in str(e)
2550+
2551+
my_default_monitor.create_monitoring_schedule(
2552+
monitor_schedule_name=unique_name_from_base("valid-schedule-name"),
2553+
batch_transform_input=batch_transform_input,
2554+
output_s3_uri=output_s3_uri,
2555+
statistics=statistics,
2556+
constraints=constraints,
2557+
schedule_cron_expression=CronExpressionGenerator.now(),
2558+
data_analysis_start_time="-PT1H",
2559+
data_analysis_end_time="-PT0H",
2560+
enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS,
2561+
)
2562+
try:
2563+
2564+
_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)
2565+
2566+
my_default_monitor.stop_monitoring_schedule()
2567+
my_default_monitor.delete_monitoring_schedule()
2568+
2569+
except Exception as e:
2570+
my_default_monitor.stop_monitoring_schedule()
2571+
my_default_monitor.delete_monitoring_schedule()
2572+
raise e

0 commit comments

Comments
 (0)