Skip to content

Commit d2f72a2

Browse files
committed
add hub_arn support for accept_types, content_types, serializers, deserializers, and predictor (aws#4463)
1 parent 92e35c8 commit d2f72a2

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,7 @@ def get_init_kwargs(
726726
model_version: Optional[str] = None,
727727
hub_arn: Optional[str] = None,
728728
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
729+
hub_arn: Optional[str] = None,
729730
tolerate_vulnerable_model: Optional[bool] = None,
730731
tolerate_deprecated_model: Optional[bool] = None,
731732
instance_type: Optional[str] = None,
@@ -759,6 +760,7 @@ def get_init_kwargs(
759760
model_version=model_version,
760761
hub_arn=hub_arn,
761762
model_type=model_type,
763+
hub_arn=hub_arn,
762764
instance_type=instance_type,
763765
region=region,
764766
image_uri=image_uri,

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
13101310
"model_version",
13111311
"hub_arn",
13121312
"model_type",
1313+
"hub_arn",
13131314
"instance_type",
13141315
"tolerate_vulnerable_model",
13151316
"tolerate_deprecated_model",
@@ -1342,6 +1343,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
13421343
"model_version",
13431344
"hub_arn",
13441345
"model_type",
1346+
"hub_arn",
13451347
"tolerate_vulnerable_model",
13461348
"tolerate_deprecated_model",
13471349
"region",
@@ -1355,6 +1357,7 @@ def __init__(
13551357
model_version: Optional[str] = None,
13561358
hub_arn: Optional[str] = None,
13571359
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
1360+
hub_arn: Optional[str] = None,
13581361
region: Optional[str] = None,
13591362
instance_type: Optional[str] = None,
13601363
image_uri: Optional[Union[str, Any]] = None,
@@ -1386,6 +1389,7 @@ def __init__(
13861389
self.model_version = model_version
13871390
self.hub_arn = hub_arn
13881391
self.model_type = model_type
1392+
self.hub_arn = hub_arn
13891393
self.instance_type = instance_type
13901394
self.region = region
13911395
self.image_uri = image_uri

0 commit comments

Comments
 (0)