-
Notifications
You must be signed in to change notification settings - Fork 1.2k
MultiPartCopy with Sync Algorithm #4475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
344d26b
374c638
67d8ec8
2fa0503
30c2b91
ef57f14
c44acd2
297d1b6
18a5728
4554d34
c7f3f96
28c9186
97001cc
27240a9
ce73f62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains important utilities related to HubContent data files.""" | ||
from __future__ import absolute_import | ||
from typing import Any, Dict, List, Optional | ||
|
||
from datetime import datetime | ||
from botocore.client import BaseClient | ||
|
||
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType | ||
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation | ||
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor | ||
from sagemaker.jumpstart.types import JumpStartModelSpecs | ||
|
||
|
||
class FileGenerator: | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Utility class to help format HubContent data files.""" | ||
|
||
def __init__( | ||
self, region: str, s3_client: BaseClient, studio_specs: Optional[Dict[str, Any]] = None | ||
): | ||
self.region = region | ||
self.s3_client = s3_client | ||
self.studio_specs = studio_specs | ||
|
||
def format(self, file_input) -> List[FileInfo]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need this? |
||
"""Dispatch method that is implemented in below registered functions.""" | ||
raise NotImplementedError | ||
|
||
|
||
class S3PathFileGenerator(FileGenerator): | ||
"""Utility class to help format all objects in an S3 bucket.""" | ||
|
||
def format(self, file_input: S3ObjectLocation) -> List[FileInfo]: | ||
"""Retrieves data from an S3 bucket and formats into FileInfo. | ||
|
||
Returns a list of ``FileInfo`` objects from the specified bucket location. | ||
""" | ||
parameters = {"Bucket": file_input.bucket, "Prefix": file_input.key} | ||
response = self.s3_client.list_objects_v2(**parameters) | ||
contents = response.get("Contents", None) | ||
|
||
if not contents: | ||
print("Nothing to download") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we are fine with regular print statements? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Placeholder for now, I'll double check if we want to use logger (prob the case) |
||
return [] | ||
|
||
files = [] | ||
for s3_obj in contents: | ||
key: str = s3_obj.get("Key") | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
size: bytes = s3_obj.get("Size", None) | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
last_modified: str = s3_obj.get("LastModified", None) | ||
files.append(FileInfo(file_input.bucket, key, size, last_modified)) | ||
return files | ||
|
||
|
||
class ModelSpecsFileGenerator(FileGenerator): | ||
"""Utility class to help format all data paths from JumpStart public model specs.""" | ||
|
||
def format(self, file_input: JumpStartModelSpecs) -> List[FileInfo]: | ||
"""Collects data locations from JumpStart public model specs and converts into FileInfo`. | ||
|
||
Returns a list of ``FileInfo`` objects from dependencies found in the public | ||
model specs. | ||
""" | ||
public_model_data_accessor = PublicModelDataAccessor( | ||
region=self.region, model_specs=file_input, studio_specs=self.studio_specs | ||
) | ||
files = [] | ||
for dependency in HubContentDependencyType: | ||
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency) | ||
|
||
# Prefix | ||
if location.key[-1] == "/": | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parameters = {"Bucket": location.bucket, "Prefix": location.key} | ||
response = self.s3_client.list_objects_v2(**parameters) | ||
contents = response.get("Contents", None) | ||
for s3_obj in contents: | ||
key: str = s3_obj.get("Key") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't the typing |
||
size: bytes = s3_obj.get("Size", None) | ||
last_modified: datetime = s3_obj.get("LastModified", None) | ||
dependency_type: HubContentDependencyType = dependency | ||
files.append( | ||
FileInfo( | ||
location.bucket, | ||
key, | ||
size, | ||
last_modified, | ||
dependency_type, | ||
) | ||
) | ||
else: | ||
parameters = {"Bucket": location.bucket, "Key": location.key} | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
response = self.s3_client.head_object(**parameters) | ||
size: bytes = response.get("ContentLength", None) | ||
last_updated: datetime = response.get("LastModified", None) | ||
dependency_type: HubContentDependencyType = dependency | ||
files.append( | ||
FileInfo(location.bucket, location.key, size, last_updated, dependency_type) | ||
) | ||
return files |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains important details related to HubContent data files.""" | ||
from __future__ import absolute_import | ||
|
||
from enum import Enum | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
from datetime import datetime | ||
|
||
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation | ||
|
||
|
||
class HubContentDependencyType(str, Enum): | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Enum class for HubContent dependency names""" | ||
|
||
INFERENCE_ARTIFACT = "inference_artifact_s3_reference" | ||
TRAINING_ARTIFACT = "training_artifact_s3_reference" | ||
INFERENCE_SCRIPT = "inference_script_s3_reference" | ||
TRAINING_SCRIPT = "training_script_s3_reference" | ||
DEFAULT_TRAINING_DATASET = "default_training_dataset_s3_reference" | ||
DEMO_NOTEBOOK = "demo_notebook_s3_reference" | ||
MARKDOWN = "markdown_s3_reference" | ||
|
||
|
||
@dataclass | ||
class FileInfo: | ||
"""Data class for additional S3 file info.""" | ||
|
||
location: S3ObjectLocation | ||
|
||
def __init__( | ||
self, | ||
bucket: str, | ||
key: str, | ||
size: Optional[bytes], | ||
last_updated: Optional[datetime], | ||
dependecy_type: Optional[HubContentDependencyType] = None, | ||
): | ||
self.location = S3ObjectLocation(bucket, key) | ||
self.size = size | ||
self.last_updated = last_updated | ||
self.dependecy_type = dependecy_type |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module provides a class that perfrms functionalities similar to ``S3:Copy``.""" | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from __future__ import absolute_import | ||
from typing import List | ||
|
||
import boto3 | ||
import botocore | ||
import tqdm | ||
|
||
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER | ||
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo | ||
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation | ||
|
||
s3transfer = boto3.s3.transfer | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
# pylint: disable=R1705,R1710 | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def human_readable_size(value): | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Convert a size in bytes into a human readable format. | ||
|
||
For example:: | ||
|
||
>>> human_readable_size(1) | ||
'1 Byte' | ||
>>> human_readable_size(10) | ||
'10 Bytes' | ||
>>> human_readable_size(1024) | ||
'1.0 KiB' | ||
>>> human_readable_size(1024 * 1024) | ||
'1.0 MiB' | ||
|
||
:param value: The size in bytes. | ||
:return: The size in a human readable format based on base-2 units. | ||
|
||
""" | ||
base = 1024 | ||
bytes_int = float(value) | ||
|
||
if bytes_int == 1: | ||
return "1 Byte" | ||
elif bytes_int < base: | ||
return "%d Bytes" % bytes_int | ||
|
||
for i, suffix in enumerate(("KiB", "MiB", "GiB", "TiB", "PiB", "EiB")): | ||
unit = base ** (i + 2) | ||
if round((bytes_int / unit) * base) < base: | ||
return "%.1f %s" % ((base * bytes_int / unit), suffix) | ||
|
||
|
||
class MultiPartCopyHandler(object): | ||
"""Multi Part Copy Handler class.""" | ||
|
||
WORKERS = 20 | ||
MULTIPART_CONFIG = 8 * (1024**2) | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__( | ||
self, | ||
region: str, | ||
files: List[FileInfo], | ||
dest_location: S3ObjectLocation, | ||
): | ||
"""Something.""" | ||
self.region = region | ||
self.files = files | ||
self.dest_location = dest_location | ||
|
||
config = botocore.config.Config(max_pool_connections=self.WORKERS) | ||
self.s3_client = boto3.client("s3", region_name=self.region, config=config) | ||
transfer_config = s3transfer.TransferConfig( | ||
multipart_threshold=self.MULTIPART_CONFIG, | ||
multipart_chunksize=self.MULTIPART_CONFIG, | ||
max_bandwidth=True, | ||
use_threads=True, | ||
max_concurrency=self.WORKERS, | ||
) | ||
self.transfer_manager = s3transfer.create_transfer_manager( | ||
client=self.s3_client, config=transfer_config | ||
) | ||
|
||
def _copy_file(self, file: FileInfo, update_fn): | ||
"""Something.""" | ||
copy_source = {"Bucket": file.location.bucket, "Key": file.location.key} | ||
result = self.transfer_manager.copy( | ||
bucket=self.dest_location.bucket, | ||
key=f"{self.dest_location.key}/{file.location.key}", | ||
copy_source=copy_source, | ||
subscribers=[ | ||
s3transfer.ProgressCallbackInvoker(update_fn), | ||
], | ||
) | ||
result.result() | ||
|
||
def call(self): | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Something.""" | ||
total_size = sum([file.size for file in self.files]) | ||
JUMPSTART_LOGGER.warning( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think should be a warning. can we modify |
||
"Copying %s files (%s) into %s/%s", | ||
len(self.files), | ||
human_readable_size(total_size), | ||
self.dest_location.bucket, | ||
self.dest_location.key, | ||
) | ||
|
||
progress = tqdm.tqdm( | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
desc="JumpStart Sync", | ||
total=total_size, | ||
unit="B", | ||
unit_scale=1, | ||
position=0, | ||
bar_format="{desc:<10}{percentage:3.0f}%|{bar:10}{r_bar}", | ||
) | ||
|
||
for file in self.files: | ||
self._copy_file(file, progress.update) | ||
|
||
self.transfer_manager.shutdown() | ||
progress.close() | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module utilites to assist S3 client calls for the Curated Hub.""" | ||
from __future__ import absolute_import | ||
from dataclasses import dataclass | ||
from typing import Dict | ||
|
||
from sagemaker.s3_utils import parse_s3_url | ||
|
||
|
||
@dataclass | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class S3ObjectLocation: | ||
"""Helper class for S3 object references""" | ||
|
||
bucket: str | ||
key: str | ||
|
||
def format_for_s3_copy(self) -> Dict[str, str]: | ||
"""Returns a dict formatted for S3 copy calls""" | ||
return { | ||
"Bucket": self.bucket, | ||
"Key": self.key, | ||
} | ||
|
||
def get_uri(self) -> str: | ||
"""Returns the s3 URI""" | ||
return f"s3://{self.bucket}/{self.key}" | ||
|
||
|
||
def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: | ||
"""Utiity to help generate an S3 object reference""" | ||
bucket, key = parse_s3_url(s3_uri) | ||
|
||
return S3ObjectLocation( | ||
bucket=bucket, | ||
key=key, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.