Skip to content

Commit 3d08909

Browse files
authored
MultiPartCopy with Sync Algorithm (#4475)
* first pass at sync function with util classes * adding tests and update clases * linting * file generator class inheritance * lint * multipart copy and algorithm updates * modularize sync * reformatting folders * testing for sync * do not tolerate vulnerable * remove prints * handle multithreading progress bar * update tests * optimize function and add hub bucket prefix * docstrings and linting
1 parent 352a5c1 commit 3d08909

20 files changed

+1728
-24
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def __init__(
101101
Default: None (no config).
102102
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
103103
sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
104-
used for SageMaker interactions. Default: Session in region associated with boto3 session.
104+
used for SageMaker interactions. Default: Session in region associated with boto3
105+
session.
105106
"""
106107

107108
self._region = region
@@ -358,7 +359,9 @@ def _retrieval_function(
358359
hub_content_type=data_type
359360
)
360361

361-
model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True)
362+
model_specs = JumpStartModelSpecs(
363+
DescribeHubContentsResponse(hub_model_description), is_hub_content=True
364+
)
362365

363366
utils.emit_logs_based_on_model_specs(
364367
model_specs,
@@ -372,7 +375,9 @@ def _retrieval_function(
372375
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
373376
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)
374377
hub_description = DescribeHubResponse(response)
375-
return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description))
378+
return JumpStartCachedContentValue(
379+
formatted_content=DescribeHubResponse(hub_description)
380+
)
376381
raise ValueError(
377382
f"Bad value for key '{key}': must be in ",
378383
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}"

src/sagemaker/jumpstart/curated_hub/accessors/__init__.py

Whitespace-only changes.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains important utilities related to HubContent data files."""
14+
from __future__ import absolute_import
15+
from typing import Any, Dict, List
16+
17+
from botocore.client import BaseClient
18+
19+
from sagemaker.jumpstart.curated_hub.types import (
20+
FileInfo,
21+
HubContentDependencyType,
22+
S3ObjectLocation,
23+
)
24+
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
25+
from sagemaker.jumpstart.types import JumpStartModelSpecs
26+
27+
28+
def generate_file_infos_from_s3_location(
29+
location: S3ObjectLocation, s3_client: BaseClient
30+
) -> List[FileInfo]:
31+
"""Lists objects from an S3 bucket and formats into FileInfo.
32+
33+
Returns a list of ``FileInfo`` objects from the specified bucket location.
34+
"""
35+
parameters = {"Bucket": location.bucket, "Prefix": location.key}
36+
response = s3_client.list_objects_v2(**parameters)
37+
contents = response.get("Contents")
38+
39+
if not contents:
40+
return []
41+
42+
files = []
43+
for s3_obj in contents:
44+
key = s3_obj.get("Key")
45+
size = s3_obj.get("Size")
46+
last_modified = s3_obj.get("LastModified")
47+
files.append(FileInfo(location.bucket, key, size, last_modified))
48+
return files
49+
50+
51+
def generate_file_infos_from_model_specs(
52+
model_specs: JumpStartModelSpecs,
53+
studio_specs: Dict[str, Any],
54+
region: str,
55+
s3_client: BaseClient,
56+
) -> List[FileInfo]:
57+
"""Collects data locations from JumpStart public model specs and converts into `FileInfo`.
58+
59+
Returns a list of `FileInfo` objects from dependencies found in the public
60+
model specs.
61+
"""
62+
public_model_data_accessor = PublicModelDataAccessor(
63+
region=region, model_specs=model_specs, studio_specs=studio_specs
64+
)
65+
files = []
66+
for dependency in HubContentDependencyType:
67+
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency)
68+
location_type = "prefix" if location.key.endswith("/") else "object"
69+
70+
if location_type == "prefix":
71+
parameters = {"Bucket": location.bucket, "Prefix": location.key}
72+
response = s3_client.list_objects_v2(**parameters)
73+
contents = response.get("Contents")
74+
for s3_obj in contents:
75+
key = s3_obj.get("Key")
76+
size = s3_obj.get("Size")
77+
last_modified = s3_obj.get("LastModified")
78+
files.append(
79+
FileInfo(
80+
location.bucket,
81+
key,
82+
size,
83+
last_modified,
84+
dependency,
85+
)
86+
)
87+
elif location_type == "object":
88+
parameters = {"Bucket": location.bucket, "Key": location.key}
89+
response = s3_client.head_object(**parameters)
90+
size = response.get("ContentLength")
91+
last_updated = response.get("LastModified")
92+
files.append(FileInfo(location.bucket, location.key, size, last_updated, dependency))
93+
return files
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module provides a class that perfrms functionalities similar to ``S3:Copy``."""
14+
from __future__ import absolute_import
15+
from typing import Optional
16+
17+
import boto3
18+
import botocore
19+
import tqdm
20+
21+
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
22+
from sagemaker.jumpstart.curated_hub.types import FileInfo
23+
from sagemaker.jumpstart.curated_hub.sync.request import HubSyncRequest
24+
25+
s3transfer = boto3.s3.transfer
26+
27+
28+
# pylint: disable=R1705,R1710
29+
def human_readable_size(value: int) -> str:
30+
"""Convert a size in bytes into a human readable format.
31+
32+
For example::
33+
34+
>>> human_readable_size(1)
35+
'1 Byte'
36+
>>> human_readable_size(10)
37+
'10 Bytes'
38+
>>> human_readable_size(1024)
39+
'1.0 KiB'
40+
>>> human_readable_size(1024 * 1024)
41+
'1.0 MiB'
42+
43+
:param value: The size in bytes.
44+
:return: The size in a human readable format based on base-2 units.
45+
46+
"""
47+
base = 1024
48+
bytes_int = float(value)
49+
50+
if bytes_int == 1:
51+
return "1 Byte"
52+
elif bytes_int < base:
53+
return "%d Bytes" % bytes_int
54+
55+
for i, suffix in enumerate(("KiB", "MiB", "GiB", "TiB", "PiB", "EiB")):
56+
unit = base ** (i + 2)
57+
if round((bytes_int / unit) * base) < base:
58+
return "%.1f %s" % ((base * bytes_int / unit), suffix)
59+
60+
61+
class MultiPartCopyHandler(object):
62+
"""Multi Part Copy Handler class."""
63+
64+
WORKERS = 20
65+
# Config values from in S3:Copy
66+
MULTIPART_CONFIG = 8 * (1024**2)
67+
68+
def __init__(
69+
self,
70+
region: str,
71+
sync_request: HubSyncRequest,
72+
label: Optional[str] = None,
73+
thread_num: Optional[int] = 0,
74+
):
75+
"""Multi-part S3:Copy Handler initializer.
76+
77+
Args:
78+
region (str): Region for the S3 Client
79+
sync_request (HubSyncRequest): sync request object containing
80+
information required to perform the copy
81+
"""
82+
self.label = label
83+
self.region = region
84+
self.files = sync_request.files
85+
self.dest_location = sync_request.destination
86+
self.thread_num = thread_num
87+
88+
config = botocore.config.Config(max_pool_connections=self.WORKERS)
89+
self.s3_client = boto3.client("s3", region_name=self.region, config=config)
90+
transfer_config = s3transfer.TransferConfig(
91+
multipart_threshold=self.MULTIPART_CONFIG,
92+
multipart_chunksize=self.MULTIPART_CONFIG,
93+
max_bandwidth=True,
94+
use_threads=True,
95+
max_concurrency=self.WORKERS,
96+
)
97+
self.transfer_manager = s3transfer.create_transfer_manager(
98+
client=self.s3_client, config=transfer_config
99+
)
100+
101+
def _copy_file(self, file: FileInfo, progress_cb):
102+
"""Performs the actual MultiPart S3:Copy of the object."""
103+
copy_source = {"Bucket": file.location.bucket, "Key": file.location.key}
104+
result = self.transfer_manager.copy(
105+
bucket=self.dest_location.bucket,
106+
key=f"{self.dest_location.key}/{file.location.key}",
107+
copy_source=copy_source,
108+
subscribers=[
109+
s3transfer.ProgressCallbackInvoker(progress_cb),
110+
],
111+
)
112+
# Attempt to access result to throw error if exists. Silently calls if successful.
113+
result.result()
114+
115+
def execute(self):
116+
"""Executes the MultiPart S3:Copy on the class.
117+
118+
Sets up progress bar and kicks off each copy request.
119+
"""
120+
total_size = sum([file.size for file in self.files])
121+
JUMPSTART_LOGGER.warning(
122+
"Copying %s files (%s) into %s/%s",
123+
len(self.files),
124+
human_readable_size(total_size),
125+
self.dest_location.bucket,
126+
self.dest_location.key,
127+
)
128+
129+
progress = tqdm.tqdm(
130+
desc=self.label,
131+
total=total_size,
132+
unit="B",
133+
unit_scale=1,
134+
position=self.thread_num,
135+
bar_format="{desc:<10}{percentage:3.0f}%|{bar:10}{r_bar}",
136+
)
137+
138+
for file in self.files:
139+
self._copy_file(file, progress.update)
140+
141+
# Call `shutdown` to wait for copy results
142+
self.transfer_manager.shutdown()
143+
progress.close()
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module accessors for the SageMaker JumpStart Public Hub."""
14+
from __future__ import absolute_import
15+
from typing import Dict, Any
16+
from sagemaker import model_uris, script_uris
17+
from sagemaker.jumpstart.curated_hub.types import (
18+
HubContentDependencyType,
19+
S3ObjectLocation,
20+
)
21+
from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri
22+
from sagemaker.jumpstart.enums import JumpStartScriptScope
23+
from sagemaker.jumpstart.types import JumpStartModelSpecs
24+
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
25+
26+
27+
class PublicModelDataAccessor:
28+
"""Accessor class for JumpStart model data s3 locations."""
29+
30+
def __init__(
31+
self,
32+
region: str,
33+
model_specs: JumpStartModelSpecs,
34+
studio_specs: Dict[str, Dict[str, Any]],
35+
):
36+
self._region = region
37+
self._bucket = get_jumpstart_content_bucket(region)
38+
self.model_specs = model_specs
39+
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift
40+
41+
def get_s3_reference(self, dependency_type: HubContentDependencyType):
42+
"""Retrieves S3 reference given a HubContentDependencyType."""
43+
return getattr(self, dependency_type.value)
44+
45+
@property
46+
def inference_artifact_s3_reference(self):
47+
"""Retrieves s3 reference for model inference artifact"""
48+
return create_s3_object_reference_from_uri(
49+
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE)
50+
)
51+
52+
@property
53+
def training_artifact_s3_reference(self):
54+
"""Retrieves s3 reference for model training artifact"""
55+
return create_s3_object_reference_from_uri(
56+
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING)
57+
)
58+
59+
@property
60+
def inference_script_s3_reference(self):
61+
"""Retrieves s3 reference for model inference script"""
62+
return create_s3_object_reference_from_uri(
63+
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE)
64+
)
65+
66+
@property
67+
def training_script_s3_reference(self):
68+
"""Retrieves s3 reference for model training script"""
69+
return create_s3_object_reference_from_uri(
70+
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING)
71+
)
72+
73+
@property
74+
def default_training_dataset_s3_reference(self):
75+
"""Retrieves s3 reference for s3 directory containing model training datasets"""
76+
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix())
77+
78+
@property
79+
def demo_notebook_s3_reference(self):
80+
"""Retrieves s3 reference for model demo jupyter notebook"""
81+
framework = self.model_specs.get_framework()
82+
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb"
83+
return S3ObjectLocation(self._get_bucket_name(), key)
84+
85+
@property
86+
def markdown_s3_reference(self):
87+
"""Retrieves s3 reference for model markdown"""
88+
framework = self.model_specs.get_framework()
89+
key = f"{framework}-metadata/{self.model_specs.model_id}.md"
90+
return S3ObjectLocation(self._get_bucket_name(), key)
91+
92+
def _get_bucket_name(self) -> str:
93+
"""Retrieves s3 bucket"""
94+
return self._bucket
95+
96+
def __get_training_dataset_prefix(self) -> str:
97+
"""Retrieves training dataset location"""
98+
return self.studio_specs["defaultDataKey"]
99+
100+
def _jumpstart_script_s3_uri(self, model_scope: str) -> str:
101+
"""Retrieves JumpStart script s3 location"""
102+
return script_uris.retrieve(
103+
region=self._region,
104+
model_id=self.model_specs.model_id,
105+
model_version=self.model_specs.version,
106+
script_scope=model_scope,
107+
)
108+
109+
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str:
110+
"""Retrieves JumpStart artifact s3 location"""
111+
return model_uris.retrieve(
112+
region=self._region,
113+
model_id=self.model_specs.model_id,
114+
model_version=self.model_specs.version,
115+
model_scope=model_scope,
116+
)

0 commit comments

Comments
 (0)