-
Notifications
You must be signed in to change notification settings - Fork 1.2k
MultiPartCopy with Sync Algorithm #4475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
liujiaorr
merged 15 commits into
aws:master-jumpstart-curated-hub
from
bencrabtree:feat/jsch-sync
Mar 12, 2024
Merged
Changes from 7 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
344d26b
first pass at sync function with util classes
bencrabtree 374c638
adding tests and update clases
bencrabtree 67d8ec8
linting
bencrabtree 2fa0503
file generator class inheritance
bencrabtree 30c2b91
lint
bencrabtree ef57f14
multipart copy and algorithm updates
bencrabtree c44acd2
modularize sync
bencrabtree 297d1b6
reformatting folders
bencrabtree 18a5728
testing for sync
bencrabtree 4554d34
do not tolerate vulnerable
bencrabtree c7f3f96
remove prints
bencrabtree 28c9186
handle multithreading progress bar
bencrabtree 97001cc
update tests
bencrabtree 27240a9
optimize function and add hub bucket prefix
bencrabtree ce73f62
docstrings and linting
bencrabtree File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
95 changes: 95 additions & 0 deletions
95
src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains important utilities related to HubContent data files.""" | ||
from __future__ import absolute_import | ||
from typing import Any, Dict, List | ||
|
||
from datetime import datetime | ||
from botocore.client import BaseClient | ||
|
||
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType | ||
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation | ||
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor | ||
from sagemaker.jumpstart.types import JumpStartModelSpecs | ||
|
||
|
||
def generate_file_infos_from_s3_location( | ||
location: S3ObjectLocation, s3_client: BaseClient | ||
) -> List[FileInfo]: | ||
"""Lists objects from an S3 bucket and formats into FileInfo. | ||
|
||
Returns a list of ``FileInfo`` objects from the specified bucket location. | ||
""" | ||
parameters = {"Bucket": location.bucket, "Prefix": location.key} | ||
response = s3_client.list_objects_v2(**parameters) | ||
contents = response.get("Contents", None) | ||
|
||
if not contents: | ||
return [] | ||
|
||
files = [] | ||
for s3_obj in contents: | ||
key: str = s3_obj.get("Key") | ||
size: bytes = s3_obj.get("Size", None) | ||
last_modified: str = s3_obj.get("LastModified", None) | ||
files.append(FileInfo(location.bucket, key, size, last_modified)) | ||
return files | ||
|
||
|
||
def generate_file_infos_from_model_specs( | ||
model_specs: JumpStartModelSpecs, | ||
studio_specs: Dict[str, Any], | ||
region: str, | ||
s3_client: BaseClient, | ||
) -> List[FileInfo]: | ||
"""Collects data locations from JumpStart public model specs and converts into `FileInfo`. | ||
|
||
Returns a list of `FileInfo` objects from dependencies found in the public | ||
model specs. | ||
""" | ||
public_model_data_accessor = PublicModelDataAccessor( | ||
region=region, model_specs=model_specs, studio_specs=studio_specs | ||
) | ||
files = [] | ||
for dependency in HubContentDependencyType: | ||
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency) | ||
|
||
# Prefix | ||
if location.key.endswith("/"): | ||
parameters = {"Bucket": location.bucket, "Prefix": location.key} | ||
response = s3_client.list_objects_v2(**parameters) | ||
contents = response.get("Contents", None) | ||
for s3_obj in contents: | ||
key: str = s3_obj.get("Key") | ||
size: bytes = s3_obj.get("Size", None) | ||
last_modified: datetime = s3_obj.get("LastModified", None) | ||
dependency_type: HubContentDependencyType = dependency | ||
files.append( | ||
FileInfo( | ||
location.bucket, | ||
key, | ||
size, | ||
last_modified, | ||
dependency_type, | ||
) | ||
) | ||
else: | ||
parameters = {"Bucket": location.bucket, "Key": location.key} | ||
response = s3_client.head_object(**parameters) | ||
size: bytes = response.get("ContentLength", None) | ||
last_updated: datetime = response.get("LastModified", None) | ||
dependency_type: HubContentDependencyType = dependency | ||
files.append( | ||
FileInfo(location.bucket, location.key, size, last_updated, dependency_type) | ||
) | ||
return files |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains important details related to HubContent data files.""" | ||
from __future__ import absolute_import | ||
|
||
from enum import Enum | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
from datetime import datetime | ||
|
||
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation | ||
|
||
|
||
class HubContentDependencyType(str, Enum): | ||
"""Enum class for HubContent dependency names""" | ||
|
||
INFERENCE_ARTIFACT = "inference_artifact_s3_reference" | ||
TRAINING_ARTIFACT = "training_artifact_s3_reference" | ||
INFERENCE_SCRIPT = "inference_script_s3_reference" | ||
TRAINING_SCRIPT = "training_script_s3_reference" | ||
DEFAULT_TRAINING_DATASET = "default_training_dataset_s3_reference" | ||
DEMO_NOTEBOOK = "demo_notebook_s3_reference" | ||
MARKDOWN = "markdown_s3_reference" | ||
|
||
|
||
@dataclass | ||
class FileInfo: | ||
"""Data class for additional S3 file info.""" | ||
|
||
location: S3ObjectLocation | ||
|
||
def __init__( | ||
self, | ||
bucket: str, | ||
key: str, | ||
size: Optional[bytes], | ||
last_updated: Optional[datetime], | ||
dependecy_type: Optional[HubContentDependencyType] = None, | ||
): | ||
self.location = S3ObjectLocation(bucket, key) | ||
self.size = size | ||
self.last_updated = last_updated | ||
self.dependecy_type = dependecy_type |
128 changes: 128 additions & 0 deletions
128
src/sagemaker/jumpstart/curated_hub/accessors/multipartcopy.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module provides a class that perfrms functionalities similar to ``S3:Copy``.""" | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from __future__ import absolute_import | ||
from typing import List | ||
|
||
import boto3 | ||
import botocore | ||
import tqdm | ||
|
||
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER | ||
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo | ||
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation | ||
|
||
s3transfer = boto3.s3.transfer | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
# pylint: disable=R1705,R1710 | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def human_readable_size(value: int) -> str: | ||
"""Convert a size in bytes into a human readable format. | ||
|
||
For example:: | ||
|
||
>>> human_readable_size(1) | ||
'1 Byte' | ||
>>> human_readable_size(10) | ||
'10 Bytes' | ||
>>> human_readable_size(1024) | ||
'1.0 KiB' | ||
>>> human_readable_size(1024 * 1024) | ||
'1.0 MiB' | ||
|
||
:param value: The size in bytes. | ||
:return: The size in a human readable format based on base-2 units. | ||
|
||
""" | ||
base = 1024 | ||
bytes_int = float(value) | ||
|
||
if bytes_int == 1: | ||
return "1 Byte" | ||
elif bytes_int < base: | ||
return "%d Bytes" % bytes_int | ||
|
||
for i, suffix in enumerate(("KiB", "MiB", "GiB", "TiB", "PiB", "EiB")): | ||
unit = base ** (i + 2) | ||
if round((bytes_int / unit) * base) < base: | ||
return "%.1f %s" % ((base * bytes_int / unit), suffix) | ||
|
||
|
||
class MultiPartCopyHandler(object): | ||
"""Multi Part Copy Handler class.""" | ||
|
||
WORKERS = 20 | ||
MULTIPART_CONFIG = 8 * (1024**2) | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__( | ||
self, | ||
region: str, | ||
files: List[FileInfo], | ||
dest_location: S3ObjectLocation, | ||
): | ||
"""Something.""" | ||
self.region = region | ||
self.files = files | ||
self.dest_location = dest_location | ||
|
||
config = botocore.config.Config(max_pool_connections=self.WORKERS) | ||
self.s3_client = boto3.client("s3", region_name=self.region, config=config) | ||
transfer_config = s3transfer.TransferConfig( | ||
multipart_threshold=self.MULTIPART_CONFIG, | ||
multipart_chunksize=self.MULTIPART_CONFIG, | ||
max_bandwidth=True, | ||
use_threads=True, | ||
max_concurrency=self.WORKERS, | ||
) | ||
self.transfer_manager = s3transfer.create_transfer_manager( | ||
client=self.s3_client, config=transfer_config | ||
) | ||
|
||
def _copy_file(self, file: FileInfo, update_fn): | ||
"""Something.""" | ||
copy_source = {"Bucket": file.location.bucket, "Key": file.location.key} | ||
result = self.transfer_manager.copy( | ||
bucket=self.dest_location.bucket, | ||
key=f"{self.dest_location.key}/{file.location.key}", | ||
copy_source=copy_source, | ||
subscribers=[ | ||
s3transfer.ProgressCallbackInvoker(update_fn), | ||
], | ||
) | ||
result.result() | ||
|
||
def execute(self): | ||
"""Something.""" | ||
total_size = sum([file.size for file in self.files]) | ||
JUMPSTART_LOGGER.warning( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think should be a warning. can we modify |
||
"Copying %s files (%s) into %s/%s", | ||
len(self.files), | ||
human_readable_size(total_size), | ||
self.dest_location.bucket, | ||
self.dest_location.key, | ||
) | ||
|
||
progress = tqdm.tqdm( | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
desc="JumpStart Sync", | ||
total=total_size, | ||
unit="B", | ||
unit_scale=1, | ||
position=0, | ||
bar_format="{desc:<10}{percentage:3.0f}%|{bar:10}{r_bar}", | ||
) | ||
|
||
for file in self.files: | ||
self._copy_file(file, progress.update) | ||
|
||
self.transfer_manager.shutdown() | ||
progress.close() | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
47 changes: 47 additions & 0 deletions
47
src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module utilites to assist S3 client calls for the Curated Hub.""" | ||
from __future__ import absolute_import | ||
from dataclasses import dataclass | ||
from typing import Dict | ||
|
||
from sagemaker.s3_utils import parse_s3_url | ||
|
||
|
||
@dataclass | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class S3ObjectLocation: | ||
"""Helper class for S3 object references""" | ||
|
||
bucket: str | ||
key: str | ||
|
||
def format_for_s3_copy(self) -> Dict[str, str]: | ||
"""Returns a dict formatted for S3 copy calls""" | ||
return { | ||
"Bucket": self.bucket, | ||
"Key": self.key, | ||
} | ||
|
||
def get_uri(self) -> str: | ||
"""Returns the s3 URI""" | ||
return f"s3://{self.bucket}/{self.key}" | ||
|
||
|
||
def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: | ||
"""Utiity to help generate an S3 object reference""" | ||
bucket, key = parse_s3_url(s3_uri) | ||
|
||
return S3ObjectLocation( | ||
bucket=bucket, | ||
key=key, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.