Skip to content

Commit ce73f62

Browse files
committed
docstrings and linting
1 parent 27240a9 commit ce73f62

File tree

7 files changed

+33
-32
lines changed

7 files changed

+33
-32
lines changed

src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from __future__ import absolute_import
1515
from typing import Any, Dict, List
1616

17-
from datetime import datetime
1817
from botocore.client import BaseClient
1918

2019
from sagemaker.jumpstart.curated_hub.types import (
@@ -90,7 +89,5 @@ def generate_file_infos_from_model_specs(
9089
response = s3_client.head_object(**parameters)
9190
size = response.get("ContentLength")
9291
last_updated = response.get("LastModified")
93-
files.append(
94-
FileInfo(location.bucket, location.key, size, last_updated, dependency)
95-
)
92+
files.append(FileInfo(location.bucket, location.key, size, last_updated, dependency))
9693
return files

src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from sagemaker.jumpstart.curated_hub.types import (
1818
HubContentDependencyType,
1919
S3ObjectLocation,
20-
create_s3_object_reference_from_uri,
2120
)
21+
from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri
2222
from sagemaker.jumpstart.enums import JumpStartScriptScope
2323
from sagemaker.jumpstart.types import JumpStartModelSpecs
2424
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@
4646
from sagemaker.jumpstart.curated_hub.utils import (
4747
create_hub_bucket_if_it_does_not_exist,
4848
generate_default_hub_bucket_name,
49+
create_s3_object_reference_from_uri,
4950
)
5051
from sagemaker.jumpstart.curated_hub.types import (
5152
HubContentDocument_v2,
5253
JumpStartModelInfo,
5354
S3ObjectLocation,
54-
create_s3_object_reference_from_uri,
5555
)
5656

5757

@@ -302,7 +302,7 @@ def sync(self, model_list: List[Dict[str, str]]):
302302
model_version_list.append(JumpStartModelInfo(model["model_id"], model["version"]))
303303

304304
js_models_in_hub = self._get_jumpstart_models_in_hub()
305-
mapped_models_in_hub = { model["name"]: model for model in js_models_in_hub }
305+
mapped_models_in_hub = {model["name"]: model for model in js_models_in_hub}
306306

307307
models_to_sync = self._determine_models_to_sync(model_version_list, mapped_models_in_hub)
308308
JUMPSTART_LOGGER.warning(
@@ -316,9 +316,9 @@ def sync(self, model_list: List[Dict[str, str]]):
316316
with futures.ThreadPoolExecutor(
317317
max_workers=self._default_thread_pool_size,
318318
thread_name_prefix="import-models-to-curated-hub",
319-
) as deploy_executor:
319+
) as import_executor:
320320
for thread_num, model in enumerate(models_to_sync):
321-
task = deploy_executor.submit(self._sync_public_model_to_hub, model, thread_num)
321+
task = import_executor.submit(self._sync_public_model_to_hub, model, thread_num)
322322
tasks.append(task)
323323

324324
# Handle failed imports
@@ -353,7 +353,7 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
353353

354354
dest_location = S3ObjectLocation(
355355
bucket=self.hub_storage_location.bucket,
356-
key=f"{self.hub_storage_location.key}/{model.model_id}/{model.version}"
356+
key=f"{self.hub_storage_location.key}/curated_models/{model.model_id}/{model.version}",
357357
)
358358
src_files = file_generator.generate_file_infos_from_model_specs(
359359
model_specs, studio_specs, self.region, self._s3_client

src/sagemaker/jumpstart/curated_hub/sync/request.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,21 @@ def __init__(
3535
):
3636
"""Contains information required to sync data into a Hub.
3737
38-
Returns:
39-
:var: files (List[FileInfo]): Files that shoudl be synced.
40-
:var: destination (S3ObjectLocation): Location to which to sync the files.
38+
Attrs:
39+
files (List[FileInfo]): Files that should be synced.
40+
destination (S3ObjectLocation): Location to which to sync the files.
4141
"""
4242
self.files = list(files_to_copy)
4343
self.destination = destination
4444

4545

4646
class HubSyncRequestFactory:
47-
"""Generates a ``HubSyncRequest`` which is required to sync data into a Hub."""
47+
"""Generates a ``HubSyncRequest`` which is required to sync data into a Hub.
48+
49+
Creates a ``HubSyncRequest`` class containing:
50+
:var: files (List[FileInfo]): Files that should be synced.
51+
:var: destination (S3ObjectLocation): Location to which to sync the files.
52+
"""
4853

4954
def __init__(
5055
self,
@@ -59,11 +64,6 @@ def __init__(
5964
src_files (List[FileInfo]): List of files to sync to destination bucket
6065
dest_files (List[FileInfo]): List of files already in destination bucket
6166
destination (S3ObjectLocation): S3 destination for copied data
62-
63-
Returns:
64-
``HubSyncRequest`` class containing:
65-
:var: files (List[FileInfo]): Files that shoudl be synced.
66-
:var: destination (S3ObjectLocation): Location to which to sync the files.
6767
"""
6868
self.comparator = comparator
6969
self.destination = destination

src/sagemaker/jumpstart/curated_hub/types.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from datetime import datetime
1919

2020
from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs
21-
from sagemaker.s3_utils import parse_s3_url
2221

2322

2423
@dataclass
@@ -40,16 +39,6 @@ def get_uri(self) -> str:
4039
return f"s3://{self.bucket}/{self.key}"
4140

4241

43-
def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation:
44-
"""Utiity to help generate an S3 object reference"""
45-
bucket, key = parse_s3_url(s3_uri)
46-
47-
return S3ObjectLocation(
48-
bucket=bucket,
49-
key=key,
50-
)
51-
52-
5342
@dataclass
5443
class JumpStartModelInfo:
5544
"""Helper class for storing JumpStart model info."""

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from __future__ import absolute_import
1515
import re
1616
from typing import Optional
17+
from sagemaker.jumpstart.curated_hub.types import S3ObjectLocation
18+
from sagemaker.s3_utils import parse_s3_url
1719
from sagemaker.session import Session
1820
from sagemaker.utils import aws_partition
1921
from sagemaker.jumpstart.types import (
@@ -131,6 +133,16 @@ def generate_default_hub_bucket_name(
131133
return f"sagemaker-hubs-{region}-{account_id}"
132134

133135

136+
def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation:
137+
"""Utiity to help generate an S3 object reference"""
138+
bucket, key = parse_s3_url(s3_uri)
139+
140+
return S3ObjectLocation(
141+
bucket=bucket,
142+
key=key,
143+
)
144+
145+
134146
def create_hub_bucket_if_it_does_not_exist(
135147
bucket_name: Optional[str] = None,
136148
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
FAKE_TIME = datetime.datetime(1997, 8, 14, 00, 00, 00)
3434

35+
3536
@pytest.fixture()
3637
def sagemaker_session():
3738
boto_mock = Mock(name="boto_session")
@@ -80,7 +81,9 @@ def test_create_with_no_bucket_name(
8081
hub_search_keywords,
8182
tags,
8283
):
83-
storage_location = S3ObjectLocation("sagemaker-hubs-us-east-1-123456789123",f"{hub_name}-{FAKE_TIME.timestamp()}")
84+
storage_location = S3ObjectLocation(
85+
"sagemaker-hubs-us-east-1-123456789123", f"{hub_name}-{FAKE_TIME.timestamp()}"
86+
)
8487
mock_generate_hub_storage_location.return_value = storage_location
8588
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
8689
sagemaker_session.create_hub = Mock(return_value=create_hub)
@@ -133,7 +136,7 @@ def test_create_with_bucket_name(
133136
hub_search_keywords,
134137
tags,
135138
):
136-
storage_location = S3ObjectLocation(hub_bucket_name,f"{hub_name}-{FAKE_TIME.timestamp()}")
139+
storage_location = S3ObjectLocation(hub_bucket_name, f"{hub_name}-{FAKE_TIME.timestamp()}")
137140
mock_generate_hub_storage_location.return_value = storage_location
138141
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
139142
sagemaker_session.create_hub = Mock(return_value=create_hub)

0 commit comments

Comments
 (0)