-
Notifications
You must be signed in to change notification settings - Fork 1.2k
MultiPartCopy with Sync Algorithm #4475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
344d26b
374c638
67d8ec8
2fa0503
30c2b91
ef57f14
c44acd2
297d1b6
18a5728
4554d34
c7f3f96
28c9186
97001cc
27240a9
ce73f62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains important utilities related to HubContent data files.""" | ||
from __future__ import absolute_import | ||
from typing import Any, Dict, List | ||
|
||
from botocore.client import BaseClient | ||
|
||
from sagemaker.jumpstart.curated_hub.types import ( | ||
FileInfo, | ||
HubContentDependencyType, | ||
S3ObjectLocation, | ||
) | ||
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor | ||
from sagemaker.jumpstart.types import JumpStartModelSpecs | ||
|
||
|
||
def generate_file_infos_from_s3_location( | ||
location: S3ObjectLocation, s3_client: BaseClient | ||
) -> List[FileInfo]: | ||
"""Lists objects from an S3 bucket and formats into FileInfo. | ||
|
||
Returns a list of ``FileInfo`` objects from the specified bucket location. | ||
""" | ||
parameters = {"Bucket": location.bucket, "Prefix": location.key} | ||
response = s3_client.list_objects_v2(**parameters) | ||
contents = response.get("Contents") | ||
|
||
if not contents: | ||
return [] | ||
|
||
files = [] | ||
for s3_obj in contents: | ||
key = s3_obj.get("Key") | ||
size = s3_obj.get("Size") | ||
last_modified = s3_obj.get("LastModified") | ||
files.append(FileInfo(location.bucket, key, size, last_modified)) | ||
return files | ||
|
||
|
||
def generate_file_infos_from_model_specs( | ||
model_specs: JumpStartModelSpecs, | ||
studio_specs: Dict[str, Any], | ||
region: str, | ||
s3_client: BaseClient, | ||
) -> List[FileInfo]: | ||
"""Collects data locations from JumpStart public model specs and converts into `FileInfo`. | ||
|
||
Returns a list of `FileInfo` objects from dependencies found in the public | ||
model specs. | ||
""" | ||
public_model_data_accessor = PublicModelDataAccessor( | ||
region=region, model_specs=model_specs, studio_specs=studio_specs | ||
) | ||
files = [] | ||
for dependency in HubContentDependencyType: | ||
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency) | ||
location_type = "prefix" if location.key.endswith("/") else "object" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. non-blocking: that's a bit of a shortcut, ideally we would try catch |
||
|
||
if location_type == "prefix": | ||
parameters = {"Bucket": location.bucket, "Prefix": location.key} | ||
response = s3_client.list_objects_v2(**parameters) | ||
contents = response.get("Contents") | ||
for s3_obj in contents: | ||
key = s3_obj.get("Key") | ||
size = s3_obj.get("Size") | ||
last_modified = s3_obj.get("LastModified") | ||
files.append( | ||
Comment on lines
+70
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nonblocking: lot of indentation depth here, consider moving to helper function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i also see some duplicated code |
||
FileInfo( | ||
location.bucket, | ||
key, | ||
size, | ||
last_modified, | ||
dependency, | ||
) | ||
) | ||
elif location_type == "object": | ||
parameters = {"Bucket": location.bucket, "Key": location.key} | ||
response = s3_client.head_object(**parameters) | ||
size = response.get("ContentLength") | ||
last_updated = response.get("LastModified") | ||
files.append(FileInfo(location.bucket, location.key, size, last_updated, dependency)) | ||
return files |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module provides a class that perfrms functionalities similar to ``S3:Copy``.""" | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from __future__ import absolute_import | ||
from typing import Optional | ||
|
||
import boto3 | ||
import botocore | ||
import tqdm | ||
|
||
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER | ||
from sagemaker.jumpstart.curated_hub.types import FileInfo | ||
from sagemaker.jumpstart.curated_hub.sync.request import HubSyncRequest | ||
|
||
s3transfer = boto3.s3.transfer | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
# pylint: disable=R1705,R1710 | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def human_readable_size(value: int) -> str: | ||
"""Convert a size in bytes into a human readable format. | ||
|
||
For example:: | ||
|
||
>>> human_readable_size(1) | ||
'1 Byte' | ||
>>> human_readable_size(10) | ||
'10 Bytes' | ||
>>> human_readable_size(1024) | ||
'1.0 KiB' | ||
>>> human_readable_size(1024 * 1024) | ||
'1.0 MiB' | ||
|
||
:param value: The size in bytes. | ||
:return: The size in a human readable format based on base-2 units. | ||
|
||
""" | ||
base = 1024 | ||
bytes_int = float(value) | ||
|
||
if bytes_int == 1: | ||
return "1 Byte" | ||
elif bytes_int < base: | ||
return "%d Bytes" % bytes_int | ||
|
||
for i, suffix in enumerate(("KiB", "MiB", "GiB", "TiB", "PiB", "EiB")): | ||
unit = base ** (i + 2) | ||
if round((bytes_int / unit) * base) < base: | ||
return "%.1f %s" % ((base * bytes_int / unit), suffix) | ||
|
||
|
||
class MultiPartCopyHandler(object): | ||
"""Multi Part Copy Handler class.""" | ||
|
||
WORKERS = 20 | ||
# Config values from in S3:Copy | ||
MULTIPART_CONFIG = 8 * (1024**2) | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__( | ||
self, | ||
region: str, | ||
sync_request: HubSyncRequest, | ||
label: Optional[str] = None, | ||
thread_num: Optional[int] = 0, | ||
): | ||
"""Multi-part S3:Copy Handler initializer. | ||
|
||
Args: | ||
region (str): Region for the S3 Client | ||
sync_request (HubSyncRequest): sync request object containing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Seems like the docstring is out of date? |
||
information required to perform the copy | ||
""" | ||
self.label = label | ||
self.region = region | ||
self.files = sync_request.files | ||
self.dest_location = sync_request.destination | ||
self.thread_num = thread_num | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. q: This value seems to be only used in |
||
|
||
config = botocore.config.Config(max_pool_connections=self.WORKERS) | ||
self.s3_client = boto3.client("s3", region_name=self.region, config=config) | ||
transfer_config = s3transfer.TransferConfig( | ||
multipart_threshold=self.MULTIPART_CONFIG, | ||
multipart_chunksize=self.MULTIPART_CONFIG, | ||
max_bandwidth=True, | ||
use_threads=True, | ||
max_concurrency=self.WORKERS, | ||
) | ||
self.transfer_manager = s3transfer.create_transfer_manager( | ||
client=self.s3_client, config=transfer_config | ||
) | ||
|
||
def _copy_file(self, file: FileInfo, progress_cb): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add typing to |
||
"""Performs the actual MultiPart S3:Copy of the object.""" | ||
copy_source = {"Bucket": file.location.bucket, "Key": file.location.key} | ||
result = self.transfer_manager.copy( | ||
bucket=self.dest_location.bucket, | ||
key=f"{self.dest_location.key}/{file.location.key}", | ||
copy_source=copy_source, | ||
subscribers=[ | ||
s3transfer.ProgressCallbackInvoker(progress_cb), | ||
], | ||
) | ||
# Attempt to access result to throw error if exists. Silently calls if successful. | ||
result.result() | ||
|
||
def execute(self): | ||
"""Executes the MultiPart S3:Copy on the class. | ||
|
||
Sets up progress bar and kicks off each copy request. | ||
""" | ||
total_size = sum([file.size for file in self.files]) | ||
JUMPSTART_LOGGER.warning( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think should be a warning. can we modify |
||
"Copying %s files (%s) into %s/%s", | ||
len(self.files), | ||
human_readable_size(total_size), | ||
self.dest_location.bucket, | ||
self.dest_location.key, | ||
) | ||
|
||
progress = tqdm.tqdm( | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
desc=self.label, | ||
total=total_size, | ||
unit="B", | ||
unit_scale=1, | ||
position=self.thread_num, | ||
bar_format="{desc:<10}{percentage:3.0f}%|{bar:10}{r_bar}", | ||
) | ||
|
||
for file in self.files: | ||
self._copy_file(file, progress.update) | ||
|
||
# Call `shutdown` to wait for copy results | ||
self.transfer_manager.shutdown() | ||
progress.close() | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module accessors for the SageMaker JumpStart Public Hub.""" | ||
from __future__ import absolute_import | ||
from typing import Dict, Any | ||
from sagemaker import model_uris, script_uris | ||
from sagemaker.jumpstart.curated_hub.types import ( | ||
HubContentDependencyType, | ||
S3ObjectLocation, | ||
) | ||
from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri | ||
from sagemaker.jumpstart.enums import JumpStartScriptScope | ||
from sagemaker.jumpstart.types import JumpStartModelSpecs | ||
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket | ||
|
||
|
||
class PublicModelDataAccessor: | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Accessor class for JumpStart model data s3 locations.""" | ||
|
||
def __init__( | ||
self, | ||
region: str, | ||
model_specs: JumpStartModelSpecs, | ||
studio_specs: Dict[str, Dict[str, Any]], | ||
): | ||
self._region = region | ||
self._bucket = get_jumpstart_content_bucket(region) | ||
self.model_specs = model_specs | ||
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift | ||
|
||
def get_s3_reference(self, dependency_type: HubContentDependencyType): | ||
"""Retrieves S3 reference given a HubContentDependencyType.""" | ||
return getattr(self, dependency_type.value) | ||
|
||
@property | ||
def inference_artifact_s3_reference(self): | ||
"""Retrieves s3 reference for model inference artifact""" | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE) | ||
) | ||
|
||
@property | ||
def training_artifact_s3_reference(self): | ||
"""Retrieves s3 reference for model training artifact""" | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING) | ||
) | ||
|
||
@property | ||
def inference_script_s3_reference(self): | ||
"""Retrieves s3 reference for model inference script""" | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE) | ||
) | ||
|
||
@property | ||
def training_script_s3_reference(self): | ||
"""Retrieves s3 reference for model training script""" | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING) | ||
) | ||
|
||
@property | ||
def default_training_dataset_s3_reference(self): | ||
"""Retrieves s3 reference for s3 directory containing model training datasets""" | ||
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are there 2 underscores for |
||
|
||
@property | ||
def demo_notebook_s3_reference(self): | ||
"""Retrieves s3 reference for model demo jupyter notebook""" | ||
framework = self.model_specs.get_framework() | ||
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb" | ||
return S3ObjectLocation(self._get_bucket_name(), key) | ||
|
||
@property | ||
Comment on lines
+77
to
+85
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. im worried we could change the s3 file organization and this would break |
||
def markdown_s3_reference(self): | ||
"""Retrieves s3 reference for model markdown""" | ||
framework = self.model_specs.get_framework() | ||
key = f"{framework}-metadata/{self.model_specs.model_id}.md" | ||
return S3ObjectLocation(self._get_bucket_name(), key) | ||
|
||
def _get_bucket_name(self) -> str: | ||
"""Retrieves s3 bucket""" | ||
return self._bucket | ||
|
||
def __get_training_dataset_prefix(self) -> str: | ||
"""Retrieves training dataset location""" | ||
return self.studio_specs["defaultDataKey"] | ||
|
||
def _jumpstart_script_s3_uri(self, model_scope: str) -> str: | ||
"""Retrieves JumpStart script s3 location""" | ||
return script_uris.retrieve( | ||
region=self._region, | ||
model_id=self.model_specs.model_id, | ||
model_version=self.model_specs.version, | ||
script_scope=model_scope, | ||
) | ||
|
||
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str: | ||
"""Retrieves JumpStart artifact s3 location""" | ||
return model_uris.retrieve( | ||
region=self._region, | ||
model_id=self.model_specs.model_id, | ||
model_version=self.model_specs.version, | ||
model_scope=model_scope, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.