-
Notifications
You must be signed in to change notification settings - Fork 1.2k
MultiPartCopy with Sync Algorithm #4475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
344d26b
374c638
67d8ec8
2fa0503
30c2b91
ef57f14
c44acd2
297d1b6
18a5728
4554d34
c7f3f96
28c9186
97001cc
27240a9
ce73f62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains important utilities related to HubContent data files.""" | ||
from __future__ import absolute_import | ||
from typing import Any, Dict, List, Optional | ||
|
||
from botocore.client import BaseClient | ||
|
||
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType | ||
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation | ||
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor | ||
from sagemaker.jumpstart.types import JumpStartModelSpecs | ||
|
||
|
||
class FileGenerator: | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Utility class to help format HubContent data files.""" | ||
|
||
def __init__( | ||
self, region: str, s3_client: BaseClient, studio_specs: Optional[Dict[str, Any]] = None | ||
): | ||
self.region = region | ||
self.s3_client = s3_client | ||
self.studio_specs = studio_specs | ||
|
||
def format(self, file_input) -> List[FileInfo]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need this? |
||
"""Dispatch method that is implemented in below registered functions.""" | ||
raise NotImplementedError | ||
|
||
|
||
class S3PathFileGenerator(FileGenerator): | ||
"""Utility class to help format all objects in an S3 bucket.""" | ||
|
||
def format(self, file_input: S3ObjectLocation) -> List[FileInfo]: | ||
"""Retrieves data from an S3 bucket and formats into FileInfo. | ||
|
||
Returns a list of ``FileInfo`` objects from the specified bucket location. | ||
""" | ||
parameters = {"Bucket": file_input.bucket, "Prefix": file_input.key} | ||
response = self.s3_client.list_objects_v2(**parameters) | ||
contents = response.get("Contents", None) | ||
|
||
if not contents: | ||
print("Nothing to download") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we are fine with regular print statements? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Placeholder for now, I'll double check if we want to use logger (prob the case) |
||
return [] | ||
|
||
files = [] | ||
for s3_obj in contents: | ||
key: str = s3_obj.get("Key") | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
size: bytes = s3_obj.get("Size", None) | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
last_modified: str = s3_obj.get("LastModified", None) | ||
files.append(FileInfo(key, size, last_modified)) | ||
return files | ||
|
||
|
||
class ModelSpecsFileGenerator(FileGenerator): | ||
"""Utility class to help format all data paths from JumpStart public model specs.""" | ||
|
||
def format(self, file_input: JumpStartModelSpecs) -> List[FileInfo]: | ||
"""Collects data locations from JumpStart public model specs and converts into FileInfo`. | ||
|
||
Returns a list of ``FileInfo`` objects from dependencies found in the public | ||
model specs. | ||
""" | ||
public_model_data_accessor = PublicModelDataAccessor( | ||
region=self.region, model_specs=file_input, studio_specs=self.studio_specs | ||
) | ||
files = [] | ||
for dependency in HubContentDependencyType: | ||
location = public_model_data_accessor.get_s3_reference(dependency) | ||
parameters = {"Bucket": location.bucket, "Prefix": location.key} | ||
response = self.s3_client.head_object(**parameters) | ||
key: str = location.key | ||
size: bytes = response.get("ContentLength", None) | ||
last_updated: str = response.get("LastModified", None) | ||
dependency_type: HubContentDependencyType = dependency | ||
files.append(FileInfo(key, size, last_updated, dependency_type)) | ||
return files |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains important details related to HubContent data files.""" | ||
from __future__ import absolute_import | ||
|
||
from enum import Enum | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
|
||
class HubContentDependencyType(str, Enum): | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Enum class for HubContent dependency names""" | ||
|
||
INFERENCE_ARTIFACT = "inference_artifact_s3_reference" | ||
TRAINING_ARTIFACT = "training_artifact_s3_reference" | ||
INFERENCE_SCRIPT = "inference_script_s3_reference" | ||
TRAINING_SCRIPT = "training_script_s3_reference" | ||
DEFAULT_TRAINING_DATASET = "default_training_dataset_s3_reference" | ||
DEMO_NOTEBOOK = "demo_notebook_s3_reference" | ||
MARKDOWN = "markdown_s3_reference" | ||
|
||
|
||
@dataclass | ||
class FileInfo: | ||
"""Data class for additional S3 file info.""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
size: Optional[bytes], | ||
last_updated: Optional[str], | ||
dependecy_type: Optional[HubContentDependencyType] = None, | ||
): | ||
self.name = name | ||
self.size = size | ||
self.last_updated = last_updated | ||
self.dependecy_type = dependecy_type |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module utilites to assist S3 client calls for the Curated Hub.""" | ||
from __future__ import absolute_import | ||
from dataclasses import dataclass | ||
from typing import Dict | ||
|
||
|
||
@dataclass | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class S3ObjectLocation: | ||
"""Helper class for S3 object references""" | ||
|
||
bucket: str | ||
key: str | ||
|
||
def format_for_s3_copy(self) -> Dict[str, str]: | ||
"""Returns a dict formatted for S3 copy calls""" | ||
return { | ||
"Bucket": self.bucket, | ||
"Key": self.key, | ||
} | ||
|
||
def get_uri(self) -> str: | ||
"""Returns the s3 URI""" | ||
return f"s3://{self.bucket}/{self.key}" | ||
|
||
|
||
def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation: | ||
"""Utiity to help generate an S3 object reference""" | ||
uri_with_s3_prefix_removed = s3_uri.replace("s3://", "", 1) | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
uri_split = uri_with_s3_prefix_removed.split("/") | ||
|
||
return S3ObjectLocation( | ||
bucket=uri_split[0], | ||
key="/".join(uri_split[1:]) if len(uri_split) > 1 else "", | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module accessors for the SageMaker JumpStart Public Hub.""" | ||
from __future__ import absolute_import | ||
from typing import Dict, Any | ||
from sagemaker import model_uris, script_uris | ||
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import HubContentDependencyType | ||
from sagemaker.jumpstart.curated_hub.utils import ( | ||
get_model_framework, | ||
) | ||
from sagemaker.jumpstart.enums import JumpStartScriptScope | ||
from sagemaker.jumpstart.types import JumpStartModelSpecs | ||
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket | ||
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import ( | ||
S3ObjectLocation, | ||
create_s3_object_reference_from_uri, | ||
) | ||
|
||
|
||
class PublicModelDataAccessor: | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Accessor class for JumpStart model data s3 locations.""" | ||
|
||
def __init__( | ||
self, | ||
region: str, | ||
model_specs: JumpStartModelSpecs, | ||
studio_specs: Dict[str, Dict[str, Any]], | ||
): | ||
self._region = region | ||
self._bucket = get_jumpstart_content_bucket(region) | ||
self.model_specs = model_specs | ||
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift | ||
|
||
def get_s3_reference(self, dependency_type: HubContentDependencyType): | ||
"""Retrieves S3 reference given a HubContentDependencyType.""" | ||
return getattr(self, dependency_type.value) | ||
|
||
@property | ||
def inference_artifact_s3_reference(self): | ||
"""Retrieves s3 reference for model inference artifact""" | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE) | ||
) | ||
|
||
@property | ||
def training_artifact_s3_reference(self): | ||
"""Retrieves s3 reference for model training artifact""" | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING) | ||
) | ||
|
||
@property | ||
def inference_script_s3_reference(self): | ||
"""Retrieves s3 reference for model inference script""" | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE) | ||
) | ||
|
||
@property | ||
def training_script_s3_reference(self): | ||
"""Retrieves s3 reference for model training script""" | ||
return create_s3_object_reference_from_uri( | ||
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING) | ||
) | ||
|
||
@property | ||
def default_training_dataset_s3_reference(self): | ||
"""Retrieves s3 reference for s3 directory containing model training datasets""" | ||
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are there 2 underscores for |
||
|
||
@property | ||
def demo_notebook_s3_reference(self): | ||
"""Retrieves s3 reference for model demo jupyter notebook""" | ||
framework = get_model_framework(self.model_specs) | ||
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb" | ||
return S3ObjectLocation(self._get_bucket_name(), key) | ||
|
||
@property | ||
def markdown_s3_reference(self): | ||
"""Retrieves s3 reference for model markdown""" | ||
framework = get_model_framework(self.model_specs) | ||
key = f"{framework}-metadata/{self.model_specs.model_id}.md" | ||
return S3ObjectLocation(self._get_bucket_name(), key) | ||
|
||
def _get_bucket_name(self) -> str: | ||
"""Retrieves s3 bucket""" | ||
return self._bucket | ||
|
||
def __get_training_dataset_prefix(self) -> str: | ||
"""Retrieves training dataset location""" | ||
return self.studio_specs["defaultDataKey"] | ||
|
||
def _jumpstart_script_s3_uri(self, model_scope: str) -> str: | ||
"""Retrieves JumpStart script s3 location""" | ||
return script_uris.retrieve( | ||
region=self._region, | ||
model_id=self.model_specs.model_id, | ||
model_version=self.model_specs.version, | ||
script_scope=model_scope, | ||
tolerate_vulnerable_model=True, | ||
tolerate_deprecated_model=True, | ||
bencrabtree marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str: | ||
"""Retrieves JumpStart artifact s3 location""" | ||
return model_uris.retrieve( | ||
region=self._region, | ||
model_id=self.model_specs.model_id, | ||
model_version=self.model_specs.version, | ||
model_scope=model_scope, | ||
tolerate_vulnerable_model=True, | ||
tolerate_deprecated_model=True, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.