@@ -1156,6 +1156,149 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
1156
1156
)
1157
1157
self .model_subscription_link = json_obj .get ("model_subscription_link" )
1158
1158
1159
+ def from_describe_hub_content_response (self , response : DescribeHubContentResponse ) -> None :
1160
+ """Sets fields in object based on values in HubContentDocument
1161
+
1162
+ Args:
1163
+ hub_content_doc (Dict[str, any]): parsed HubContentDocument returned
1164
+ from SageMaker:DescribeHubContent
1165
+ """
1166
+ self .model_id : str = response .hub_content_name
1167
+ self .version : str = response .hub_content_version
1168
+ hub_content_document : HubModelDocument = response .hub_content_document
1169
+ self .url : str = hub_content_document .url
1170
+ self .min_sdk_version : str = hub_content_document .min_sdk_version
1171
+ self .training_supported : bool = hub_content_document .training_supported
1172
+ self .incremental_training_supported : bool = bool (
1173
+ hub_content_document ["IncrementalTrainingSupported" ]
1174
+ )
1175
+ self .hosting_ecr_uri : Optional [str ] = hub_content_document .hosting_ecr_uri
1176
+ self ._non_serializable_slots .append ("hosting_ecr_specs" )
1177
+
1178
+ hosting_artifact_bucket , hosting_artifact_key = parse_s3_url (
1179
+ hub_content_document .hosting_artifact_uri
1180
+ )
1181
+ self .hosting_artifact_key : str = hosting_artifact_key
1182
+ hosting_script_bucket , hosting_script_key = parse_s3_url (
1183
+ hub_content_document .hosting_script_uri
1184
+ )
1185
+ self .hosting_script_key : str = hosting_script_key
1186
+ self .inference_environment_variables = hub_content_document .inference_environment_variables
1187
+ self .inference_vulnerable : bool = False
1188
+ self .inference_dependencies : List [str ] = hub_content_document .inference_dependencies
1189
+ self .inference_vulnerabilities : List [str ] = []
1190
+ self .training_vulnerable : bool = False
1191
+ self .training_dependencies : List [str ] = hub_content_document .training_dependencies
1192
+ self .training_vulnerabilities : List [str ] = []
1193
+ self .deprecated : bool = False
1194
+ self .deprecated_message : Optional [str ] = None
1195
+ self .deprecate_warn_message : Optional [str ] = None
1196
+ self .usage_info_message : Optional [str ] = None
1197
+ self .default_inference_instance_type : Optional [
1198
+ str
1199
+ ] = hub_content_document .default_inference_instance_type
1200
+ self .default_training_instance_type : Optional [
1201
+ str
1202
+ ] = hub_content_document .default_training_instance_type
1203
+ self .supported_inference_instance_types : Optional [
1204
+ List [str ]
1205
+ ] = hub_content_document .supported_inference_instance_types
1206
+ self .supported_training_instance_types : Optional [
1207
+ List [str ]
1208
+ ] = hub_content_document .supported_training_instance_types
1209
+ self .dynamic_container_deployment_supported : Optional [
1210
+ bool
1211
+ ] = hub_content_document .dynamic_container_deployment_supported
1212
+ self .hosting_resource_requirements : Optional [
1213
+ Dict [str , int ]
1214
+ ] = hub_content_document .hosting_resource_requirements
1215
+ self .metrics : Optional [List [Dict [str , str ]]] = hub_content_document .training_metrics
1216
+ self .training_prepacked_script_key : Optional [str ] = None
1217
+ if hub_content_document .training_prepacked_script_uri is not None :
1218
+ training_prepacked_script_bucket , training_prepacked_script_key = parse_s3_url (
1219
+ hub_content_document .training_prepacked_script_uri
1220
+ )
1221
+ self .training_prepacked_script_key = training_prepacked_script_key
1222
+
1223
+ self .hosting_prepacked_artifact_key : Optional [str ] = None
1224
+ if hub_content_document .hosting_prepacked_artifact_uri is not None :
1225
+ hosting_prepacked_artifact_bucket , hosting_prepacked_artifact_key = parse_s3_url (
1226
+ hub_content_document .hosting_prepacked_artifact_uri
1227
+ )
1228
+ self .hosting_prepacked_artifact_key = hosting_prepacked_artifact_key
1229
+
1230
+ self .fit_kwargs = get_model_spec_kwargs_from_hub_content_document (
1231
+ ModelSpecKwargType .FIT , hub_content_document
1232
+ )
1233
+ self .model_kwargs = get_model_spec_kwargs_from_hub_content_document (
1234
+ ModelSpecKwargType .MODEL , hub_content_document
1235
+ )
1236
+ self .deploy_kwargs = get_model_spec_kwargs_from_hub_content_document (
1237
+ ModelSpecKwargType .DEPLOY , hub_content_document
1238
+ )
1239
+ self .estimator_kwargs = get_model_spec_kwargs_from_hub_content_document (
1240
+ ModelSpecKwargType .ESTIMATOR , hub_content_document
1241
+ )
1242
+
1243
+ self .predictor_specs : Optional [
1244
+ JumpStartPredictorSpecs
1245
+ ] = hub_content_document .sage_maker_sdk_predictor_specifications
1246
+ self .default_payloads : Optional [
1247
+ Dict [str , JumpStartSerializablePayload ]
1248
+ ] = hub_content_document .default_payloads
1249
+ self .gated_bucket = hub_content_document .gated_bucket
1250
+ self .inference_volume_size : Optional [int ] = hub_content_document .inference_volume_size
1251
+ self .inference_enable_network_isolation : bool = (
1252
+ hub_content_document .inference_enable_network_isolation
1253
+ )
1254
+ self .resource_name_base : Optional [str ] = hub_content_document .resource_name_base
1255
+
1256
+ self .hosting_eula_key : Optional [str ] = None
1257
+ if hub_content_document .hosting_eula_uri is not None :
1258
+ hosting_eula_bucket , hosting_eula_key = parse_s3_url (
1259
+ hub_content_document .hosting_eula_uri
1260
+ )
1261
+ self .hosting_eula_key = hosting_eula_key
1262
+
1263
+ self .hosting_model_package_arns : Optional [Dict ] = None # TODO: Missing from shcema?
1264
+ self .hosting_use_script_uri : bool = hub_content_document .hosting_use_script_uri
1265
+
1266
+ self .hosting_instance_type_variants : Optional [JumpStartInstanceTypeVariants ] = (
1267
+ JumpStartInstanceTypeVariants (hub_content_document .hosting_instance_type_variants )
1268
+ if hub_content_document .hosting_instance_type_variants
1269
+ else None
1270
+ )
1271
+
1272
+ if self .training_supported :
1273
+ self .training_ecr_uri : Optional [str ] = hub_content_document .training_ecr_uri
1274
+ self ._non_serializable_slots .append ("training_ecr_specs" )
1275
+ training_artifact_bucket , training_artifact_key = parse_s3_url (
1276
+ hub_content_document .training_artifact_uri
1277
+ )
1278
+ self .training_artifact_key : str = training_artifact_key
1279
+ training_script_bucket , training_script_key = parse_s3_url (
1280
+ hub_content_document .training_script_uri
1281
+ )
1282
+ self .training_script_key : str = training_script_key
1283
+
1284
+ self .hyperparameters : List [
1285
+ JumpStartHyperparameter
1286
+ ] = hub_content_document .hyperparameters
1287
+ self .training_volume_size : Optional [int ] = hub_content_document .training_volume_size
1288
+ self .training_enable_network_isolation : bool = (
1289
+ hub_content_document .training_enable_network_isolation
1290
+ )
1291
+ self .training_model_package_artifact_uris : Optional [
1292
+ Dict
1293
+ ] = hub_content_document .training_model_package_artifact_uri
1294
+ self .training_instance_type_variants : Optional [
1295
+ JumpStartInstanceTypeVariants
1296
+ ] = JumpStartInstanceTypeVariants (
1297
+ hub_content_document .training_instance_type_variants
1298
+ if hub_content_document .training_instance_type_variants
1299
+ else None
1300
+ )
1301
+
1159
1302
def supports_prepacked_inference (self ) -> bool :
1160
1303
"""Returns True if the model has a prepacked inference artifact."""
1161
1304
return getattr (self , "hosting_prepacked_artifact_key" , None ) is not None
0 commit comments