Skip to content

MultiPartCopy with Sync Algorithm #4475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 12, 2024
Merged
11 changes: 8 additions & 3 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def __init__(
Default: None (no config).
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
used for SageMaker interactions. Default: Session in region associated with boto3 session.
used for SageMaker interactions. Default: Session in region associated with boto3
session.
"""

self._region = region
Expand Down Expand Up @@ -358,7 +359,9 @@ def _retrieval_function(
hub_content_type=data_type
)

model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True)
model_specs = JumpStartModelSpecs(
DescribeHubContentsResponse(hub_model_description), is_hub_content=True
)

utils.emit_logs_based_on_model_specs(
model_specs,
Expand All @@ -372,7 +375,9 @@ def _retrieval_function(
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)
hub_description = DescribeHubResponse(response)
return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description))
return JumpStartCachedContentValue(
formatted_content=DescribeHubResponse(hub_description)
)
raise ValueError(
f"Bad value for key '{key}': must be in ",
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}"
Expand Down
Empty file.
87 changes: 87 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains important utilities related to HubContent data files."""
from __future__ import absolute_import
from typing import Any, Dict, List, Optional

from botocore.client import BaseClient

from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo, HubContentDependencyType
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
from sagemaker.jumpstart.types import JumpStartModelSpecs


class FileGenerator:
"""Utility class to help format HubContent data files."""

def __init__(
self, region: str, s3_client: BaseClient, studio_specs: Optional[Dict[str, Any]] = None
):
self.region = region
self.s3_client = s3_client
self.studio_specs = studio_specs

def format(self, file_input) -> List[FileInfo]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this?

"""Dispatch method that is implemented in below registered functions."""
raise NotImplementedError


class S3PathFileGenerator(FileGenerator):
"""Utility class to help format all objects in an S3 bucket."""

def format(self, file_input: S3ObjectLocation) -> List[FileInfo]:
"""Retrieves data from an S3 bucket and formats into FileInfo.

Returns a list of ``FileInfo`` objects from the specified bucket location.
"""
parameters = {"Bucket": file_input.bucket, "Prefix": file_input.key}
response = self.s3_client.list_objects_v2(**parameters)
contents = response.get("Contents", None)

if not contents:
print("Nothing to download")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are fine with regular print statements?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Placeholder for now, I'll double check if we want to use logger (prob the case)

return []

files = []
for s3_obj in contents:
key: str = s3_obj.get("Key")
size: bytes = s3_obj.get("Size", None)
last_modified: str = s3_obj.get("LastModified", None)
files.append(FileInfo(key, size, last_modified))
return files


class ModelSpecsFileGenerator(FileGenerator):
"""Utility class to help format all data paths from JumpStart public model specs."""

def format(self, file_input: JumpStartModelSpecs) -> List[FileInfo]:
"""Collects data locations from JumpStart public model specs and converts into FileInfo`.

Returns a list of ``FileInfo`` objects from dependencies found in the public
model specs.
"""
public_model_data_accessor = PublicModelDataAccessor(
region=self.region, model_specs=file_input, studio_specs=self.studio_specs
)
files = []
for dependency in HubContentDependencyType:
location = public_model_data_accessor.get_s3_reference(dependency)
parameters = {"Bucket": location.bucket, "Prefix": location.key}
response = self.s3_client.head_object(**parameters)
key: str = location.key
size: bytes = response.get("ContentLength", None)
last_updated: str = response.get("LastModified", None)
dependency_type: HubContentDependencyType = dependency
files.append(FileInfo(key, size, last_updated, dependency_type))
return files
47 changes: 47 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains important details related to HubContent data files."""
from __future__ import absolute_import

from enum import Enum
from dataclasses import dataclass
from typing import Optional


class HubContentDependencyType(str, Enum):
"""Enum class for HubContent dependency names"""

INFERENCE_ARTIFACT = "inference_artifact_s3_reference"
TRAINING_ARTIFACT = "training_artifact_s3_reference"
INFERENCE_SCRIPT = "inference_script_s3_reference"
TRAINING_SCRIPT = "training_script_s3_reference"
DEFAULT_TRAINING_DATASET = "default_training_dataset_s3_reference"
DEMO_NOTEBOOK = "demo_notebook_s3_reference"
MARKDOWN = "markdown_s3_reference"


@dataclass
class FileInfo:
"""Data class for additional S3 file info."""

def __init__(
self,
name: str,
size: Optional[bytes],
last_updated: Optional[str],
dependecy_type: Optional[HubContentDependencyType] = None,
):
self.name = name
self.size = size
self.last_updated = last_updated
self.dependecy_type = dependecy_type
46 changes: 46 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/objectlocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module utilites to assist S3 client calls for the Curated Hub."""
from __future__ import absolute_import
from dataclasses import dataclass
from typing import Dict


@dataclass
class S3ObjectLocation:
"""Helper class for S3 object references"""

bucket: str
key: str

def format_for_s3_copy(self) -> Dict[str, str]:
"""Returns a dict formatted for S3 copy calls"""
return {
"Bucket": self.bucket,
"Key": self.key,
}

def get_uri(self) -> str:
"""Returns the s3 URI"""
return f"s3://{self.bucket}/{self.key}"


def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation:
"""Utiity to help generate an S3 object reference"""
uri_with_s3_prefix_removed = s3_uri.replace("s3://", "", 1)
uri_split = uri_with_s3_prefix_removed.split("/")

return S3ObjectLocation(
bucket=uri_split[0],
key="/".join(uri_split[1:]) if len(uri_split) > 1 else "",
)
123 changes: 123 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module accessors for the SageMaker JumpStart Public Hub."""
from __future__ import absolute_import
from typing import Dict, Any
from sagemaker import model_uris, script_uris
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import HubContentDependencyType
from sagemaker.jumpstart.curated_hub.utils import (
get_model_framework,
)
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.types import JumpStartModelSpecs
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import (
S3ObjectLocation,
create_s3_object_reference_from_uri,
)


class PublicModelDataAccessor:
"""Accessor class for JumpStart model data s3 locations."""

def __init__(
self,
region: str,
model_specs: JumpStartModelSpecs,
studio_specs: Dict[str, Dict[str, Any]],
):
self._region = region
self._bucket = get_jumpstart_content_bucket(region)
self.model_specs = model_specs
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift

def get_s3_reference(self, dependency_type: HubContentDependencyType):
"""Retrieves S3 reference given a HubContentDependencyType."""
return getattr(self, dependency_type.value)

@property
def inference_artifact_s3_reference(self):
"""Retrieves s3 reference for model inference artifact"""
return create_s3_object_reference_from_uri(
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE)
)

@property
def training_artifact_s3_reference(self):
"""Retrieves s3 reference for model training artifact"""
return create_s3_object_reference_from_uri(
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING)
)

@property
def inference_script_s3_reference(self):
"""Retrieves s3 reference for model inference script"""
return create_s3_object_reference_from_uri(
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE)
)

@property
def training_script_s3_reference(self):
"""Retrieves s3 reference for model training script"""
return create_s3_object_reference_from_uri(
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING)
)

@property
def default_training_dataset_s3_reference(self):
"""Retrieves s3 reference for s3 directory containing model training datasets"""
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are there 2 underscores for self.__get_training_dataset_prefix?


@property
def demo_notebook_s3_reference(self):
"""Retrieves s3 reference for model demo jupyter notebook"""
framework = get_model_framework(self.model_specs)
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb"
return S3ObjectLocation(self._get_bucket_name(), key)

@property
def markdown_s3_reference(self):
"""Retrieves s3 reference for model markdown"""
framework = get_model_framework(self.model_specs)
key = f"{framework}-metadata/{self.model_specs.model_id}.md"
return S3ObjectLocation(self._get_bucket_name(), key)

def _get_bucket_name(self) -> str:
"""Retrieves s3 bucket"""
return self._bucket

def __get_training_dataset_prefix(self) -> str:
"""Retrieves training dataset location"""
return self.studio_specs["defaultDataKey"]

def _jumpstart_script_s3_uri(self, model_scope: str) -> str:
"""Retrieves JumpStart script s3 location"""
return script_uris.retrieve(
region=self._region,
model_id=self.model_specs.model_id,
model_version=self.model_specs.version,
script_scope=model_scope,
tolerate_vulnerable_model=True,
tolerate_deprecated_model=True,
)

def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str:
"""Retrieves JumpStart artifact s3 location"""
return model_uris.retrieve(
region=self._region,
model_id=self.model_specs.model_id,
model_version=self.model_specs.version,
model_scope=model_scope,
tolerate_vulnerable_model=True,
tolerate_deprecated_model=True,
)
Loading