Skip to content

Commit 374c638

Browse files
committed
adding tests and update clases
1 parent 344d26b commit 374c638

File tree

8 files changed

+260
-71
lines changed

8 files changed

+260
-71
lines changed

src/sagemaker/jumpstart/curated_hub/accessors/filegenerator.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,28 @@ def __init__(
3535

3636
@singledispatchmethod
3737
def format(self, file_input) -> List[FileInfo]:
38-
"""Implement."""
38+
"""Dispatch method that takes in an input of either ``S3ObjectLocation`` or
39+
``JumpStartModelSpecs`` and is implemented in below registered functions.
40+
"""
3941
# pylint: disable=W0107
4042
pass
4143

4244
@format.register
4345
def _(self, file_input: S3ObjectLocation) -> List[FileInfo]:
44-
"""Something."""
46+
"""Implements ``.format`` when the input is of type ``S3ObjectLocation``.
47+
48+
Returns a list of ``FileInfo`` objects from the specified bucket location.
49+
"""
4550
files = self.s3_format(file_input)
4651
return files
4752

4853
@format.register
4954
def _(self, file_input: JumpStartModelSpecs) -> List[FileInfo]:
50-
"""Something."""
55+
"""Implements ``.format`` when the input is of type ``JumpStartModelSpecs``.
56+
57+
Returns a list of ``FileInfo`` objects from dependencies found in the public
58+
model specs.
59+
"""
5160
files = self.specs_format(file_input, self.studio_specs)
5261
return files
5362

@@ -72,36 +81,16 @@ def s3_format(self, file_input: S3ObjectLocation) -> List[FileInfo]:
7281
def specs_format(
7382
self, file_input: JumpStartModelSpecs, studio_specs: Dict[str, Any]
7483
) -> List[FileInfo]:
75-
"""Collects data locations from JumpStart public model specs and
76-
converts into FileInfo.
84+
"""
85+
Collects data locations from JumpStart public model specs and
86+
converts into FileInfo.
7787
"""
7888
public_model_data_accessor = PublicModelDataAccessor(
7989
region=self.region, model_specs=file_input, studio_specs=studio_specs
8090
)
81-
function_table = {
82-
HubContentDependencyType.INFERENCE_ARTIFACT: (
83-
public_model_data_accessor.get_inference_artifact_s3_reference
84-
),
85-
HubContentDependencyType.TRAINING_ARTIFACT: (
86-
public_model_data_accessor.get_training_artifact_s3_reference
87-
),
88-
HubContentDependencyType.INFERNECE_SCRIPT: (
89-
public_model_data_accessor.get_inference_script_s3_reference
90-
),
91-
HubContentDependencyType.TRAINING_SCRIPT: (
92-
public_model_data_accessor.get_training_script_s3_reference
93-
),
94-
HubContentDependencyType.DEFAULT_TRAINING_DATASET: (
95-
public_model_data_accessor.get_default_training_dataset_s3_reference
96-
),
97-
HubContentDependencyType.DEMO_NOTEBOOK: (
98-
public_model_data_accessor.get_demo_notebook_s3_reference
99-
),
100-
HubContentDependencyType.MARKDOWN: public_model_data_accessor.get_markdown_s3_reference,
101-
}
10291
files = []
10392
for dependency in HubContentDependencyType:
104-
location = function_table[dependency]()
93+
location = public_model_data_accessor.get_s3_reference(dependency)
10594
parameters = {"Bucket": location.bucket, "Prefix": location.key}
10695
response = self.s3_client.head_object(**parameters)
10796
key: str = location.key

src/sagemaker/jumpstart/curated_hub/accessors/fileinfo.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
class HubContentDependencyType(str, Enum):
2222
"""Enum class for HubContent dependency names"""
2323

24-
INFERENCE_ARTIFACT = "INFERENCE_ARTIFACT"
25-
TRAINING_ARTIFACT = "TRAINING_ARTIFACT"
26-
INFERNECE_SCRIPT = "INFERENCE_SCRIPT"
27-
TRAINING_SCRIPT = "TRAINING_SCRIPT"
28-
DEFAULT_TRAINING_DATASET = "DEFAULT_TRAINING_DATASET"
29-
DEMO_NOTEBOOK = "DEMO_NOTEBOOK"
30-
MARKDOWN = "MARKDOWN"
24+
INFERENCE_ARTIFACT = "inference_artifact_s3_reference"
25+
TRAINING_ARTIFACT = "training_artifact_s3_reference"
26+
INFERENCE_SCRIPT = "inference_script_s3_reference"
27+
TRAINING_SCRIPT = "training_script_s3_reference"
28+
DEFAULT_TRAINING_DATASET = "default_training_dataset_s3_reference"
29+
DEMO_NOTEBOOK = "demo_notebook_s3_reference"
30+
MARKDOWN = "markdown_s3_reference"
3131

3232

3333
@dataclass

src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515
from typing import Dict, Any
1616
from sagemaker import model_uris, script_uris
17+
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import HubContentDependencyType
1718
from sagemaker.jumpstart.curated_hub.utils import (
1819
get_model_framework,
1920
)
@@ -40,53 +41,92 @@ def __init__(
4041
self.model_specs = model_specs
4142
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift
4243

43-
def get_bucket_name(self) -> str:
44+
def get_s3_reference(self, dependency_type: HubContentDependencyType):
45+
"""Retrieves S3 reference given a HubContentDependencyType."""
46+
return getattr(self, dependency_type.value)
47+
48+
@property
49+
def inference_artifact_s3_reference(self):
50+
"""Retrieves s3 reference for model inference artifact"""
51+
return self._get_inference_artifact_s3_reference()
52+
53+
@property
54+
def training_artifact_s3_reference(self):
55+
"""Retrieves s3 reference for model training artifact"""
56+
return self._get_training_artifact_s3_reference()
57+
58+
@property
59+
def inference_script_s3_reference(self):
60+
"""Retrieves s3 reference for model inference script"""
61+
return self._get_inference_script_s3_reference()
62+
63+
@property
64+
def training_script_s3_reference(self):
65+
"""Retrieves s3 reference for model training script"""
66+
return self._get_training_script_s3_reference()
67+
68+
@property
69+
def default_training_dataset_s3_reference(self):
70+
"""Retrieves s3 reference for s3 directory containing model training datasets"""
71+
return self._get_default_training_dataset_s3_reference()
72+
73+
@property
74+
def demo_notebook_s3_reference(self):
75+
"""Retrieves s3 reference for model demo jupyter notebook"""
76+
return self._get_demo_notebook_s3_reference()
77+
78+
@property
79+
def markdown_s3_reference(self):
80+
"""Retrieves s3 reference for model markdown"""
81+
return self._get_markdown_s3_reference()
82+
83+
def _get_bucket_name(self) -> str:
4484
"""Retrieves s3 bucket"""
4585
return self._bucket
4686

47-
def get_inference_artifact_s3_reference(self) -> S3ObjectLocation:
87+
def _get_inference_artifact_s3_reference(self) -> S3ObjectLocation:
4888
"""Retrieves s3 reference for model inference artifact"""
4989
return create_s3_object_reference_from_uri(
5090
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE)
5191
)
5292

53-
def get_training_artifact_s3_reference(self) -> S3ObjectLocation:
93+
def _get_training_artifact_s3_reference(self) -> S3ObjectLocation:
5494
"""Retrieves s3 reference for model training artifact"""
5595
return create_s3_object_reference_from_uri(
5696
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING)
5797
)
5898

59-
def get_inference_script_s3_reference(self) -> S3ObjectLocation:
99+
def _get_inference_script_s3_reference(self) -> S3ObjectLocation:
60100
"""Retrieves s3 reference for model inference script"""
61101
return create_s3_object_reference_from_uri(
62102
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE)
63103
)
64104

65-
def get_training_script_s3_reference(self) -> S3ObjectLocation:
105+
def _get_training_script_s3_reference(self) -> S3ObjectLocation:
66106
"""Retrieves s3 reference for model training script"""
67107
return create_s3_object_reference_from_uri(
68108
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING)
69109
)
70110

71-
def get_default_training_dataset_s3_reference(self) -> S3ObjectLocation:
111+
def _get_default_training_dataset_s3_reference(self) -> S3ObjectLocation:
72112
"""Retrieves s3 reference for s3 directory containing model training datasets"""
73-
return S3ObjectLocation(self.get_bucket_name(), self._get_training_dataset_prefix())
113+
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix())
74114

75-
def _get_training_dataset_prefix(self) -> str:
115+
def __get_training_dataset_prefix(self) -> str:
76116
"""Retrieves training dataset location"""
77117
return self.studio_specs["defaultDataKey"]
78118

79-
def get_demo_notebook_s3_reference(self) -> S3ObjectLocation:
119+
def _get_demo_notebook_s3_reference(self) -> S3ObjectLocation:
80120
"""Retrieves s3 reference for model demo jupyter notebook"""
81121
framework = get_model_framework(self.model_specs)
82122
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb"
83-
return S3ObjectLocation(self.get_bucket_name(), key)
123+
return S3ObjectLocation(self._get_bucket_name(), key)
84124

85-
def get_markdown_s3_reference(self) -> S3ObjectLocation:
125+
def _get_markdown_s3_reference(self) -> S3ObjectLocation:
86126
"""Retrieves s3 reference for model markdown"""
87127
framework = get_model_framework(self.model_specs)
88128
key = f"{framework}-metadata/{self.model_specs.model_id}.md"
89-
return S3ObjectLocation(self.get_bucket_name(), key)
129+
return S3ObjectLocation(self._get_bucket_name(), key)
90130

91131
def _jumpstart_script_s3_uri(self, model_scope: str) -> str:
92132
"""Retrieves JumpStart script s3 location"""

src/sagemaker/jumpstart/curated_hub/accessors/sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""This module provides a class to help copy HubContent dependencies."""
13+
"""This module provides a class that perfrms functionalities similar to ``S3:Copy``."""
1414
from __future__ import absolute_import
1515
from typing import Generator, List
1616

src/sagemaker/jumpstart/curated_hub/accessors/synccomparator.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module provides comparators for syncing s3 files."""
1414
from __future__ import absolute_import
15+
from datetime import timedelta
1516

1617
from sagemaker.jumpstart.curated_hub.accessors.fileinfo import FileInfo
1718

@@ -36,14 +37,6 @@ def determine_should_sync(self, src_file: FileInfo, dest_file: FileInfo) -> bool
3637
)
3738
return should_sync
3839

39-
def total_seconds(self, td):
40-
"""
41-
timedelta's time_seconds() function for python 2.6 users
42-
43-
:param td: The difference between two datetime objects.
44-
"""
45-
return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6
46-
4740
def compare_size(self, src_file: FileInfo, dest_file: FileInfo):
4841
"""
4942
:returns: True if the sizes are the same.
@@ -62,7 +55,7 @@ def compare_time(self, src_file: FileInfo, dest_file: FileInfo):
6255
dest_time = dest_file.last_updated
6356
delta = dest_time - src_time
6457
# pylint: disable=R1703,R1705
65-
if self.total_seconds(delta) >= 0:
58+
if timedelta.total_seconds(delta) >= 0:
6659
# Destination is newer than source.
6760
return True
6861
else:

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import boto3
2020
from botocore.client import BaseClient
21+
from packaging.version import Version
2122

2223
from sagemaker.jumpstart.curated_hub.accessors.filegenerator import FileGenerator
2324
from sagemaker.jumpstart.curated_hub.accessors.objectlocation import S3ObjectLocation
@@ -176,53 +177,69 @@ def sync(self, model_list: List[Dict[str, str]]):
176177
# Retrieve latest version of unspecified JumpStart model versions
177178
model_version_list = []
178179
for model in model_list:
179-
# TODO: Uncomment and implement
180-
# if not model["version"] or model["version"] == "*":
181-
# model["version"] = self._find_latest_version(model_name=model.model_id)
180+
version = model.get("version", "*")
181+
if not version or version == "*":
182+
model_specs = verify_model_region_and_return_specs(
183+
model["model_id"], version, JumpStartScriptScope.INFERENCE, self.region
184+
)
185+
model["version"] = model_specs.version
182186
model_version_list.append(model)
183187

184188
# Find synced JumpStart model versions in the Hub
185189
js_models_in_hub = []
186-
for model in hub_models:
190+
for hub_model in hub_models:
187191
# TODO: extract both in one pass
188192
jumpstart_model_id = next(
189-
(tag for tag in model.search_keywords if tag.startswith("@jumpstart-model-id")),
193+
(
194+
tag
195+
for tag in hub_model["search_keywords"]
196+
if tag.startswith("@jumpstart-model-id")
197+
),
190198
None,
191199
)
192200
jumpstart_model_version = next(
193201
(
194202
tag
195-
for tag in model.search_keywords
203+
for tag in hub_model["search_keywords"]
196204
if tag.startswith("@jumpstart-model-version")
197205
),
198206
None,
199207
)
200208

201209
if jumpstart_model_id and jumpstart_model_version:
202-
js_models_in_hub.append(model)
210+
js_models_in_hub.append(hub_model)
203211

204212
# Match inputted list of model versions with synced JumpStart model versions in the Hub
205213
models_to_sync = []
206214
for model in model_version_list:
207-
model_id, version = model
208-
matched_model = next((model for model in js_models_in_hub if model.name == model_id))
215+
matched_model = next(
216+
(
217+
hub_model
218+
for hub_model in js_models_in_hub
219+
if hub_model and hub_model["name"] == model["model_id"]
220+
),
221+
None,
222+
)
209223

210224
# Model does not exist in Hub, sync
211225
if not matched_model:
212226
models_to_sync.append(model)
213227

214228
if matched_model:
229+
model_version = Version(model["version"])
230+
hub_model_version = Version(matched_model["version"])
231+
215232
# 1. Model version exists in Hub, pass
216-
if matched_model.version == version:
233+
if hub_model_version == model_version:
217234
pass
218235

219236
# 2. Invalid model version exists in Hub, pass
220237
# This will only happen if something goes wrong in our metadata
221-
if matched_model.version > version:
238+
if hub_model_version > model_version:
222239
pass
223240

224241
# 3. Old model version exists in Hub, update
225-
if matched_model.version < version:
242+
if hub_model_version < model_version:
226243
# Check minSDKVersion against current SDK version, emit log
227244
models_to_sync.append(model)
228245

@@ -269,6 +286,7 @@ def _sync_public_model_to_hub(self, model: Dict[str, str]):
269286
scope=JumpStartScriptScope.INFERENCE,
270287
sagemaker_session=self._sagemaker_session,
271288
)
289+
272290
# TODO: Uncomment and implement
273291
# studio_specs = self.fetch_studio_specs(model_id=model_name, version=model_version)
274292
studio_specs = {}
@@ -282,12 +300,17 @@ def _sync_public_model_to_hub(self, model: Dict[str, str]):
282300
src_files = file_generator.format(model_specs)
283301
dest_files = file_generator.format(dest_location)
284302

285-
files_to_copy = list(FileSync(src_files, dest_files, dest_location).call())
303+
files_to_copy = FileSync(src_files, dest_files, dest_location).call()
286304

287305
if len(files_to_copy) > 0:
288-
# Copy files with MPU
306+
# TODO: Copy files with MPU
289307
print("hi")
290308

309+
# Tag model if specs say it is deprecated or training/inference vulnerable
310+
# Update tag of HubContent ARN without version. Versioned ARNs are not
311+
# onboarded to Tagris.
312+
tags = []
313+
291314
hub_content_document = HubContentDocument_v2(spec=model_specs)
292315

293316
self._sagemaker_session.import_hub_content(
@@ -301,4 +324,5 @@ def _sync_public_model_to_hub(self, model: Dict[str, str]):
301324
hub_content_description="",
302325
hub_content_markdown="",
303326
hub_content_search_keywords=[],
327+
tags=tags,
304328
)

0 commit comments

Comments
 (0)