Skip to content

Commit 761fa1d

Browse files
committed
set custom field from HCD config to model spec data class
1 parent 97990c6 commit 761fa1d

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/sagemaker/jumpstart/types.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,10 +1363,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
13631363
self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {}))
13641364
self.predictor_specs: Optional[JumpStartPredictorSpecs] = (
13651365
JumpStartPredictorSpecs(
1366-
json_obj.get("sage_maker_sdk_predictor_specifications"),
1366+
json_obj.get("predictor_specs"),
13671367
is_hub_content=self._is_hub_content,
13681368
)
1369-
if json_obj.get("sage_maker_sdk_predictor_specifications")
1369+
if json_obj.get("predictor_specs")
13701370
else None
13711371
)
13721372
self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = (
@@ -1502,6 +1502,9 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields):
15021502
"incremental_training_supported",
15031503
]
15041504

1505+
# Map of HubContent fields that map to custom names in MetadataBaseFields
1506+
CUSTOM_FIELD_MAP = {"sage_maker_sdk_predictor_specifications": "predictor_specs"}
1507+
15051508
__slots__ = slots + JumpStartMetadataBaseFields.__slots__
15061509

15071510
def __init__(
@@ -1532,6 +1535,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
15321535
for field in json_obj.keys():
15331536
if field in self.__slots__:
15341537
setattr(self, field, json_obj[field])
1538+
1539+
# Handle custom fields
1540+
for field, custom_field in self.CUSTOM_FIELD_MAP.items():
1541+
if field in json_obj:
1542+
setattr(self, custom_field, json_obj[field])
15351543

15361544

15371545
class JumpStartMetadataConfig(JumpStartDataHolderType):

0 commit comments

Comments
 (0)