diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 348de7adeb..458da694d5 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -74,9 +74,12 @@ class Frameworks(str, Enum): JUMPSTART_REGION = "eu-west-2" SDK_MANIFEST_FILE = "models_manifest.json" +PROPRIETARY_SDK_MANIFEST_FILE = "proprietary-sdk-manifest.json" JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format( JUMPSTART_REGION, JUMPSTART_REGION ) +PROPRIETARY_DOC_BUCKET = "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com" + TASK_MAP = { Tasks.IC: ProblemTypes.IMAGE_CLASSIFICATION, Tasks.IC_EMBEDDING: ProblemTypes.IMAGE_EMBEDDING, @@ -152,18 +155,26 @@ class Frameworks(str, Enum): } -def get_jumpstart_sdk_manifest(): - url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, SDK_MANIFEST_FILE) +def get_public_s3_json_object(url): with request.urlopen(url) as f: models_manifest = f.read().decode("utf-8") return json.loads(models_manifest) -def get_jumpstart_sdk_spec(key): - url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, key) - with request.urlopen(url) as f: - model_spec = f.read().decode("utf-8") - return json.loads(model_spec) +def get_jumpstart_sdk_manifest(): + return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{SDK_MANIFEST_FILE}") + + +def get_proprietary_sdk_manifest(): + return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{PROPRIETARY_SDK_MANIFEST_FILE}") + + +def get_jumpstart_sdk_spec(s3_key: str): + return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{s3_key}") + + +def get_proprietary_sdk_spec(s3_key: str): + return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{s3_key}") def get_model_task(id): @@ -196,6 +207,45 @@ def get_model_source(url): return "Source" +def create_proprietary_model_table(): + proprietary_content_intro = [] + proprietary_content_intro.append("\n") + proprietary_content_intro.append(".. list-table:: Available Proprietary Models\n") + proprietary_content_intro.append(" :widths: 50 20 20 20 20\n") + proprietary_content_intro.append(" :header-rows: 1\n") + proprietary_content_intro.append(" :class: datatable\n") + proprietary_content_intro.append("\n") + proprietary_content_intro.append(" * - Model ID\n") + proprietary_content_intro.append(" - Fine Tunable?\n") + proprietary_content_intro.append(" - Supported Version\n") + proprietary_content_intro.append(" - Min SDK Version\n") + proprietary_content_intro.append(" - Source\n") + + sdk_manifest = get_proprietary_sdk_manifest() + sdk_manifest_top_versions_for_models = {} + + for model in sdk_manifest: + if model["model_id"] not in sdk_manifest_top_versions_for_models: + sdk_manifest_top_versions_for_models[model["model_id"]] = model + else: + if str(sdk_manifest_top_versions_for_models[model["model_id"]]["version"]) < str( + model["version"] + ): + sdk_manifest_top_versions_for_models[model["model_id"]] = model + + proprietary_content_entries = [] + for model in sdk_manifest_top_versions_for_models.values(): + model_spec = get_proprietary_sdk_spec(model["spec_key"]) + proprietary_content_entries.append(" * - {}\n".format(model_spec["model_id"])) + proprietary_content_entries.append(" - {}\n".format(False)) # TODO: support training + proprietary_content_entries.append(" - {}\n".format(model["version"])) + proprietary_content_entries.append(" - {}\n".format(model["min_version"])) + proprietary_content_entries.append( + " - `{} <{}>`__ |external-link|\n".format("Source", model_spec.get("url")) + ) + return proprietary_content_intro + proprietary_content_entries + ["\n"] + + def create_jumpstart_model_table(): sdk_manifest = get_jumpstart_sdk_manifest() sdk_manifest_top_versions_for_models = {} @@ -249,19 +299,19 @@ def create_jumpstart_model_table(): file_content_intro.append(" - Source\n") dynamic_table_files = [] - file_content_entries = [] + open_weight_content_entries = [] for model in sdk_manifest_top_versions_for_models.values(): model_spec = get_jumpstart_sdk_spec(model["spec_key"]) model_task = get_model_task(model_spec["model_id"]) string_model_task = get_string_model_task(model_spec["model_id"]) model_source = get_model_source(model_spec["url"]) - file_content_entries.append(" * - {}\n".format(model_spec["model_id"])) - file_content_entries.append(" - {}\n".format(model_spec["training_supported"])) - file_content_entries.append(" - {}\n".format(model["version"])) - file_content_entries.append(" - {}\n".format(model["min_version"])) - file_content_entries.append(" - {}\n".format(model_task)) - file_content_entries.append( + open_weight_content_entries.append(" * - {}\n".format(model_spec["model_id"])) + open_weight_content_entries.append(" - {}\n".format(model_spec["training_supported"])) + open_weight_content_entries.append(" - {}\n".format(model["version"])) + open_weight_content_entries.append(" - {}\n".format(model["min_version"])) + open_weight_content_entries.append(" - {}\n".format(model_task)) + open_weight_content_entries.append( " - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"]) ) @@ -299,7 +349,10 @@ def create_jumpstart_model_table(): f.writelines(file_content_single_entry) f.close() + proprietary_content_entries = create_proprietary_model_table() + f = open("doc_utils/pretrainedmodels.rst", "a") f.writelines(file_content_intro) - f.writelines(file_content_entries) + f.writelines(open_weight_content_entries) + f.writelines(proprietary_content_entries) f.close() diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index bf081365ab..78aa655e04 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -16,6 +16,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -75,6 +76,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default accept type for the model matching the given arguments. @@ -114,4 +116,5 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index 882cfafc39..76b83c25cd 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -58,7 +58,9 @@ from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME from sagemaker.lineage.context import EndpointContext -from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +from sagemaker.compute_resource_requirements.resource_requirements import ( + ResourceRequirements, +) LOGGER = logging.getLogger("sagemaker") diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index e43e96be17..46d0361f67 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -16,6 +16,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -75,6 +76,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default content type for the model matching the given arguments. @@ -114,6 +116,7 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 706ae56bda..1a4be43897 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -35,6 +35,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -95,6 +96,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. @@ -135,4 +137,5 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 0471f374ae..48aaab0ac8 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -20,6 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -34,6 +35,7 @@ def retrieve_default( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -85,6 +87,7 @@ def retrieve_default( tolerate_deprecated_model, sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, + model_type=model_type, ) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index e03a13a7a3..35df030ddc 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -18,6 +18,7 @@ from sagemaker.deprecations import deprecated from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import cache from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME @@ -197,7 +198,9 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None: @staticmethod def _get_manifest( - region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None + region: str = JUMPSTART_DEFAULT_REGION_NAME, + s3_client: Optional[boto3.client] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest. @@ -215,13 +218,19 @@ def _get_manifest( additional_kwargs.update({"s3_client": s3_client}) cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( - {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, region + {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, + region, ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - return JumpStartModelsAccessor._cache.get_manifest() # type: ignore + return JumpStartModelsAccessor._cache.get_manifest(model_type) # type: ignore @staticmethod - def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader: + def get_model_header( + region: str, + model_id: str, + version: str, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + ) -> JumpStartModelHeader: """Returns model header from JumpStart models cache. Args: @@ -234,12 +243,18 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) return JumpStartModelsAccessor._cache.get_header( # type: ignore - model_id=model_id, semantic_version_str=version + model_id=model_id, + semantic_version_str=version, + model_type=model_type, ) @staticmethod def get_model_specs( - region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None + region: str, + model_id: str, + version: str, + s3_client: Optional[boto3.client] = None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. @@ -260,7 +275,7 @@ def get_model_specs( ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) return JumpStartModelsAccessor._cache.get_specs( # type: ignore - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version, model_type=model_type ) @staticmethod diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 38e02e3ebd..608303c5e6 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -38,6 +39,7 @@ def _retrieve_default_instance_type( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default instance type for the model. @@ -84,6 +86,7 @@ def _retrieve_default_instance_type( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 7acad9b793..c15f686805 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -35,6 +36,7 @@ def _retrieve_model_init_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> dict: """Retrieves kwargs for `Model`. @@ -71,6 +73,7 @@ def _retrieve_model_init_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) kwargs = deepcopy(model_specs.model_kwargs) @@ -89,6 +92,7 @@ def _retrieve_model_deploy_kwargs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> dict: """Retrieves kwargs for `Model.deploy`. @@ -128,6 +132,7 @@ def _retrieve_model_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None: diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index bd0ae365d9..5c8a2488c6 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -22,6 +22,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.session import Session @@ -35,6 +36,7 @@ def _retrieve_model_package_arn( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. @@ -74,6 +76,7 @@ def _retrieve_model_package_arn( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 3ea2c16f80..0424145119 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -20,6 +20,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( @@ -35,6 +36,7 @@ def _retrieve_example_payloads( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[Dict[str, JumpStartSerializablePayload]]: """Returns example payloads. @@ -72,6 +74,7 @@ def _retrieve_example_payloads( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) default_payloads = model_specs.default_payloads diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 8d599c89cc..e9e0e8dfde 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -26,6 +26,7 @@ from sagemaker.jumpstart.enums import ( JumpStartScriptScope, MIMEType, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -76,6 +77,7 @@ def _retrieve_default_deserializer( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -108,6 +110,7 @@ def _retrieve_default_deserializer( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type)) @@ -120,6 +123,7 @@ def _retrieve_default_serializer( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -151,6 +155,7 @@ def _retrieve_default_serializer( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type)) @@ -163,6 +168,7 @@ def _retrieve_deserializer_options( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -194,6 +200,7 @@ def _retrieve_deserializer_options( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) seen_classes: Set[Type] = set() @@ -276,6 +283,7 @@ def _retrieve_default_content_type( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default content type for the model. @@ -312,6 +320,7 @@ def _retrieve_default_content_type( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) default_content_type = model_specs.predictor_specs.default_content_type @@ -325,6 +334,7 @@ def _retrieve_default_accept_type( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the default accept type for the model. @@ -360,6 +370,7 @@ def _retrieve_default_accept_type( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) default_accept_type = model_specs.predictor_specs.default_accept_type @@ -374,6 +385,7 @@ def _retrieve_supported_accept_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -409,6 +421,7 @@ def _retrieve_supported_accept_types( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) supported_accept_types = model_specs.predictor_specs.supported_accept_types @@ -423,6 +436,7 @@ def _retrieve_supported_content_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: """Retrieves the supported content types for the model. @@ -458,6 +472,7 @@ def _retrieve_supported_content_types( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) supported_content_types = model_specs.predictor_specs.supported_content_types diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 6b05f07b15..60af520a6e 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -19,6 +19,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -32,6 +33,7 @@ def _retrieve_resource_name_base( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns default resource name. @@ -68,6 +70,7 @@ def _retrieve_resource_name_base( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 6ee4f31c56..9f01a7af77 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -21,6 +21,7 @@ ) from sagemaker.jumpstart.enums import ( JumpStartScriptScope, + JumpStartModelType, ) from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, @@ -50,6 +51,7 @@ def _retrieve_default_resources( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, ) -> ResourceRequirements: @@ -97,6 +99,7 @@ def _retrieve_default_resources( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index e26d588167..7682ab3817 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -25,11 +25,17 @@ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, + MODEL_TYPE_TO_MANIFEST_MAP, + MODEL_TYPE_TO_SPECS_MAP, +) +from sagemaker.jumpstart.exceptions import ( + get_wildcard_model_version_msg, + get_wildcard_proprietary_model_version_msg, ) -from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, @@ -44,6 +50,7 @@ JumpStartS3FileType, JumpStartVersionedModelId, ) +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache @@ -68,6 +75,7 @@ def __init__( JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, manifest_file_s3_key: str = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + proprietary_manifest_s3_key: str = JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, @@ -100,14 +108,26 @@ def __init__( expiration_horizon=s3_cache_expiration_horizon, retrieval_function=self._retrieval_function, ) - self._model_id_semantic_version_manifest_key_cache = LRUCache[ + self._open_weight_model_id_manifest_key_cache = LRUCache[ JumpStartVersionedModelId, JumpStartVersionedModelId ]( max_cache_items=max_semantic_version_cache_items, expiration_horizon=semantic_version_cache_expiration_horizon, - retrieval_function=self._get_manifest_key_from_model_id_semantic_version, + retrieval_function=self._get_open_weight_manifest_key_from_model_id, + ) + self._proprietary_model_id_manifest_key_cache = LRUCache[ + JumpStartVersionedModelId, JumpStartVersionedModelId + ]( + max_cache_items=max_semantic_version_cache_items, + expiration_horizon=semantic_version_cache_expiration_horizon, + retrieval_function=self._get_proprietary_manifest_key_from_model_id, ) self._manifest_file_s3_key = manifest_file_s3_key + self._proprietary_manifest_s3_key = proprietary_manifest_s3_key + self._manifest_file_s3_map = { + JumpStartModelType.OPEN_WEIGHTS: self._manifest_file_s3_key, + JumpStartModelType.PROPRIETARY: self._proprietary_manifest_s3_key, + } self.s3_bucket_name = ( utils.get_jumpstart_content_bucket(self._region) if s3_bucket_name is None @@ -129,15 +149,40 @@ def get_region(self) -> str: """Return region for cache.""" return self._region - def set_manifest_file_s3_key(self, key: str) -> None: - """Set manifest file s3 key. Clears cache after new key is set.""" - if key != self._manifest_file_s3_key: - self._manifest_file_s3_key = key + def set_manifest_file_s3_key( + self, + key: str, + file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + ) -> None: + """Set manifest file s3 key, clear cache after new key is set. + + Raises: + ValueError: if the file type is not recognized + """ + file_mapping = { + JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: self._manifest_file_s3_key, + JumpStartS3FileType.PROPRIETARY_MANIFEST: self._proprietary_manifest_s3_key, + } + property_name = file_mapping.get(file_type) + if not property_name: + raise ValueError( + self._file_type_error_msg(file_type, manifest_only=True) + ) + if key != property_name: + setattr(self, property_name, key) self.clear() - def get_manifest_file_s3_key(self) -> str: + def get_manifest_file_s3_key( + self, file_type: JumpStartS3FileType = JumpStartS3FileType.OPEN_WEIGHT_MANIFEST + ) -> str: """Return manifest file s3 key for cache.""" - return self._manifest_file_s3_key + if file_type == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: + return self._manifest_file_s3_key + if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST: + return self._proprietary_manifest_s3_key + raise ValueError( + self._file_type_error_msg(file_type, manifest_only=True) + ) def set_s3_bucket_name(self, s3_bucket_name: str) -> None: """Set s3 bucket used for cache.""" @@ -149,10 +194,24 @@ def get_bucket(self) -> str: """Return bucket used for cache.""" return self.s3_bucket_name - def _get_manifest_key_from_model_id_semantic_version( + def _file_type_error_msg(self, file_type: str, manifest_only: bool = False) -> str: + """Return error message for bad model type.""" + if manifest_only: + return ( + f"Bad value when getting manifest '{file_type}': " + f"must be in {JumpStartS3FileType.OPEN_WEIGHT_MANIFEST} " + f"{JumpStartS3FileType.PROPRIETARY_MANIFEST}." + ) + return ( + f"Bad value when getting manifest '{file_type}': " + f"must be in '{' '.join([e.name for e in JumpStartS3FileType])}'." + ) + + def _model_id_retrieval_function( self, key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + model_type: JumpStartModelType, ) -> JumpStartVersionedModelId: """Return model ID and version in manifest that matches semantic version/id. @@ -164,6 +223,8 @@ def _get_manifest_key_from_model_id_semantic_version( key (JumpStartVersionedModelId): Key for which to fetch versioned model ID. value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached model ID/version. + model_type (JumpStartModelType): JumpStart model type to indicate whether it is + open weights model or proprietary (Marketplace) model. Raises: KeyError: If the semantic version is not found in the manifest, or is found but @@ -171,21 +232,20 @@ def _get_manifest_key_from_model_id_semantic_version( """ model_id, version = key.model_id, key.version - + sm_version = utils.get_sagemaker_version() manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + JumpStartCachedS3ContentKey( + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) )[0].formatted_content - sm_version = utils.get_sagemaker_version() - versions_compatible_with_sagemaker = [ - Version(header.version) + header.version for header in manifest.values() # type: ignore if header.model_id == model_id and Version(header.min_version) <= Version(sm_version) ] sm_compatible_model_version = self._select_version( - version, versions_compatible_with_sagemaker + model_id, version, versions_compatible_with_sagemaker, model_type ) if sm_compatible_model_version is not None: @@ -196,7 +256,7 @@ def _get_manifest_key_from_model_id_semantic_version( if header.model_id == model_id ] sm_incompatible_model_version = self._select_version( - version, versions_incompatible_with_sagemaker + model_id, version, versions_incompatible_with_sagemaker, model_type ) if sm_incompatible_model_version is not None: @@ -226,15 +286,27 @@ def _get_manifest_key_from_model_id_semantic_version( f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. " ) - other_model_id_version = self._select_version( - "*", versions_incompatible_with_sagemaker - ) # all versions here are incompatible with sagemaker + other_model_id_version = None + if model_type == JumpStartModelType.OPEN_WEIGHTS: + other_model_id_version = self._select_version( + model_id, "*", versions_incompatible_with_sagemaker, model_type + ) # all versions here are incompatible with sagemaker + elif model_type == JumpStartModelType.PROPRIETARY: + all_possible_model_id_version = [ + header.version for header in manifest.values() # type: ignore + if header.model_id == model_id + ] + other_model_id_version = ( + None + if not all_possible_model_id_version + else all_possible_model_id_version[0] + ) + if other_model_id_version is not None: error_msg += ( f"Consider using model ID '{model_id}' with version " f"'{other_model_id_version}'." ) - else: possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0] @@ -242,6 +314,32 @@ def _get_manifest_key_from_model_id_semantic_version( raise KeyError(error_msg) + def _get_open_weight_manifest_key_from_model_id( + self, + key: JumpStartVersionedModelId, + value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + ) -> JumpStartVersionedModelId: + """For open weights models, retrieve model manifest key for open weight model. + + Filters models list by supported versions. + """ + return self._model_id_retrieval_function( + key, value, model_type=JumpStartModelType.OPEN_WEIGHTS + ) + + def _get_proprietary_manifest_key_from_model_id( + self, + key: JumpStartVersionedModelId, + value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + ) -> JumpStartVersionedModelId: + """For proprietary models, retrieve model manifest key for proprietary model. + + Filters models list by supported versions. + """ + return self._model_id_retrieval_function( + key, value, model_type=JumpStartModelType.PROPRIETARY + ) + def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]: """Returns json file from s3, along with its etag.""" response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key) @@ -286,11 +384,11 @@ def _get_json_file_from_local_override( filetype: JumpStartS3FileType ) -> Union[dict, list]: """Reads json file from local filesystem and returns data.""" - if filetype == JumpStartS3FileType.MANIFEST: + if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: metadata_local_root = ( os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE] ) - elif filetype == JumpStartS3FileType.SPECS: + elif filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE] else: raise ValueError(f"Unsupported file type for local override: {filetype}") @@ -318,8 +416,10 @@ def _retrieval_function( """ file_type, s3_key = key.file_type, key.s3_key - - if file_type == JumpStartS3FileType.MANIFEST: + if file_type in { + JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + JumpStartS3FileType.PROPRIETARY_MANIFEST, + }: if value is not None and not self._is_local_metadata_mode(): etag = self._get_json_md5_hash(s3_key) if etag == value.md5_hash: @@ -329,27 +429,36 @@ def _retrieval_function( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) - if file_type == JumpStartS3FileType.SPECS: + if file_type in { + JumpStartS3FileType.OPEN_WEIGHT_SPECS, + JumpStartS3FileType.PROPRIETARY_SPECS, + }: formatted_body, _ = self._get_json_file(s3_key, file_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedS3ContentValue( - formatted_content=model_specs - ) + return JumpStartCachedS3ContentValue(formatted_content=model_specs) raise ValueError( - f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}" + self._file_type_error_msg(file_type) ) - def get_manifest(self) -> List[JumpStartModelHeader]: + def get_manifest( + self, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" - manifest_dict = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + JumpStartCachedS3ContentKey( + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) )[0].formatted_content manifest = list(manifest_dict.values()) # type: ignore return manifest - def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader: + def get_header( + self, + model_id: str, + semantic_version_str: str, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + ) -> JumpStartModelHeader: """Return header for a given JumpStart model ID and semantic version. Args: @@ -358,29 +467,43 @@ def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModel header. """ - return self._get_header_impl(model_id, semantic_version_str=semantic_version_str) + return self._get_header_impl( + model_id, semantic_version_str=semantic_version_str, model_type=model_type + ) def _select_version( self, - semantic_version_str: str, - available_versions: List[Version], + model_id: str, + version_str: str, + available_versions: List[str], + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[str]: """Perform semantic version search on available versions. Args: - semantic_version_str (str): the semantic version for which to filter + version_str (str): the semantic version for which to filter available versions. available_versions (List[Version]): list of available versions. """ - if semantic_version_str == "*": + + if version_str == "*": if len(available_versions) == 0: return None return str(max(available_versions)) + if model_type == JumpStartModelType.PROPRIETARY: + if "*" in version_str: + raise KeyError( + get_wildcard_proprietary_model_version_msg( + model_id, version_str, available_versions + ) + ) + return version_str if version_str in available_versions else None + try: - spec = SpecifierSet(f"=={semantic_version_str}") + spec = SpecifierSet(f"=={version_str}") except InvalidSpecifier: - raise KeyError(f"Bad semantic version: {semantic_version_str}") + raise KeyError(f"Bad semantic version: {version_str}") available_versions_filtered = list(spec.filter(available_versions)) return ( str(max(available_versions_filtered)) if available_versions_filtered != [] else None @@ -391,6 +514,7 @@ def _get_header_impl( model_id: str, semantic_version_str: str, attempt: int = 0, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS ) -> JumpStartModelHeader: """Lower-level function to return header. @@ -402,14 +526,20 @@ def _get_header_impl( header. attempt (int): attempt number at retrieving a header. """ - - versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get( - JumpStartVersionedModelId(model_id, semantic_version_str) - )[0] + if model_type == JumpStartModelType.OPEN_WEIGHTS: + versioned_model_id = self._open_weight_model_id_manifest_key_cache.get( + JumpStartVersionedModelId(model_id, semantic_version_str) + )[0] + elif model_type == JumpStartModelType.PROPRIETARY: + versioned_model_id = self._proprietary_model_id_manifest_key_cache.get( + JumpStartVersionedModelId(model_id, semantic_version_str) + )[0] manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + JumpStartCachedS3ContentKey( + MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]) )[0].formatted_content + try: header = manifest[versioned_model_id] # type: ignore return header @@ -417,28 +547,34 @@ def _get_header_impl( if attempt > 0: raise self.clear() - return self._get_header_impl(model_id, semantic_version_str, attempt + 1) + return self._get_header_impl(model_id, semantic_version_str, attempt + 1, model_type) - def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelSpecs: + def get_specs( + self, + model_id: str, + version_str: str, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS + ) -> JumpStartModelSpecs: """Return specs for a given JumpStart model ID and semantic version. Args: model_id (str): model ID for which to get specs. semantic_version_str (str): The semantic version for which to get specs. + model_type (JumpStartModelType): The type of the model of interest. """ - - header = self.get_header(model_id, semantic_version_str) + header = self.get_header(model_id, version_str, model_type) spec_key = header.spec_key specs, cache_hit = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) + JumpStartCachedS3ContentKey( + MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key + ) ) - if not cache_hit and "*" in semantic_version_str: + + if not cache_hit and "*" in version_str: JUMPSTART_LOGGER.warning( get_wildcard_model_version_msg( - header.model_id, - semantic_version_str, - header.version + header.model_id, version_str, header.version ) ) return specs.formatted_content @@ -446,4 +582,5 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS def clear(self) -> None: """Clears the model ID/version and s3 cache.""" self._s3_cache.clear() - self._model_id_semantic_version_manifest_key_cache.clear() + self._open_weight_model_id_manifest_key_cache.clear() + self._proprietary_model_id_manifest_key_cache.clear() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 2e655ac285..be66e8968e 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -22,8 +22,9 @@ SerializerType, DeserializerType, MIMEType, + JumpStartModelType, ) -from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo +from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo, JumpStartS3FileType from sagemaker.base_serializers import ( BaseSerializer, CSVSerializer, @@ -169,6 +170,7 @@ JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" +JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py" @@ -188,6 +190,9 @@ SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY = "SageMakerGatedModelS3Uri" +PROPRIETARY_MODEL_SPEC_PREFIX = "proprietary-models" +PROPRIETARY_MODEL_FILTER_NAME = "marketplace" + CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP: Dict[MIMEType, SerializerType] = { MIMEType.X_IMAGE: SerializerType.RAW_BYTES, MIMEType.LIST_TEXT: SerializerType.JSON, @@ -213,6 +218,16 @@ DeserializerType.JSON: JSONDeserializer, } +MODEL_TYPE_TO_MANIFEST_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = { + JumpStartModelType.OPEN_WEIGHTS: JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_MANIFEST, +} + +MODEL_TYPE_TO_SPECS_MAP: Dict[Type[JumpStartModelType], Type[JumpStartS3FileType]] = { + JumpStartModelType.OPEN_WEIGHTS: JumpStartS3FileType.OPEN_WEIGHT_SPECS, + JumpStartModelType.PROPRIETARY: JumpStartS3FileType.PROPRIETARY_SPECS, +} + MODEL_ID_LIST_WEB_URL = "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html" JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart") diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index e33daca046..afe0df0dfe 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -34,6 +34,19 @@ class ModelFramework(str, Enum): SKLEARN = "sklearn" +class JumpStartModelType(str, Enum): + """Enum class for JumpStart model type. + + OPEN_WEIGHTS: Publicly available models have open weights + and are onboarded and maintained by JumpStart. + PROPRIETARY: Proprietary models from third-party providers do not have open weights. + You must subscribe to proprietary models in AWS Marketplace before use. + """ + + OPEN_WEIGHTS = "open_weights" + PROPRIETARY = "proprietary" + + class VariableScope(str, Enum): """Possible value of the ``scope`` attribute for a hyperparameter or environment variable. @@ -78,6 +91,7 @@ class JumpStartTag(str, Enum): MODEL_ID = "sagemaker-sdk:jumpstart-model-id" MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" + MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" class SerializerType(str, Enum): diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 24105c4369..4dada409f5 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -35,7 +35,7 @@ from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job from sagemaker.jumpstart.utils import ( - is_valid_model_id, + validate_model_id_and_get_type, resolve_model_sagemaker_config_field, ) from sagemaker.utils import stringify_object, format_tags, Tags @@ -504,8 +504,8 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ - def _is_valid_model_id_hook(): - return is_valid_model_id( + def _validate_model_id_and_get_type_hook(): + return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, region=region, @@ -513,14 +513,17 @@ def _is_valid_model_id_hook(): sagemaker_session=sagemaker_session, ) - if not _is_valid_model_id_hook(): + self.model_type = _validate_model_id_and_get_type_hook() + if not self.model_type: JumpStartModelsAccessor.reset_cache() - if not _is_valid_model_id_hook(): + self.model_type = _validate_model_id_and_get_type_hook() + if not self.model_type: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) estimator_init_kwargs = get_init_kwargs( model_id=model_id, model_version=model_version, + model_type=self.model_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index c55c9081cb..742a6b8d3f 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -14,7 +14,12 @@ from __future__ import absolute_import from typing import List, Optional -from sagemaker.jumpstart.constants import MODEL_ID_LIST_WEB_URL, JumpStartScriptScope +from botocore.exceptions import ClientError + +from sagemaker.jumpstart.constants import ( + MODEL_ID_LIST_WEB_URL, + JumpStartScriptScope, +) NO_AVAILABLE_INSTANCES_ERROR_MSG = ( "No instances available in {region} that can support model ID '{model_id}'. " @@ -28,7 +33,7 @@ INVALID_MODEL_ID_ERROR_MSG = ( "Invalid model ID: '{model_id}'. Please visit " - f"{MODEL_ID_LIST_WEB_URL} for list of supported model IDs. " + f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. " "The module `sagemaker.jumpstart.notebook_utils` contains utilities for " "fetching model IDs. We recommend upgrading to the latest version of sagemaker " "to get access to the most models." @@ -58,6 +63,32 @@ def get_wildcard_model_version_msg( ) +def get_proprietary_model_subscription_msg( + model_id: str, + subscription_link: str, +) -> str: + """Returns customer-facing message for using a proprietary model.""" + + return ( + f"INFO: Using proprietary model '{model_id}'. " + f"To subscribe to this model in AWS Marketplace, see {subscription_link}" + ) + + +def get_wildcard_proprietary_model_version_msg( + model_id: str, wildcard_model_version: str, available_versions: List[str] +) -> str: + """Returns customer-facing message for passing wildcard version to proprietary models.""" + msg = ( + f"Proprietary model '{model_id}' does not support " + f"wildcard version identifier '{wildcard_model_version}'. " + ) + if len(available_versions) > 0: + msg += f"You can pin to version '{available_versions[0]}'. " + msg += f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. " + return msg + + def get_old_model_version_msg( model_id: str, current_model_version: str, latest_model_version: str ) -> str: @@ -70,6 +101,15 @@ def get_old_model_version_msg( ) +def get_proprietary_model_subscription_error(error: ClientError, subscription_link: str) -> None: + """Returns customer-facing message associated with a Marketplace subscription error.""" + + error_code = error.response["Error"]["Code"] + error_message = error.response["Error"]["Message"] + if error_code == "ValidationException" and "not subscribed" in error_message: + raise MarketplaceModelSubscriptionError(subscription_link) + + class JumpStartHyperparametersError(ValueError): """Exception raised for bad hyperparameters of a JumpStart model.""" @@ -169,3 +209,29 @@ def __init__( ) super().__init__(self.message) + + +class MarketplaceModelSubscriptionError(ValueError): + """Exception raised when trying to deploy a JumpStart Marketplace model. + + A caller is required to subscribe to the Marketplace product in order to deploy. + This exception is raised when a caller tries to deploy a JumpStart Marketplace model + but the caller is not subscribed to the model. + """ + + def __init__( + self, + model_subscription_link: Optional[str] = None, + message: Optional[str] = None, + ): + if message: + self.message = message + else: + self.message = ( + "To use a proprietary JumpStart model, " + "you must first subscribe to the model in AWS Marketplace. " + ) + if model_subscription_link: + self.message += f"To subscribe to this model, see {model_subscription_link}" + + super().__init__(self.message) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 7ccf57983b..7c20c281f5 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -50,7 +50,7 @@ TRAINING_ENTRY_POINT_SCRIPT_NAME, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.factory import model from sagemaker.jumpstart.types import ( JumpStartEstimatorDeployKwargs, @@ -77,6 +77,7 @@ def get_init_kwargs( model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -134,6 +135,7 @@ def get_init_kwargs( estimator_init_kwargs: JumpStartEstimatorInitKwargs = JumpStartEstimatorInitKwargs( model_id=model_id, model_version=model_version, + model_type=model_type, role=role, region=region, instance_count=instance_count, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 1b41cad714..c12a14c4a4 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -37,7 +37,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.types import ( JumpStartModelDeployKwargs, JumpStartModelInitKwargs, @@ -71,6 +71,7 @@ def get_default_predictor( tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, sagemaker_session: Session, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -92,6 +93,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, @@ -100,6 +102,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, @@ -108,6 +111,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, @@ -116,6 +120,7 @@ def get_default_predictor( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) return predictor @@ -187,6 +192,7 @@ def _add_instance_type_to_kwargs( tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, training_instance_type=kwargs.training_instance_type, + model_type=kwargs.model_type, ) if not disable_instance_type_logging and orig_instance_type is None: @@ -199,7 +205,14 @@ def _add_instance_type_to_kwargs( def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: - """Sets image uri based on default or override, returns full kwargs.""" + """Sets image uri based on default or override, returns full kwargs. + + Uses placeholder image uri for JumpStart proprietary models that uses ModelPackages + """ + + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.image_uri = None + return kwargs kwargs.image_uri = kwargs.image_uri or image_uris.retrieve( region=kwargs.region, @@ -219,6 +232,10 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets model data based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.model_data = None + return kwargs + model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve( model_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, @@ -255,6 +272,10 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets source dir based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.source_dir = None + return kwargs + source_dir = kwargs.source_dir if _model_supports_inference_script_uri( @@ -283,6 +304,10 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets entry point based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.entry_point = None + return kwargs + entry_point = kwargs.entry_point if _model_supports_inference_script_uri( @@ -304,6 +329,10 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets env based on default or override, returns full kwargs.""" + if kwargs.model_type == JumpStartModelType.PROPRIETARY: + kwargs.env = None + return kwargs + env = kwargs.env if env is None: @@ -348,6 +377,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) kwargs.model_package_arn = model_package_arn @@ -364,6 +394,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) for key, value in model_kwargs_to_add.items(): @@ -399,6 +430,7 @@ def _add_endpoint_name_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) kwargs.endpoint_name = kwargs.endpoint_name or ( @@ -420,6 +452,7 @@ def _add_model_name_to_kwargs( tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) kwargs.name = kwargs.name or ( @@ -440,11 +473,12 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ).version if kwargs.sagemaker_session.settings.include_jumpstart_tags: kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version + kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type ) return kwargs @@ -461,6 +495,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, ) for key, value in deploy_kwargs_to_add.items(): @@ -481,6 +516,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, instance_type=kwargs.instance_type, ) @@ -490,6 +526,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -522,6 +559,7 @@ def get_deploy_kwargs( deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs( model_id=model_id, model_version=model_version, + model_type=model_type, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, @@ -655,6 +693,7 @@ def get_init_kwargs( model_id: str, model_from_estimator: bool = False, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, instance_type: Optional[str] = None, @@ -686,6 +725,7 @@ def get_init_kwargs( model_init_kwargs: JumpStartModelInitKwargs = JumpStartModelInitKwargs( model_id=model_id, model_version=model_version, + model_type=model_type, instance_type=instance_type, region=region, image_uri=image_uri, @@ -730,14 +770,12 @@ def get_init_kwargs( # we use the model artifact from the training job output if not model_from_estimator: model_init_kwargs = _add_model_data_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_source_dir_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_entry_point_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_env_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_predictor_cls_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_extra_model_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_role_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/filters.py b/src/sagemaker/jumpstart/filters.py index b045435ed0..fc5113315d 100644 --- a/src/sagemaker/jumpstart/filters.py +++ b/src/sagemaker/jumpstart/filters.py @@ -49,6 +49,14 @@ class SpecialSupportedFilterKeys(str, Enum): TASK = "task" FRAMEWORK = "framework" + MODEL_TYPE = "model_type" + + +class ProprietaryModelFilterIdentifiers(str, Enum): + """Enum class for proprietary model filter keys.""" + + PROPRIETARY = "proprietary" + MARKETPLACE = "marketplace" FILTER_OPERATOR_STRING_MAPPINGS = { @@ -429,6 +437,22 @@ def __init__(self, key: str, value: str, operator: str): self.value = value self.operator = operator + def set_key(self, key: str) -> None: + """Sets the key for the model filter. + + Args: + key (str): The key to be set. + """ + self.key = key + + def set_value(self, value: str) -> None: + """Sets the value for the model filter. + + Args: + value (str): The value to be set. + """ + self.value = value + def parse_filter_string(filter_string: str) -> ModelFilter: """Parse filter string and return a serialized ``ModelFilter`` object. diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 1742f860e4..99c09b2375 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -15,14 +15,21 @@ from __future__ import absolute_import from typing import Dict, List, Optional, Union +from botocore.exceptions import ClientError + from sagemaker import payloads from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer from sagemaker.base_serializers import BaseSerializer +from sagemaker.enums import EndpointType from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.enums import JumpStartScriptScope -from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG +from sagemaker.jumpstart.exceptions import ( + INVALID_MODEL_ID_ERROR_MSG, + get_proprietary_model_subscription_error, + get_proprietary_model_subscription_msg, +) from sagemaker.jumpstart.factory.model import ( get_default_predictor, get_deploy_kwargs, @@ -30,7 +37,12 @@ get_register_kwargs, ) from sagemaker.jumpstart.types import JumpStartSerializablePayload -from sagemaker.jumpstart.utils import is_valid_model_id +from sagemaker.jumpstart.utils import ( + validate_model_id_and_get_type, + verify_model_region_and_return_specs, +) +from sagemaker.jumpstart.constants import JUMPSTART_LOGGER +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, @@ -45,7 +57,6 @@ from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements -from sagemaker.enums import EndpointType class JumpStartModel(Model): @@ -270,8 +281,8 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ - def _is_valid_model_id_hook(): - return is_valid_model_id( + def _validate_model_id_and_type(): + return validate_model_id_and_get_type( model_id=model_id, model_version=model_version, region=region, @@ -279,16 +290,18 @@ def _is_valid_model_id_hook(): sagemaker_session=sagemaker_session, ) - if not _is_valid_model_id_hook(): + self.model_type = _validate_model_id_and_type() + if not self.model_type: JumpStartModelsAccessor.reset_cache() - if not _is_valid_model_id_hook(): + self.model_type = _validate_model_id_and_type() + if not self.model_type: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) self._model_data_is_set = model_data is not None - model_init_kwargs = get_init_kwargs( model_id=model_id, model_from_estimator=False, + model_type=self.model_type, model_version=model_version, instance_type=instance_type, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -326,10 +339,27 @@ def _is_valid_model_id_hook(): self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session + if self.model_type == JumpStartModelType.PROPRIETARY: + self.log_subscription_warning() + super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) self.model_package_arn = model_init_kwargs.model_package_arn + def log_subscription_warning(self) -> None: + """Log message prompting the customer to subscribe to the proprietary model.""" + subscription_link = verify_model_region_and_return_specs( + region=self.region, + model_id=self.model_id, + version=self.model_version, + model_type=self.model_type, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self.sagemaker_session, + ).model_subscription_link + JUMPSTART_LOGGER.warning( + get_proprietary_model_subscription_msg(self.model_id, subscription_link) + ) + def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: """Returns all example payloads associated with the model. @@ -347,6 +377,7 @@ def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + model_type=self.model_type, ) def retrieve_example_payload(self) -> JumpStartSerializablePayload: @@ -364,6 +395,7 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: return payloads.retrieve_example( model_id=self.model_id, model_version=self.model_version, + model_type=self.model_type, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -558,6 +590,9 @@ def deploy( endpoint. endpoint_type (EndpointType): The type of endpoint used to deploy models. (Default: EndpointType.MODEL_BASED). + + Raises: + MarketplaceModelSubscriptionError: If the caller is not subscribed to the model. """ deploy_kwargs = get_deploy_kwargs( @@ -589,9 +624,29 @@ def deploy( resources=resources, managed_instance_scaling=managed_instance_scaling, endpoint_type=endpoint_type, + model_type=self.model_type, ) + if ( + self.model_type == JumpStartModelType.PROPRIETARY + and endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED + ): + raise ValueError( + f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models." + ) - predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) + try: + predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) + except ClientError as e: + subscription_link = verify_model_region_and_return_specs( + region=self.region, + model_id=self.model_id, + version=self.model_version, + model_type=self.model_type, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self.sagemaker_session, + ).model_subscription_link + get_proprietary_model_subscription_error(e, subscription_link) + raise # If no predictor class was passed, add defaults to predictor if self.orig_predictor_cls is None and async_inference_config is None: @@ -603,6 +658,7 @@ def deploy( tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + model_type=self.model_type, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 1554025995..485354e802 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -24,10 +24,12 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, + PROPRIETARY_MODEL_SPEC_PREFIX, ) -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.filters import ( SPECIAL_SUPPORTED_FILTER_KEYS, + ProprietaryModelFilterIdentifiers, BooleanValues, Identity, SpecialSupportedFilterKeys, @@ -38,6 +40,7 @@ get_jumpstart_content_bucket, get_sagemaker_version, verify_model_region_and_return_specs, + validate_model_id_and_get_type, ) from sagemaker.session import Session @@ -124,15 +127,11 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: Args: model_id (str): The model ID for which to extract the framework/task/model. - - Raises: - ValueError: If the model ID cannot be parsed into at least 3 components seperated by - "-" character. """ _id_parts = model_id.split("-") if len(_id_parts) < 3: - raise ValueError(f"incorrect model ID: {model_id}.") + return "", "", "" framework = _id_parts[0] task = _id_parts[1] @@ -141,6 +140,20 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]: return framework, task, name +def extract_model_type_filter_representation(spec_key: str) -> str: + """Parses model spec key, determine if the model is proprietary or open weight. + + Args: + spek_key (str): The model spec key for which to extract the model type. + """ + model_spec_prefix = spec_key.split("/")[0] + + if model_spec_prefix == PROPRIETARY_MODEL_SPEC_PREFIX: + return JumpStartModelType.PROPRIETARY.value + + return JumpStartModelType.OPEN_WEIGHTS.value + + def list_jumpstart_tasks( # pylint: disable=redefined-builtin filter: Union[Operator, str] = Constant(BooleanValues.TRUE), region: str = JUMPSTART_DEFAULT_REGION_NAME, @@ -321,14 +334,22 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). """ - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( - region=region, s3_client=sagemaker_session.s3_client + prop_models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, + s3_client=sagemaker_session.s3_client, + model_type=JumpStartModelType.PROPRIETARY, ) + open_weight_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, + s3_client=sagemaker_session.s3_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, + ) + models_manifest_list = open_weight_manifest_list + prop_models_manifest_list if isinstance(filter, str): filter = Identity(filter) - manifest_keys = set(models_manifest_list[0].__slots__) + manifest_keys = set(models_manifest_list[0].__slots__ + prop_models_manifest_list[0].__slots__) all_keys: Set[str] = set() @@ -338,6 +359,10 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin model_filter = operator.unresolved_value key = model_filter.key all_keys.add(key) + if model_filter.key == SpecialSupportedFilterKeys.MODEL_TYPE and model_filter.value in { + identifier.value for identifier in ProprietaryModelFilterIdentifiers + }: + model_filter.set_value(JumpStartModelType.PROPRIETARY.value) model_filters.add(model_filter) for key in all_keys: @@ -351,6 +376,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys + is_model_type_filter = SpecialSupportedFilterKeys.MODEL_TYPE in all_keys def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, str]]: @@ -373,6 +399,11 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, SpecialSupportedFilterKeys.FRAMEWORK ] = extract_framework_task_model(model_manifest.model_id)[0] + if is_model_type_filter: + manifest_specs_cached_values[ + SpecialSupportedFilterKeys.MODEL_TYPE + ] = extract_model_type_filter_representation(model_manifest.spec_key) + if Version(model_manifest.min_version) > Version(get_sagemaker_version()): return None @@ -466,6 +497,12 @@ def get_model_url( sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use to retrieve the model url. """ + model_type = validate_model_id_and_get_type( + model_id=model_id, + model_version=model_version, + region=region, + sagemaker_session=sagemaker_session, + ) model_specs = verify_model_region_and_return_specs( region=region, @@ -473,5 +510,6 @@ def get_model_url( version=model_version, sagemaker_session=sagemaker_session, scope=JumpStartScriptScope.INFERENCE, + model_type=model_type, ) return model_specs.url diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 810d1c4cd3..9af44779f3 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -19,6 +19,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session from sagemaker.workflow.entities import PipelineVariable @@ -102,8 +103,10 @@ def __repr__(self) -> str: class JumpStartS3FileType(str, Enum): """Type of files published in JumpStart S3 distribution buckets.""" - MANIFEST = "manifest" - SPECS = "specs" + OPEN_WEIGHT_MANIFEST = "manifest" + OPEN_WEIGHT_SPECS = "specs" + PROPRIETARY_MANIFEST = "proptietary_manifest" + PROPRIETARY_SPECS = "proprietary_specs" class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): @@ -788,6 +791,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "training_instance_type_variants", "default_payloads", "gated_bucket", + "model_subscription_link", ] def __init__(self, spec: Dict[str, Any]): @@ -805,29 +809,31 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ self.model_id: str = json_obj["model_id"] - self.url: str = json_obj["url"] + self.url: str = json_obj.get("url", "") self.version: str = json_obj["version"] self.min_sdk_version: str = json_obj["min_sdk_version"] - self.incremental_training_supported: bool = bool(json_obj["incremental_training_supported"]) + self.incremental_training_supported: bool = bool( + json_obj.get("incremental_training_supported", False) + ) self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) if "hosting_ecr_specs" in json_obj else None ) - self.hosting_artifact_key: str = json_obj["hosting_artifact_key"] - self.hosting_script_key: str = json_obj["hosting_script_key"] - self.training_supported: bool = bool(json_obj["training_supported"]) + self.hosting_artifact_key: Optional[str] = json_obj.get("hosting_artifact_key") + self.hosting_script_key: Optional[str] = json_obj.get("hosting_script_key") + self.training_supported: Optional[bool] = bool(json_obj.get("training_supported", False)) self.inference_environment_variables = [ JumpStartEnvironmentVariable(env_variable) - for env_variable in json_obj["inference_environment_variables"] + for env_variable in json_obj.get("inference_environment_variables", []) ] - self.inference_vulnerable: bool = bool(json_obj["inference_vulnerable"]) - self.inference_dependencies: List[str] = json_obj["inference_dependencies"] - self.inference_vulnerabilities: List[str] = json_obj["inference_vulnerabilities"] - self.training_vulnerable: bool = bool(json_obj["training_vulnerable"]) - self.training_dependencies: List[str] = json_obj["training_dependencies"] - self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"] - self.deprecated: bool = bool(json_obj["deprecated"]) + self.inference_vulnerable: bool = bool(json_obj.get("inference_vulnerable", False)) + self.inference_dependencies: List[str] = json_obj.get("inference_dependencies", []) + self.inference_vulnerabilities: List[str] = json_obj.get("inference_vulnerabilities", []) + self.training_vulnerable: bool = bool(json_obj.get("training_vulnerable", False)) + self.training_dependencies: List[str] = json_obj.get("training_dependencies", []) + self.training_vulnerabilities: List[str] = json_obj.get("training_vulnerabilities", []) + self.deprecated: bool = bool(json_obj.get("deprecated", False)) self.deprecated_message: Optional[str] = json_obj.get("deprecated_message") self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message") self.usage_info_message: Optional[str] = json_obj.get("usage_info_message") @@ -917,6 +923,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if json_obj.get("training_instance_type_variants") else None ) + self.model_subscription_link = json_obj.get("model_subscription_link") def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartModelSpecs object.""" @@ -1049,6 +1056,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "model_type", "instance_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", @@ -1079,6 +1087,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "instance_type", "model_id", "model_version", + "model_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", "region", @@ -1090,6 +1099,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1119,6 +1129,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.model_type = model_type self.instance_type = instance_type self.region = region self.image_uri = image_uri @@ -1151,6 +1162,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "model_type", "initial_instance_count", "instance_type", "region", @@ -1182,6 +1194,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "model_type", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1193,6 +1206,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -1224,6 +1238,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.model_type = model_type self.initial_instance_count = initial_instance_count self.instance_type = instance_type self.region = region @@ -1258,6 +1273,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "model_type", "instance_type", "instance_count", "region", @@ -1317,12 +1333,14 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_id", "model_version", + "model_type", } def __init__( self, model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, role: Optional[str] = None, @@ -1379,6 +1397,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.model_type = (model_type,) self.instance_type = instance_type self.instance_count = instance_count self.region = region @@ -1440,6 +1459,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "model_type", "region", "inputs", "wait", @@ -1454,6 +1474,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "model_type", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1464,6 +1485,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, inputs: Optional[Union[str, Dict, Any, Any]] = None, wait: Optional[bool] = None, @@ -1478,6 +1500,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.model_type = model_type self.region = region self.inputs = inputs self.wait = wait diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 2621422811..71a8067a6f 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 from packaging.version import Version @@ -26,6 +26,7 @@ TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, TRAINING_JOB_ROLE_ARN_PATH, ) + from sagemaker.jumpstart import constants, enums from sagemaker.jumpstart import accessors from sagemaker.s3 import parse_s3_url @@ -314,6 +315,7 @@ def add_single_jumpstart_tag( ( tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags) + or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags) ) if is_uri else False @@ -348,6 +350,7 @@ def add_jumpstart_model_id_version_tags( tags: Optional[List[TagsDict]], model_id: str, model_version: str, + model_type: Optional[enums.JumpStartModelType] = None, ) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: @@ -364,6 +367,13 @@ def add_jumpstart_model_id_version_tags( tags, is_uri=False, ) + if model_type == enums.JumpStartModelType.PROPRIETARY: + tags = add_single_jumpstart_tag( + enums.JumpStartModelType.PROPRIETARY.value, + enums.JumpStartTag.MODEL_TYPE, + tags, + is_uri=False, + ) return tags @@ -530,6 +540,7 @@ def verify_model_region_and_return_specs( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -578,6 +589,7 @@ def verify_model_region_and_return_specs( model_id=model_id, version=version, s3_client=sagemaker_session.s3_client, + model_type=model_type, ) if ( @@ -732,36 +744,52 @@ def resolve_estimator_sagemaker_config_field( return field_val -def is_valid_model_id( +def validate_model_id_and_get_type( model_id: Optional[str], region: Optional[str] = None, model_version: Optional[str] = None, script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> bool: - """Returns True if the model ID is supported for the given script. +) -> Optional[enums.JumpStartModelType]: + """Returns model type if the model ID is supported for the given script. Raises: ValueError: If the script is not supported by JumpStart. """ + + def _get_model_type( + model_id: str, + open_weights_model_ids: Set[str], + proprietary_model_ids: Set[str], + script: enums.JumpStartScriptScope, + ) -> Optional[enums.JumpStartModelType]: + if model_id in open_weights_model_ids: + return enums.JumpStartModelType.OPEN_WEIGHTS + if model_id in proprietary_model_ids: + if script == enums.JumpStartScriptScope.INFERENCE: + return enums.JumpStartModelType.PROPRIETARY + raise ValueError(f"Unsupported script for Marketplace models: {script}") + return None + if model_id in {None, ""}: - return False + return None if not isinstance(model_id, str): - return False + return None s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME model_version = model_version or "*" - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( - region=region, s3_client=s3_client + region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.OPEN_WEIGHTS + ) + open_weight_model_id_set = {model.model_id for model in models_manifest_list} + + proprietary_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.PROPRIETARY ) - model_id_set = {model.model_id for model in models_manifest_list} - if script == enums.JumpStartScriptScope.INFERENCE: - return model_id in model_id_set - if script == enums.JumpStartScriptScope.TRAINING: - return model_id in model_id_set - raise ValueError(f"Unsupported script: {script}") + + proprietary_model_id_set = {model.model_id for model in proprietary_manifest_list} + return _get_model_type(model_id, open_weight_model_id_set, proprietary_model_id_set, script) def get_jumpstart_model_id_version_from_resource_arn( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 5a2b27c54d..b26d201247 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -138,7 +138,7 @@ class Model(ModelBase, InferenceRecommenderMixin): def __init__( self, - image_uri: Union[str, PipelineVariable], + image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, predictor_cls: Optional[callable] = None, diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 52d633ed4e..de33f61b82 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -21,6 +21,7 @@ from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.payload_utils import PayloadSerializer from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -31,6 +32,7 @@ def retrieve_all_examples( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -83,6 +85,7 @@ def retrieve_all_examples( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if unserialized_payload_dict is None: @@ -120,6 +123,7 @@ def retrieve_example( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -133,6 +137,8 @@ def retrieve_example( the model payload. model_version (str): The version of the JumpStart model for which to retrieve the model payload. + model_type (str): The model type of the JumpStart model, either is open weight + or proprietary. serialize (bool): Whether to serialize byte-stream valued payloads by downloading binary files from s3 and applying encoding, or to keep payload in pre-serialized state. Set this option to False if you want to avoid s3 downloads or if you @@ -162,6 +168,7 @@ def retrieve_example( region=region, model_id=model_id, model_version=model_version, + model_type=model_type, serialize=serialize, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 42c2af0917..6f846bba65 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -15,6 +15,7 @@ from typing import Optional from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint @@ -41,6 +42,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Predictor: """Retrieves the default predictor for the model matching the given arguments. @@ -107,4 +109,5 @@ def retrieve_default( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index 93b2833a35..df14ac558f 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -21,6 +21,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session LOGGER = logging.getLogger("sagemaker") @@ -33,6 +34,7 @@ def retrieve_default( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, ) -> ResourceRequirements: @@ -82,6 +84,7 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + model_type=model_type, sagemaker_session=sagemaker_session, instance_type=instance_type, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index fc76c0fa76..aefb52bd97 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -34,6 +34,7 @@ from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -93,6 +94,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. @@ -134,4 +136,5 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 24050807cc..5205765e2f 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -256,9 +256,33 @@ def test_jumpstart_model_register(setup): predictor = model_package.deploy( instance_type="ml.p3.2xlarge", initial_instance_count=1, - tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], ) response = predictor.predict("hello world!") assert response is not None + + +@pytest.mark.skipif( + True, + reason="Only enable if test account is subscribed to the proprietary model", +) +def test_proprietary_jumpstart_model(setup): + + model_id = "ai21-jurassic-2-light" + + model = JumpStartModel( + model_id=model_id, + model_version="2.0.004", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + predictor = model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}] + ) + payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1} + + response = predictor.predict(payload) + + assert response is not None diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 28211d06f1..49c18beec2 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -17,21 +17,27 @@ from sagemaker import accept_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec + mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_accept_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -45,18 +51,26 @@ def test_jumpstart_default_accept_types( assert default_accept_type == "application/json" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_supported_accept_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -73,5 +87,9 @@ def test_jumpstart_supported_accept_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index 4b2db7d7f4..7765d6eaad 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -17,6 +17,7 @@ from sagemaker import content_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec @@ -24,14 +25,18 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_content_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -45,18 +50,26 @@ def test_jumpstart_default_content_types( assert default_content_type == "application/x-text" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_supported_content_types( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -72,5 +85,9 @@ def test_jumpstart_supported_content_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 9d6e2f21de..5328533da5 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -18,6 +18,7 @@ from sagemaker import base_deserializers, deserializers from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec @@ -26,14 +27,18 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_deserializers( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -47,18 +52,26 @@ def test_jumpstart_default_deserializers( assert isinstance(default_deserializer, base_deserializers.JSONDeserializer) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_deserializer_options( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -79,5 +92,9 @@ def test_jumpstart_deserializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index acd8d19923..38cc5ebbf3 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -18,17 +18,23 @@ import pytest from sagemaker import environment_variables +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec + mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_default_environment_variables(patched_get_model_specs): +def test_jumpstart_default_environment_variables( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -48,7 +54,11 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -68,7 +78,11 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -98,10 +112,14 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_sdk_environment_variables(patched_get_model_specs): +def test_jumpstart_sdk_environment_variables( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -122,7 +140,11 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -143,7 +165,11 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index eebc079164..a13fba87ae 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -18,6 +18,7 @@ import pytest from sagemaker import hyperparameters +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec @@ -26,10 +27,14 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_default_hyperparameters(patched_get_model_specs): +def test_jumpstart_default_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "pytorch-eqa-bert-base-cased" region = "us-west-2" @@ -47,6 +52,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): model_id=model_id, version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -64,6 +70,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): model_id=model_id, version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -89,6 +96,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): model_id=model_id, version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 0054ed9dbd..7a5df4ac93 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -17,7 +17,7 @@ import pytest import boto3 from sagemaker import hyperparameters -from sagemaker.jumpstart.enums import HyperparameterValidationMode +from sagemaker.jumpstart.enums import HyperparameterValidationMode, JumpStartModelType from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError from sagemaker.jumpstart.types import JumpStartHyperparameter @@ -27,8 +27,11 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_validate_provided_hyperparameters(patched_get_model_specs): +def test_jumpstart_validate_provided_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): def add_options_to_hyperparameter(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) spec.hyperparameters.extend( @@ -109,6 +112,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): return spec patched_get_model_specs.side_effect = add_options_to_hyperparameter + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -140,6 +144,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -398,8 +403,11 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_validate_algorithm_hyperparameters(patched_get_model_specs): +def test_jumpstart_validate_algorithm_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): def add_options_to_hyperparameter(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) spec.hyperparameters.append( @@ -416,6 +424,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): return spec patched_get_model_specs.side_effect = add_options_to_hyperparameter + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -437,7 +446,11 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -464,10 +477,14 @@ def add_options_to_hyperparameter(*largs, **kwargs): assert str(e.value) == "Cannot find algorithm hyperparameter for 'adam-learning-rate'." +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): +def test_jumpstart_validate_all_hyperparameters( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "pytorch-eqa-bert-base-cased", "*" region = "us-west-2" @@ -491,7 +508,11 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 8a41891280..6c80c97f33 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -18,19 +18,24 @@ from sagemaker import image_uris from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.image_uris.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_common_image_uri( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -49,6 +54,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -69,6 +75,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -89,6 +96,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -109,6 +117,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index bed2e50674..982c7f1702 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -18,14 +18,17 @@ import pytest from sagemaker import instance_types +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_instance_types(patched_get_model_specs): +def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_model_id_and_get_type): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "huggingface-eqa-bert-base-cased", "*" region = "us-west-2" @@ -47,6 +50,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -65,6 +69,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -89,6 +94,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -111,7 +117,11 @@ def test_jumpstart_instance_types(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 605253466a..ce8cc4ddfa 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -6113,6 +6113,7 @@ "deprecated_message": None, "hosting_model_package_arns": None, "hosting_eula_key": None, + "model_subscription_link": None, "hyperparameters": [ { "name": "epochs", @@ -6309,3 +6310,83 @@ "imagenet-inception-v3-classification-4/specs_v3.0.0.json", }, ] + +BASE_PROPRIETARY_HEADER = { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], +} + +BASE_PROPRIETARY_MANIFEST = [ + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "lighton-mini-instruct40b", + "version": "v1.0", + "min_version": "2.0.0", + "spec_key": "proprietary-models/lighton-mini-instruct40b/proprietary_specs_v1.0.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "ai21-paraphrase", + "version": "1.0.005", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, +] + +BASE_PROPRIETARY_SPEC = { + "model_id": "ai21-jurassic-2-light", + "version": "2.0.004", + "min_sdk_version": "2.999.0", + "listing_id": "prodview-roz6zicyvi666", + "product_id": "1bd680a0-f29b-479d-91c3-9899743021cf", + "model_subscription_link": "https://aws.amazon.com/marketplace/ai/procurement?productId=1bd680a0", + "hosting_notebook_key": "pmm-notebooks/pmm-notebook-ai21-jurassic-2-light.ipynb", + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 600, + }, + "default_payloads": { + "Shakespeare": { + "content_type": "application/json", + "prompt_key": "prompt", + "output_keys": {"generated_text": "[0].completions[0].data.text"}, + "body": {"prompt": "To be, or", "maxTokens": 1, "temperature": 0}, + } + }, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "hosting_model_package_arns": { + "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/j2-light-v2-0-004", + "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/j2-light-v2-0-004", + "us-west-1": "arn:aws:sagemaker:us-west-1:382657785993:model-package/j2-light-v2-0-004", + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/j2-light-v2-0-004", + "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/j2-light-v2-0-004", + "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/j2-light-v2-0-004", + "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/j2-light-v2-0-004", + "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/j2-light-v2-0-004", + "eu-west-3": "arn:aws:sagemaker:eu-west-3:843114510376:model-package/j2-light-v2-0-004", + "eu-north-1": "arn:aws:sagemaker:eu-north-1:136758871317:model-package/j2-light-v2-0-004", + "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/j2-light-v2-0-004", + "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/j2-light-v2-0-004", + "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/j2-light-v2-0-004", + "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/j2-light-v2-0-004", + "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/j2-light-v2-0-004", + "sa-east-1": "arn:aws:sagemaker:sa-east-1:270155090741:model-package/j2-light-v2-0-004", + }, +} diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 4dc35b65ca..4fa18f31aa 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -31,7 +31,7 @@ _retrieve_default_training_metric_definitions, ) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.estimator import JumpStartEstimator @@ -60,9 +60,10 @@ class EstimatorTest(unittest.TestCase): @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -75,14 +76,15 @@ def test_non_prepacked( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, + mock_get_model_type: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, mock_jumpstart_estimator_factory_logger: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "9876" @@ -92,6 +94,8 @@ def test_non_prepacked( mock_get_model_specs.side_effect = get_special_model_spec + mock_get_model_type.return_value = JumpStartModelType.OPEN_WEIGHTS + mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session @@ -182,7 +186,7 @@ def test_non_prepacked( ], ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -199,11 +203,11 @@ def test_prepacked( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -282,7 +286,7 @@ def test_prepacked( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -299,14 +303,14 @@ def test_gated_model_s3_uri( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_timestamp.return_value = "8675309" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-gated-artifact-trainable-model", "*" @@ -418,7 +422,7 @@ def test_gated_model_s3_uri( "sagemaker.jumpstart.artifacts.environment_variables.get_jumpstart_gated_content_bucket" ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -435,7 +439,7 @@ def test_gated_model_non_model_package_s3_uri( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, mock_get_jumpstart_gated_content_bucket: mock.Mock, ): @@ -444,7 +448,7 @@ def test_gated_model_non_model_package_s3_uri( mock_get_jumpstart_gated_content_bucket.return_value = "top-secret-private-models-bucket" mock_timestamp.return_value = "8675309" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-gated-artifact-non-model-package-trainable-model", "*" @@ -566,7 +570,7 @@ def test_gated_model_non_model_package_s3_uri( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -583,15 +587,13 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_timestamp.return_value = "8675309" - mock_is_valid_model_id.return_value = True - model_id, _ = "js-gated-artifact-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -599,6 +601,8 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + with pytest.raises(ValueError) as e: JumpStartEstimator(model_id=model_id, region="eu-north-1") @@ -608,7 +612,7 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( "us-west-2, us-east-1, eu-west-1, ap-southeast-1." ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -621,10 +625,10 @@ def test_deprecated( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "deprecated_model", "*" @@ -642,7 +646,7 @@ def test_deprecated( JumpStartEstimator(model_id=model_id, tolerate_deprecated_model=True).fit(channels).deploy() - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -655,9 +659,9 @@ def test_vulnerable( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "vulnerable_model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -758,7 +762,7 @@ def test_estimator_use_kwargs(self): @mock.patch("sagemaker.jumpstart.factory.estimator.metric_definitions_utils.retrieve_default") @mock.patch("sagemaker.jumpstart.factory.estimator.environment_variables.retrieve_default") @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -775,7 +779,7 @@ def evaluate_estimator_workflow_with_kwargs( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_timestamp: mock.Mock, mock_retrieve_default_environment_variables: mock.Mock, mock_retrieve_metric_definitions: mock.Mock, @@ -806,7 +810,7 @@ def evaluate_estimator_workflow_with_kwargs( mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "js-trainable-model", "*" @@ -908,16 +912,16 @@ def evaluate_estimator_workflow_with_kwargs( mock_estimator_deploy.assert_called_once_with(**expected_deploy_kwargs) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_tags_disabled( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -947,16 +951,16 @@ def test_jumpstart_estimator_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_tags( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -989,18 +993,18 @@ def test_jumpstart_estimator_tags( @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_attach_no_model_id_happy_case( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, get_model_id_version_from_training_job: mock.Mock, mock_attach: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS get_model_id_version_from_training_job.return_value = ( "js-trainable-model-prepacked", @@ -1032,18 +1036,18 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_estimator_attach_no_model_id_sad_case( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, get_model_id_version_from_training_job: mock.Mock, mock_attach: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS get_model_id_version_from_training_job.side_effect = ValueError() @@ -1115,22 +1119,22 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): @mock.patch("sagemaker.jumpstart.estimator.get_init_kwargs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") - def test_is_valid_model_id( + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + def test_validate_model_id_and_get_type( self, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartEstimator(model_id="valid_model_id") - mock_is_valid_model_id.return_value = False + mock_validate_model_id_and_get_type.return_value = False with pytest.raises(ValueError): JumpStartEstimator(model_id="invalid_model_id") @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1147,14 +1151,14 @@ def test_no_predictor_returns_default_predictor( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_get_default_predictor.return_value = default_predictor_with_presets - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1189,7 +1193,7 @@ def test_no_predictor_returns_default_predictor( self.assertEqual(predictor, default_predictor_with_presets) @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1206,14 +1210,14 @@ def test_no_predictor_yes_async_inference_config( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_get_default_predictor.return_value = default_predictor_with_presets - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1239,7 +1243,7 @@ def test_no_predictor_yes_async_inference_config( self.assertEqual(type(predictor), Predictor) @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1256,14 +1260,14 @@ def test_yes_predictor_returns_unmodified_predictor( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor mock_get_default_predictor.return_value = default_predictor_with_presets - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model-prepacked", "*" @@ -1289,7 +1293,7 @@ def test_yes_predictor_returns_unmodified_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -1310,9 +1314,9 @@ def test_incremental_training_with_unsupported_model_logs_warning( mock_session_model: mock.Mock, mock_logger_warning: mock.Mock, mock_supports_incremental_training: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_estimator_deploy.return_value = default_predictor @@ -1343,7 +1347,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( sagemaker_session=sagemaker_session, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -1364,9 +1368,9 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( mock_session_model: mock.Mock, mock_logger_warning: mock.Mock, mock_supports_incremental_training: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_estimator_deploy.return_value = default_predictor @@ -1395,7 +1399,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1412,10 +1416,10 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "3456" @@ -1456,7 +1460,7 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p4de.24xlarge" @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1473,10 +1477,10 @@ def test_training_passes_role_to_deploy( mock_get_model_specs: mock.Mock, mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "3456" @@ -1533,7 +1537,7 @@ def test_training_passes_role_to_deploy( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch( "sagemaker.jumpstart.factory.model.DEFAULT_JUMPSTART_SAGEMAKER_SESSION", sagemaker_session ) @@ -1553,10 +1557,10 @@ def test_training_passes_session_to_deploy( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_sagemaker_timestamp.return_value = "3456" @@ -1611,7 +1615,7 @@ def test_training_passes_session_to_deploy( ], ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_init_kwargs") @@ -1627,11 +1631,11 @@ def test_model_id_not_found_refeshes_cache_training( mock_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.side_effect = [False, False] + mock_validate_model_id_and_get_type.side_effect = [False, False] model_id, _ = "js-trainable-model", "*" @@ -1647,7 +1651,7 @@ def test_model_id_not_found_refeshes_cache_training( ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -1666,16 +1670,16 @@ def test_model_id_not_found_refeshes_cache_training( ] ) - mock_is_valid_model_id.reset_mock() + mock_validate_model_id_and_get_type.reset_mock() mock_reset_cache.reset_mock() - mock_is_valid_model_id.side_effect = [False, True] + mock_validate_model_id_and_get_type.side_effect = [False, True] JumpStartEstimator( model_id=model_id, ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -1694,7 +1698,7 @@ def test_model_id_not_found_refeshes_cache_training( ] ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -1704,10 +1708,10 @@ def test_model_artifact_variant_estimator( mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model-artifact-variant-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py index d22e910a00..073921d5ba 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py @@ -25,6 +25,7 @@ from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.estimator import JumpStartEstimator +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session from sagemaker.utils import resolve_value_from_config @@ -78,7 +79,7 @@ def config_value_impl(sagemaker_session: Session, config_path: str, sagemaker_co class IntelligentDefaultsEstimatorTest(unittest.TestCase): - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -98,12 +99,12 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_model_init_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -135,7 +136,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( assert "enable_network_isolation" not in mock_estimator_deploy.call_args[1] - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -155,12 +156,12 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -208,7 +209,7 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( config_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -228,12 +229,12 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -290,7 +291,7 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( override_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -310,12 +311,12 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -365,7 +366,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( override_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.session.Session.get_caller_identity_arn") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -387,12 +388,12 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, mock_get_caller_identity_arn: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -426,7 +427,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( assert "enable_network_isolation" not in mock_estimator_deploy.call_args[1] - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.session.Session.get_caller_identity_arn") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -448,12 +449,12 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, mock_get_caller_identity_arn: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_get_caller_identity_arn.return_value = execution_role model_id, _ = "js-trainable-model", "*" @@ -500,7 +501,7 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( metadata_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -520,11 +521,11 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( mock_model_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -576,7 +577,7 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( override_inference_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_init_kwargs") @@ -594,11 +595,11 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_estimator_init: mock.Mock, mock_estimator_deploy: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_estimator_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index f45283935b..ba4ba0bb13 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -22,7 +22,7 @@ _retrieve_default_environment_variables, ) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.model import JumpStartModel from sagemaker.model import Model @@ -32,9 +32,11 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from tests.unit.sagemaker.jumpstart.utils import ( + get_spec_from_base_spec, get_special_model_spec, overwrite_dictionary, get_special_model_spec_for_inference_component_based_endpoint, + get_prototype_manifest, ) execution_role = "fake role! do not use!" @@ -53,7 +55,7 @@ class ModelTest(unittest.TestCase): @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -65,7 +67,7 @@ def test_non_prepacked( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, ): @@ -73,7 +75,7 @@ def test_non_prepacked( mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -128,7 +130,7 @@ def test_non_prepacked( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -140,14 +142,14 @@ def test_non_prepacked_inference_component_based_endpoint( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = ( @@ -208,7 +210,7 @@ def test_non_prepacked_inference_component_based_endpoint( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -220,14 +222,14 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -282,7 +284,7 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -294,11 +296,11 @@ def test_prepacked( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -345,7 +347,7 @@ def test_prepacked( @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -353,7 +355,7 @@ def test_no_compiled_model_warning_log_js_models( self, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, @@ -362,7 +364,7 @@ def test_no_compiled_model_warning_log_js_models( mock_timestamp.return_value = "1234" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "gated_llama_neuron_model", "*" @@ -381,7 +383,7 @@ def test_no_compiled_model_warning_log_js_models( @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -389,15 +391,14 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( self, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, ): - mock_timestamp.return_value = "1234" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "gated_variant-model", "*" @@ -441,7 +442,64 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( ], ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.__init__") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_proprietary_model_endpoint( + self, + mock_model_deploy: mock.Mock, + mock_model_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + mock_sagemaker_timestamp.return_value = "7777" + + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.PROPRIETARY + model_id, _ = "ai21-summarization", "*" + + mock_get_model_specs.side_effect = get_spec_from_base_spec + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id, model_version="2.0.004") + + mock_model_init.assert_called_once_with( + predictor_cls=Predictor, + role=execution_role, + sagemaker_session=sagemaker_session, + enable_network_isolation=False, + ) + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p4de.24xlarge", + wait=True, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "ai21-summarization"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "2.0.004"}, + {"Key": JumpStartTag.MODEL_TYPE, "Value": "proprietary"}, + ], + endpoint_logging=False, + model_data_download_timeout=3600, + container_startup_health_check_timeout=600, + ) + + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -451,11 +509,11 @@ def test_deprecated( mock_model_deploy: mock.Mock, mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "deprecated_model", "*" @@ -468,7 +526,7 @@ def test_deprecated( JumpStartModel(model_id=model_id, tolerate_deprecated_model=True).deploy() - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -478,9 +536,9 @@ def test_vulnerable( mock_model_deploy: mock.Mock, mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_model_deploy.return_value = default_predictor @@ -543,7 +601,7 @@ def test_model_use_kwargs(self): ) @mock.patch("sagemaker.jumpstart.factory.model.environment_variables.retrieve_default") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -555,7 +613,7 @@ def evaluate_model_workflow_with_kwargs( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_retrieve_environment_variables: mock.Mock, init_kwargs: Optional[dict] = None, deploy_kwargs: Optional[dict] = None, @@ -565,7 +623,7 @@ def evaluate_model_workflow_with_kwargs( mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_session.return_value = sagemaker_session @@ -661,22 +719,22 @@ def test_jumpstart_model_kwargs_match_parent_class(self): @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") - def test_is_valid_model_id( + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") + def test_validate_model_id_and_get_type( self, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartModel(model_id="valid_model_id") - mock_is_valid_model_id.return_value = False + mock_validate_model_id_and_get_type.return_value = False with pytest.raises(ValueError): JumpStartModel(model_id="invalid_model_id") @mock.patch("sagemaker.jumpstart.model.get_default_predictor") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -688,14 +746,14 @@ def test_no_predictor_returns_default_predictor( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -717,12 +775,13 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -734,14 +793,14 @@ def test_no_predictor_yes_async_inference_config( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -758,7 +817,7 @@ def test_no_predictor_yes_async_inference_config( mock_get_default_predictor.assert_not_called() @mock.patch("sagemaker.jumpstart.model.get_default_predictor") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -770,14 +829,14 @@ def test_yes_predictor_returns_default_predictor( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-class-model-prepacked", "*" @@ -793,24 +852,24 @@ def test_yes_predictor_returns_default_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.JumpStartModelsAccessor.reset_cache") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) - def test_model_id_not_found_refeshes_cach_inference( + def test_model_id_not_found_refeshes_cache_inference( self, mock_reset_cache: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.side_effect = [False, False] + mock_validate_model_id_and_get_type.side_effect = [False, False] model_id, _ = "js-trainable-model", "*" @@ -826,7 +885,7 @@ def test_model_id_not_found_refeshes_cach_inference( ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -845,16 +904,19 @@ def test_model_id_not_found_refeshes_cach_inference( ] ) - mock_is_valid_model_id.reset_mock() + mock_validate_model_id_and_get_type.reset_mock() mock_reset_cache.reset_mock() - mock_is_valid_model_id.side_effect = [False, True] + mock_validate_model_id_and_get_type.side_effect = [ + False, + JumpStartModelType.OPEN_WEIGHTS, + ] JumpStartModel( model_id=model_id, ) mock_reset_cache.assert_called_once_with() - mock_is_valid_model_id.assert_has_calls( + mock_validate_model_id_and_get_type.assert_has_calls( calls=[ mock.call( model_id="js-trainable-model", @@ -873,16 +935,16 @@ def test_model_id_not_found_refeshes_cach_inference( ] ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_tags( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "env-var-variant-model", "*" @@ -909,16 +971,16 @@ def test_jumpstart_model_tags( [{"Key": "blah", "Value": "blahagain"}] + js_tags, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_tags_disabled( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "env-var-variant-model", "*" @@ -941,16 +1003,16 @@ def test_jumpstart_model_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-package-arn", "*" @@ -975,16 +1037,16 @@ def test_jumpstart_model_package_arn( self.assertIn(tag, mock_session.create_model.call_args[1]["tags"]) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn_override( self, mock_get_model_specs: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS # arbitrary model without model packarn arn model_id, _ = "js-trainable-model", "*" @@ -1017,7 +1079,7 @@ def test_jumpstart_model_package_arn_override( }, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1025,10 +1087,10 @@ def test_jumpstart_model_package_arn_unsupported_region( self, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-model-package-arn", "*" @@ -1044,7 +1106,7 @@ def test_jumpstart_model_package_arn_unsupported_region( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -1058,14 +1120,14 @@ def test_model_data_s3_prefix_override( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): mock_model_deploy.return_value = default_predictor mock_sagemaker_timestamp.return_value = "7777" - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1109,7 +1171,7 @@ def test_model_data_s3_prefix_override( '"S3DataType": "S3Prefix", "CompressionType": "None"}}', ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -1123,11 +1185,11 @@ def test_model_data_s3_prefix_model( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model_data_s3_prefix_model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1153,7 +1215,7 @@ def test_model_data_s3_prefix_model( mock_js_info_logger.assert_not_called() - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -1167,11 +1229,11 @@ def test_model_artifact_variant_model( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model-artifact-variant-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -1218,7 +1280,7 @@ def test_model_artifact_variant_model( enable_network_isolation=True, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -1230,11 +1292,11 @@ def test_model_registry_accept_and_response_types( mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): mock_model_deploy.return_value = default_predictor - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "model_data_s3_prefix_model", "*" mock_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 727f3060b3..70409704e6 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -21,6 +21,7 @@ from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.model import JumpStartModel +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session from sagemaker.utils import resolve_value_from_config @@ -59,7 +60,7 @@ class IntelligentDefaultsModelTest(unittest.TestCase): region = "us-west-2" sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -75,10 +76,10 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" mock_retrieve_kwargs.return_value = {} @@ -100,7 +101,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( assert "enable_network_isolation" not in mock_model_init.call_args[1] - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -116,10 +117,10 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -146,7 +147,7 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( override_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -162,10 +163,10 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -192,7 +193,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( config_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -208,10 +209,10 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -240,7 +241,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( override_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -256,10 +257,10 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -286,7 +287,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( metadata_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -302,9 +303,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -333,7 +334,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( override_enable_network_isolation, ) - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -349,10 +350,10 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" @@ -374,7 +375,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_model_init.call_args[1] - @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.utils.get_sagemaker_config_value") @@ -390,10 +391,10 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_get_sagemaker_config_value: mock.Mock, mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, - mock_is_valid_model_id: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, ): - mock_is_valid_model_id.return_value = True + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, _ = "js-trainable-model", "*" diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 97427be1ae..c57d2a958b 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -18,6 +18,7 @@ import pytest from sagemaker.jumpstart import accessors +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.constants import BASE_MANIFEST from tests.unit.sagemaker.jumpstart.utils import ( get_header_from_base_header, @@ -63,8 +64,51 @@ def test_jumpstart_models_cache_get_fxs(mock_cache): reload(accessors) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") +def test_jumpstart_proprietary_models_cache_get(mock_cache): + + mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST) + mock_cache.get_header = Mock(side_effect=get_header_from_base_header) + mock_cache.get_specs = Mock(side_effect=get_spec_from_base_spec) + + assert get_header_from_base_header( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) == accessors.JumpStartModelsAccessor.get_model_header( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert get_spec_from_base_spec( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) == accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + assert ( + len( + accessors.JumpStartModelsAccessor._get_manifest( + model_type=JumpStartModelType.PROPRIETARY + ) + ) + > 0 + ) + + # necessary because accessors is a static module + reload(accessors) + + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache") -def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): +def test_jumpstart_models_cache_set_reset(mock_model_cache: Mock): # test change of region resets cache accessors.JumpStartModelsAccessor.get_model_header( @@ -138,6 +182,50 @@ def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): reload(accessors) +@patch("sagemaker.jumpstart.cache.JumpStartModelsCache") +def test_jumpstart_proprietary_models_cache_set_reset(mock_model_cache: Mock): + + # test change of region resets cache + accessors.JumpStartModelsAccessor.get_model_header( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.get_model_header( + region="us-east-2", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-1", + model_id="ai21-summarization", + version="*", + model_type=JumpStartModelType.PROPRIETARY, + ) + mock_model_cache.assert_called_once() + mock_model_cache.reset_mock() + + # necessary because accessors is a static module + reload(accessors) + + class TestS3Accessor(TestCase): bucket = "bucket" diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 1a770f785f..21112926a5 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -32,7 +32,7 @@ from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn from sagemaker.jumpstart.artifacts.model_uris import _retrieve_model_uri -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec from tests.unit.sagemaker.workflow.conftest import mock_client @@ -331,9 +331,13 @@ class RetrieveModelPackageArnTest(unittest.TestCase): mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - def test_retrieve_model_package_arn(self, patched_get_model_specs): + def test_retrieve_model_package_arn( + self, patched_get_model_specs: Mock, patched_validate_model_id_and_get_type: Mock + ): patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "variant-model" region = "us-west-2" @@ -437,9 +441,13 @@ class PrivateJumpStartBucketTest(unittest.TestCase): mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs): + def test_retrieve_uri_from_gated_bucket( + self, patched_get_model_specs, patched_validate_model_id_and_get_type + ): patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id = "private-model" region = "us-west-2" diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 6633ecdc23..50fe6da0a6 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -23,7 +23,11 @@ import pytest from mock import patch -from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache +from sagemaker.jumpstart.cache import ( + JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, + JumpStartModelsCache, +) from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -32,7 +36,9 @@ JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, + JumpStartS3FileType, ) +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import ( get_spec_from_base_spec, patched_retrieval_function, @@ -41,6 +47,8 @@ from tests.unit.sagemaker.jumpstart.constants import ( BASE_MANIFEST, BASE_SPEC, + BASE_PROPRIETARY_SPEC, + BASE_PROPRIETARY_MANIFEST, ) from sagemaker.jumpstart.utils import get_jumpstart_content_bucket @@ -152,6 +160,33 @@ def test_jumpstart_cache_get_header(): semantic_version_str="1.0.*", ) + assert JumpStartModelHeader( + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + } + ) == cache.get_header( + model_id="ai21-summarization", + semantic_version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="3.*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Proprietary model 'ai21-summarization' does not support wildcard version identifier '3.*'. " + "You can pin to version '1.1.003'. " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for a list of valid model IDs. " in str(e.value) + ) + with pytest.raises(KeyError) as e: cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", @@ -194,6 +229,19 @@ def test_jumpstart_cache_get_header(): "v3-classification-4'?" ) in str(e.value) + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="ai21-summarize", + semantic_version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Unable to find model manifest for 'ai21-summarize' with version '1.1.003'. " + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for updated list of models. " + "Did you mean to use model ID 'ai21-summarization'?" + ) in str(e.value) + with pytest.raises(KeyError): cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", @@ -224,6 +272,27 @@ def test_jumpstart_cache_get_header(): semantic_version_str="*", ) + with pytest.raises(KeyError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="1.1.004", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError): + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="2.*", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError): + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="v*", + model_type=JumpStartModelType.PROPRIETARY, + ) + @patch("boto3.client") def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client): @@ -276,6 +345,12 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): cache.set_manifest_file_s3_key("some_key1") cache.clear.assert_called_once() + cache.clear.reset_mock() + cache.set_manifest_file_s3_key("some_key1", file_type=JumpStartS3FileType.OPEN_WEIGHT_MANIFEST) + cache.clear.assert_called_once() + with pytest.raises(ValueError): + cache.set_manifest_file_s3_key("some_key1", file_type="unknown_type") + def test_jumpstart_cache_handles_boto3_client_errors(): # Testing get_object @@ -423,15 +498,80 @@ def test_jumpstart_cache_accepts_input_parameters(): assert cache._s3_cache._max_cache_items == max_s3_cache_items assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon assert ( - cache._model_id_semantic_version_manifest_key_cache._max_cache_items + cache._open_weight_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items ) assert ( - cache._model_id_semantic_version_manifest_key_cache._expiration_horizon + cache._open_weight_model_id_manifest_key_cache._expiration_horizon == semantic_version_cache_expiration_horizon ) +def test_jumpstart_proprietary_cache_accepts_input_parameters(): + + region = "us-east-1" + max_s3_cache_items = 1 + s3_cache_expiration_horizon = datetime.timedelta(weeks=2) + max_semantic_version_cache_items = 3 + semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4) + bucket = "my-amazing-bucket" + manifest_file_key = "some_s3_key" + proprietary_manifest_file_key = "some_proprietary_s3_key" + + cache = JumpStartModelsCache( + region=region, + max_s3_cache_items=max_s3_cache_items, + s3_cache_expiration_horizon=s3_cache_expiration_horizon, + max_semantic_version_cache_items=max_semantic_version_cache_items, + semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon, + s3_bucket_name=bucket, + manifest_file_s3_key=manifest_file_key, + proprietary_manifest_s3_key=proprietary_manifest_file_key, + ) + + assert ( + cache.get_manifest_file_s3_key(file_type=JumpStartS3FileType.PROPRIETARY_MANIFEST) + == proprietary_manifest_file_key + ) + assert cache.get_region() == region + assert cache.get_bucket() == bucket + assert cache._s3_cache._max_cache_items == max_s3_cache_items + assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert ( + cache._proprietary_model_id_manifest_key_cache._max_cache_items + == max_semantic_version_cache_items + ) + assert ( + cache._proprietary_model_id_manifest_key_cache._expiration_horizon + == semantic_version_cache_expiration_horizon + ) + + +def test_jumpstart_cache_raise_unknown_file_type_exception(): + + region = "us-east-1" + max_s3_cache_items = 1 + s3_cache_expiration_horizon = datetime.timedelta(weeks=2) + max_semantic_version_cache_items = 3 + semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4) + bucket = "my-amazing-bucket" + manifest_file_key = "some_s3_key" + proprietary_manifest_file_key = "some_proprietary_s3_key" + + cache = JumpStartModelsCache( + region=region, + max_s3_cache_items=max_s3_cache_items, + s3_cache_expiration_horizon=s3_cache_expiration_horizon, + max_semantic_version_cache_items=max_semantic_version_cache_items, + semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon, + s3_bucket_name=bucket, + manifest_file_s3_key=manifest_file_key, + proprietary_manifest_s3_key=proprietary_manifest_file_key, + ) + with pytest.raises(ValueError): + cache.get_manifest_file_s3_key(file_type="unknown_type") + + @patch("boto3.client") def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): @@ -583,7 +723,7 @@ def test_jumpstart_cache_makes_correct_s3_calls( with patch("logging.Logger.warning") as mocked_warning_log: cache.get_specs( - model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + model_id="pytorch-ic-imagenet-inception-v3-classification-4", version_str="*" ) mocked_warning_log.assert_called_once_with( "Using model 'pytorch-ic-imagenet-inception-v3-classification-4' with wildcard " @@ -593,7 +733,7 @@ def test_jumpstart_cache_makes_correct_s3_calls( ) mocked_warning_log.reset_mock() cache.get_specs( - model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + model_id="pytorch-ic-imagenet-inception-v3-classification-4", version_str="*" ) mocked_warning_log.assert_not_called() @@ -605,13 +745,85 @@ def test_jumpstart_cache_makes_correct_s3_calls( mock_boto3_client.return_value.head_object.assert_not_called() +@patch("sagemaker.jumpstart.cache.utils.emit_logs_based_on_model_specs") +@patch("boto3.client") +def test_jumpstart_cache_proprietary_manifest_makes_correct_s3_calls( + mock_boto3_client, mock_emit_logs_based_on_model_specs +): + + # test get_header + mock_manifest_json = json.dumps( + [ + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + }, + ] + ) + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_manifest_json, "utf-8")), content_length=len(mock_manifest_json) + ), + "ETag": "etag", + } + + mock_boto3_client.return_value.head_object.return_value = {"ETag": "some-hash"} + + bucket_name = get_jumpstart_content_bucket("us-west-2") + client_config = botocore.config.Config(signature_version="my_signature_version") + cache = JumpStartModelsCache( + s3_bucket_name=bucket_name, s3_client_config=client_config, region="us-west-2" + ) + cache.get_header( + model_id="ai21-summarization", + semantic_version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY + ) + mock_boto3_client.return_value.head_object.assert_not_called() + + mock_boto3_client.assert_called_with("s3", region_name="us-west-2", config=client_config) + + # test get_specs. manifest already in cache, so only s3 call will be to get specs. + mock_json = json.dumps(BASE_PROPRIETARY_SPEC) + + mock_boto3_client.return_value.reset_mock() + + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ), + "ETag": "etag", + } + + with patch("logging.Logger.warning") as mocked_warning_log: + cache.get_specs( + model_id="ai21-summarization", + version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + mocked_warning_log.assert_not_called() + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, + Key="proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + ) + mock_boto3_client.return_value.head_object.assert_not_called() + + @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") cache.clear = MagicMock() - cache._model_id_semantic_version_manifest_key_cache = MagicMock() - cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ + cache._open_weight_model_id_manifest_key_cache = MagicMock() + cache._open_weight_model_id_manifest_key_cache.get.side_effect = [ ( JumpStartVersionedModelId( "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" @@ -640,7 +852,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache.clear.assert_called_once() cache.clear.reset_mock() - cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ + cache._open_weight_model_id_manifest_key_cache.get.side_effect = [ ( JumpStartVersionedModelId( "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" @@ -668,7 +880,18 @@ def test_jumpstart_get_full_manifest(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") raw_manifest = [header.to_json() for header in cache.get_manifest()] - raw_manifest == BASE_MANIFEST + assert raw_manifest == BASE_MANIFEST + + +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +def test_jumpstart_get_full_proprietary_manifest(): + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + raw_manifest = [ + header.to_json() for header in cache.get_manifest(model_type=JumpStartModelType.PROPRIETARY) + ] + + assert raw_manifest == BASE_PROPRIETARY_MANIFEST @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @@ -678,54 +901,92 @@ def test_jumpstart_cache_get_specs(): model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version ) model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id=model_id, version="2.0.0") == cache.get_specs( - model_id=model_id, semantic_version_str="2.0.*" + model_id=model_id, version_str="2.0.*" ) model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0" assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version ) model_id = "pytorch-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id=model_id, version="1.0.0") == cache.get_specs( - model_id=model_id, semantic_version_str="1.*" + model_id=model_id, version_str="1.*" ) model_id = "pytorch-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id=model_id, version="1.0.0") == cache.get_specs( - model_id=model_id, semantic_version_str="1.0.*" + model_id=model_id, version_str="1.0.*" + ) + + assert get_spec_from_base_spec( + model_id="ai21-summarization", + version="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) == cache.get_specs( + model_id="ai21-summarization", + version_str="1.1.003", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError) as e: + cache.get_specs( + model_id="ai21-summarization", + version_str="3.*", + model_type=JumpStartModelType.PROPRIETARY, + ) + assert ( + "Proprietary model 'ai21-summarization' does not support wildcard version identifier '3.*'. " + "You can pin to version '1.1.003'. " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " + "for a list of valid model IDs. " in str(e.value) ) with pytest.raises(KeyError): - cache.get_specs(model_id=model_id + "bak", semantic_version_str="*") + cache.get_specs(model_id=model_id + "bak", version_str="*") with pytest.raises(KeyError): - cache.get_specs(model_id=model_id, semantic_version_str="9.*") + cache.get_specs(model_id=model_id, version_str="9.*") with pytest.raises(KeyError): - cache.get_specs(model_id=model_id, semantic_version_str="BAD") + cache.get_specs(model_id=model_id, version_str="BAD") with pytest.raises(KeyError): cache.get_specs( model_id=model_id, - semantic_version_str="2.1.*", + version_str="2.1.*", ) with pytest.raises(KeyError): cache.get_specs( model_id=model_id, - semantic_version_str="3.9.*", + version_str="3.9.*", ) with pytest.raises(KeyError): cache.get_specs( model_id=model_id, - semantic_version_str="5.*", + version_str="5.*", + ) + + model_id, version = "ai21-summarization", "2.0.0" + with pytest.raises(KeyError): + cache.get_specs( + model_id=model_id, + version_str="BAD", + model_type=JumpStartModelType.PROPRIETARY, + ) + + with pytest.raises(KeyError): + cache.get_specs( + model_id=model_id, + version_str="9.*", + model_type=JumpStartModelType.PROPRIETARY, ) @@ -794,9 +1055,7 @@ def test_jumpstart_local_metadata_override_specs( cache = JumpStartModelsCache(s3_bucket_name="some_bucket") model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" - assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs( - model_id=model_id, semantic_version_str=version - ) + assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs(model_id=model_id, version_str=version) mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root") mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") @@ -840,7 +1099,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( cache = JumpStartModelsCache(s3_bucket_name="some_bucket") assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( - model_id=model_id, semantic_version_str=version + model_id=model_id, version_str=version ) mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") diff --git a/tests/unit/sagemaker/jumpstart/test_exceptions.py b/tests/unit/sagemaker/jumpstart/test_exceptions.py index 555099a753..2307d22474 100644 --- a/tests/unit/sagemaker/jumpstart/test_exceptions.py +++ b/tests/unit/sagemaker/jumpstart/test_exceptions.py @@ -11,10 +11,15 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import pytest + +from botocore.exceptions import ClientError from sagemaker.jumpstart.exceptions import ( get_wildcard_model_version_msg, get_old_model_version_msg, + get_proprietary_model_subscription_error, + MarketplaceModelSubscriptionError, ) @@ -35,3 +40,32 @@ def test_get_old_model_version_msg(): "Note that models may have different input/output signatures after a major " "version upgrade." == get_old_model_version_msg("mother_of_all_models", "1.0.0", "1.2.3") ) + + +def test_get_marketplace_subscription_error(): + error = ClientError( + error_response={ + "Error": { + "Code": "ValidationException", + "Message": "Caller is not subscribed to the Marketplace listing.", + }, + }, + operation_name="mock-operation", + ) + with pytest.raises(MarketplaceModelSubscriptionError): + get_proprietary_model_subscription_error(error, subscription_link="mock-link") + + error = ClientError( + error_response={ + "Error": { + "Code": "UnknownException", + "Message": "Unknown error raised.", + }, + }, + operation_name="mock-operation", + ) + + try: + get_proprietary_model_subscription_error(error, subscription_link="mock-link") + except MarketplaceModelSubscriptionError: + pytest.fail("MarketplaceModelSubscriptionError should not be raised for unknown error.") diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 1a7108579c..059cd7ccad 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -17,6 +17,7 @@ get_prototype_manifest, get_prototype_model_spec, ) +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.notebook_utils import ( _generate_jumpstart_model_versions, get_model_url, @@ -63,7 +64,7 @@ def test_list_jumpstart_scripts( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 assert patched_get_model_specs.call_count == 1 patched_get_model_specs.reset_mock() @@ -78,8 +79,8 @@ def test_list_jumpstart_scripts( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() - assert patched_read_s3_file.call_count == len(PROTOTYPICAL_MODEL_SPECS_DICT) + assert patched_get_manifest.call_count == 2 + assert patched_read_s3_file.call_count == 2 * len(PROTOTYPICAL_MODEL_SPECS_DICT) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -107,7 +108,7 @@ def test_list_jumpstart_tasks( ) # incomplete list, based on mocked metadata patched_generate_jumpstart_models.assert_called_once() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() patched_get_model_specs.reset_mock() @@ -122,7 +123,7 @@ def test_list_jumpstart_tasks( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() @@ -154,11 +155,11 @@ def test_list_jumpstart_frameworks( ) patched_generate_jumpstart_models.assert_called_once() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_model_specs.assert_not_called() patched_get_model_specs.reset_mock() - patched_get_manifest.reset_mock() + assert patched_get_manifest.call_count == 2 patched_generate_jumpstart_models.reset_mock() kwargs = { @@ -180,7 +181,7 @@ def test_list_jumpstart_frameworks( patched_generate_jumpstart_models.assert_called_once_with( **kwargs, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION ) - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 4 patched_get_model_specs.assert_not_called() @@ -238,7 +239,7 @@ def test_list_jumpstart_models_script_filter( kwargs = {"filter": f"training_supported != {val}"} list_jumpstart_models(**kwargs) assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -255,7 +256,7 @@ def test_list_jumpstart_models_script_filter( ("xgboost-classification-model", "1.0.0"), ] assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -264,7 +265,7 @@ def test_list_jumpstart_models_script_filter( models = list_jumpstart_models(**kwargs) assert [] == models assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -287,7 +288,7 @@ def test_list_jumpstart_models_task_filter( kwargs = {"filter": f"task == {val}"} list_jumpstart_models(**kwargs) patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -295,7 +296,7 @@ def test_list_jumpstart_models_task_filter( kwargs = {"filter": f"task != {val}"} list_jumpstart_models(**kwargs) patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -312,7 +313,7 @@ def test_list_jumpstart_models_task_filter( ("xgboost-classification-model", "1.0.0"), ] patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -321,7 +322,7 @@ def test_list_jumpstart_models_task_filter( models = list_jumpstart_models(**kwargs) assert [] == models patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") @@ -348,7 +349,7 @@ def test_list_jumpstart_models_framework_filter( kwargs = {"filter": f"framework == {val}"} list_jumpstart_models(**kwargs) patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -356,7 +357,7 @@ def test_list_jumpstart_models_framework_filter( kwargs = {"filter": f"framework != {val}"} list_jumpstart_models(**kwargs) patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -372,7 +373,7 @@ def test_list_jumpstart_models_framework_filter( ("xgboost-classification-model", "1.0.0"), ] patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -390,8 +391,8 @@ def test_list_jumpstart_models_framework_filter( } models = list_jumpstart_models(**kwargs) assert [("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0")] == models - patched_read_s3_file.assert_called_once() - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -405,7 +406,7 @@ def test_list_jumpstart_models_framework_filter( models = list_jumpstart_models(**kwargs) assert [] == models patched_read_s3_file.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -420,8 +421,10 @@ def test_list_jumpstart_models_region( list_jumpstart_models(region="some-region") - patched_get_manifest.assert_called_once_with( - region="some-region", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client + patched_get_manifest.assert_called_with( + region="some-region", + s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -478,7 +481,7 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): ] == list_jumpstart_models(list_old_models=True, list_versions=True) patched_get_model_specs.assert_not_called() - patched_get_manifest.assert_called_once() + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_get_model_specs.reset_mock() @@ -526,8 +529,8 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): And("inference_vulnerable is false", "training_vulnerable is false") ) - assert patched_read_s3_file.call_count == num_specs - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -538,8 +541,8 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): And("inference_vulnerable is false", "training_vulnerable is false") ) - assert patched_read_s3_file.call_count == num_specs - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -570,8 +573,8 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str: num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) assert [] == list_jumpstart_models("deprecated equals false") - assert patched_read_s3_file.call_count == num_specs - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * num_specs + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() @@ -607,6 +610,43 @@ def test_list_jumpstart_models_no_versions( assert list_jumpstart_models(list_versions=False) == all_model_ids + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + def test_list_jumpstart_proprietary_models( + self, + patched_get_model_specs: Mock, + patched_get_manifest: Mock, + ): + patched_get_model_specs.side_effect = get_prototype_model_spec + patched_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + all_prop_model_ids = [ + "ai21-paraphrase", + "ai21-summarization", + "lighton-mini-instruct40b", + ] + + all_open_weight_model_ids = [ + "catboost-classification-model", + "huggingface-spc-bert-base-cased", + "lightgbm-classification-model", + "mxnet-semseg-fcn-resnet50-ade", + "pytorch-eqa-bert-base-cased", + "sklearn-classification-linear", + "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", + "xgboost-classification-model", + ] + + assert list_jumpstart_models("model_type == proprietary") == all_prop_model_ids + assert list_jumpstart_models("model_type == marketplace") == all_prop_model_ids + assert list_jumpstart_models("model_type == open_weights") == all_open_weight_model_ids + + assert list_jumpstart_models(list_versions=False) == sorted( + all_prop_model_ids + all_open_weight_model_ids + ) + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_complex_queries( @@ -670,12 +710,20 @@ def test_list_jumpstart_models_multiple_level_index( list_jumpstart_models("hosting_ecr_specs.py_version == py3") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_get_model_url( patched_get_model_specs: Mock, + patched_validate_model_id_and_get_type: Mock, + patched_get_manifest: Mock, ): patched_get_model_specs.side_effect = get_prototype_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( + region + ) model_id, version = "xgboost-classification-model", "1.0.0" assert "https://xgboost.readthedocs.io/en/latest/" == get_model_url(model_id, version) @@ -686,7 +734,6 @@ def test_get_model_url( ) model_id, version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0" - region = "fake-region" patched_get_model_specs.reset_mock() patched_get_model_specs.side_effect = lambda *largs, **kwargs: get_prototype_model_spec( @@ -695,11 +742,12 @@ def test_get_model_url( **{key: value for key, value in kwargs.items() if key != "region"}, ) - get_model_url(model_id, version, region=region) + get_model_url(model_id, version, region="us-west-2") patched_get_model_specs.assert_called_once_with( model_id=model_id, version=version, - region=region, + region="us-west-2", s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 7ab9cdd1cc..52f28f2da1 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -7,17 +7,15 @@ import pytest from sagemaker.deserializers import JSONDeserializer from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import MIMEType +from sagemaker.jumpstart.enums import MIMEType, JumpStartModelType from sagemaker import predictor from sagemaker.jumpstart.model import JumpStartModel from sagemaker.jumpstart.utils import verify_model_region_and_return_specs -from sagemaker.serializers import IdentitySerializer -from tests.unit.sagemaker.jumpstart.utils import ( - get_special_model_spec, -) +from sagemaker.serializers import IdentitySerializer, JSONSerializer +from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, get_spec_from_base_spec @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @@ -54,6 +52,43 @@ def test_jumpstart_predictor_support( assert js_predictor.accept == MIMEType.JSON +@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_proprietary_predictor_support( + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_get_jumpstart_model_id_version_from_endpoint, +): + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs + patched_get_model_specs.side_effect = get_spec_from_base_spec + + # version not needed for JumpStart predictor + model_id, model_version = "ai21-summarization", "*" + + patched_get_jumpstart_model_id_version_from_endpoint.return_value = ( + model_id, + model_version, + None, + ) + + js_predictor = predictor.retrieve_default( + endpoint_name="blah", + model_id=model_id, + model_version=model_version, + model_type=JumpStartModelType.PROPRIETARY, + ) + + patched_get_jumpstart_model_id_version_from_endpoint.assert_not_called() + + assert js_predictor.content_type == MIMEType.JSON + assert isinstance(js_predictor.serializer, JSONSerializer) + + assert isinstance(js_predictor.deserializer, JSONDeserializer) + assert js_predictor.accept == MIMEType.JSON + + @patch("sagemaker.predictor.Predictor") @patch("sagemaker.predictor.get_default_predictor") @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @@ -92,6 +127,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=mock_session, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @@ -125,19 +161,19 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( @patch("sagemaker.predictor.get_model_id_version_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") -@patch("sagemaker.jumpstart.model.is_valid_model_id") +@patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_serializable_payload_with_predictor( patched_get_model_specs, patched_verify_model_region_and_return_specs, - patched_is_valid_model_id, + patched_validate_model_id_and_get_type, patched_get_object_cached, patched_get_model_id_version_from_endpoint, ): patched_get_object_cached.return_value = base64.b64decode("encodedimage") - patched_is_valid_model_id.return_value = True + patched_validate_model_id_and_get_type.return_value = True patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 556b99bc9c..3ec6f8aec3 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -32,7 +32,7 @@ JumpStartScriptScope, ) from functools import partial -from sagemaker.jumpstart.enums import JumpStartTag, MIMEType +from sagemaker.jumpstart.enums import JumpStartTag, MIMEType, JumpStartModelType from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, VulnerableJumpStartModelError, @@ -68,7 +68,7 @@ def test_get_jumpstart_content_bucket_override(): with patch("logging.Logger.info") as mocked_info_log: random_region = "random_region" assert "some-val" == utils.get_jumpstart_content_bucket(random_region) - mocked_info_log.assert_called_once_with("Using JumpStart bucket override: 'some-val'") + mocked_info_log.assert_called_with("Using JumpStart bucket override: 'some-val'") def test_get_jumpstart_gated_content_bucket(): @@ -1180,7 +1180,7 @@ def test_mime_type_enum_from_str(): class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") - def test_is_valid_model_id_true( + def test_validate_model_id_and_get_type_true( self, mock_get_model_specs: Mock, mock_get_manifest: Mock, @@ -1194,12 +1194,16 @@ def test_is_valid_model_id_true( mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION mock_s3_client_value = mock_session_value.s3_client - patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) + patched = partial( + utils.validate_model_id_and_get_type, sagemaker_session=mock_session_value + ) - with patch("sagemaker.jumpstart.utils.is_valid_model_id", patched): - self.assertTrue(utils.is_valid_model_id("bee")) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched): + self.assertTrue(utils.validate_model_id_and_get_type("bee")) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) mock_get_model_specs.assert_not_called() @@ -1213,14 +1217,20 @@ def test_is_valid_model_id_true( ] mock_get_model_specs.return_value = Mock(training_supported=True) - self.assertTrue(utils.is_valid_model_id("bee", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + self.assertTrue( + utils.validate_model_id_and_get_type("bee", script=JumpStartScriptScope.TRAINING) + ) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") - def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_manifest: Mock): + def test_validate_model_id_and_get_type_false( + self, mock_get_model_specs: Mock, mock_get_manifest: Mock + ): mock_get_manifest.return_value = [ Mock(model_id="ay"), Mock(model_id="bee"), @@ -1230,18 +1240,18 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION mock_s3_client_value = mock_session_value.s3_client - patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) + patched = partial( + utils.validate_model_id_and_get_type, sagemaker_session=mock_session_value + ) - with patch("sagemaker.jumpstart.utils.is_valid_model_id", patched): + with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched): - self.assertFalse(utils.is_valid_model_id("dee")) - self.assertFalse(utils.is_valid_model_id("")) - self.assertFalse(utils.is_valid_model_id(None)) - self.assertFalse(utils.is_valid_model_id(set())) + self.assertFalse(utils.validate_model_id_and_get_type("dee")) + self.assertFalse(utils.validate_model_id_and_get_type("")) + self.assertFalse(utils.validate_model_id_and_get_type(None)) + self.assertFalse(utils.validate_model_id_and_get_type(set())) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value - ) + mock_get_manifest.assert_called() mock_get_model_specs.assert_not_called() @@ -1253,30 +1263,48 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani Mock(model_id="bee"), Mock(model_id="see"), ] - self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + self.assertFalse( + utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING) + ) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) mock_get_manifest.reset_mock() - self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id("", script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id(None, script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id(set(), script=JumpStartScriptScope.TRAINING)) + self.assertFalse( + utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING) + ) + self.assertFalse( + utils.validate_model_id_and_get_type("", script=JumpStartScriptScope.TRAINING) + ) + self.assertFalse( + utils.validate_model_id_and_get_type(None, script=JumpStartScriptScope.TRAINING) + ) + self.assertFalse( + utils.validate_model_id_and_get_type(set(), script=JumpStartScriptScope.TRAINING) + ) mock_get_model_specs.assert_not_called() - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) mock_get_manifest.reset_mock() mock_get_model_specs.reset_mock() mock_get_model_specs.return_value = Mock(training_supported=False) - self.assertTrue(utils.is_valid_model_id("ay", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + self.assertTrue( + utils.validate_model_id_and_get_type("ay", script=JumpStartScriptScope.TRAINING) + ) + mock_get_manifest.assert_called_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + s3_client=mock_s3_client_value, + model_type=JumpStartModelType.PROPRIETARY, ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 146c6fd1f7..65fe10f7a7 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -28,12 +28,16 @@ JumpStartS3FileType, JumpStartModelHeader, ) +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest from tests.unit.sagemaker.jumpstart.constants import ( PROTOTYPICAL_MODEL_SPECS_DICT, BASE_MANIFEST, BASE_SPEC, + BASE_PROPRIETARY_MANIFEST, + BASE_PROPRIETARY_SPEC, BASE_HEADER, + BASE_PROPRIETARY_HEADER, SPECIAL_MODEL_SPECS_DICT, ) @@ -44,11 +48,16 @@ def get_header_from_base_header( model_id: str = None, semantic_version_str: str = None, version: str = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelHeader: if version and semantic_version_str: raise ValueError("Cannot specify both `version` and `semantic_version_str` fields.") + if model_type == JumpStartModelType.PROPRIETARY: + spec = copy.deepcopy(BASE_PROPRIETARY_HEADER) + return JumpStartModelHeader(spec) + if all( [ "pytorch" not in model_id, @@ -79,7 +88,10 @@ def get_header_from_base_header( def get_prototype_manifest( region: str = JUMPSTART_DEFAULT_REGION_NAME, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: + if model_type == JumpStartModelType.PROPRIETARY: + return [JumpStartModelHeader(spec) for spec in BASE_PROPRIETARY_MANIFEST] return [ get_header_from_base_header(region=region, model_id=model_id, version=version) for model_id in PROTOTYPICAL_MODEL_SPECS_DICT.keys() @@ -92,6 +104,7 @@ def get_prototype_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. @@ -107,6 +120,7 @@ def get_special_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. This is reserved @@ -122,6 +136,7 @@ def get_special_model_spec_for_inference_component_based_endpoint( model_id: str = None, version: str = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID and adding @@ -142,14 +157,22 @@ def get_spec_from_base_spec( _obj: JumpStartModelsCache = None, region: str = None, model_id: str = None, - semantic_version_str: str = None, + version_str: str = None, version: str = None, s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> JumpStartModelSpecs: - if version and semantic_version_str: + if version and version_str: raise ValueError("Cannot specify both `version` and `semantic_version_str` fields.") + if model_type == JumpStartModelType.PROPRIETARY: + spec = copy.deepcopy(BASE_PROPRIETARY_SPEC) + spec["version"] = version or version_str + spec["model_id"] = model_id + + return JumpStartModelSpecs(spec) + if all( [ "pytorch" not in model_id, @@ -172,7 +195,7 @@ def get_spec_from_base_spec( spec = copy.deepcopy(BASE_SPEC) - spec["version"] = version or semantic_version_str + spec["version"] = version or version_str spec["model_id"] = model_id return JumpStartModelSpecs(spec) @@ -185,19 +208,35 @@ def patched_retrieval_function( ) -> JumpStartCachedS3ContentValue: filetype, s3_key = key.file_type, key.s3_key - if filetype == JumpStartS3FileType.MANIFEST: + if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: return JumpStartCachedS3ContentValue( formatted_content=get_formatted_manifest(BASE_MANIFEST) ) - if filetype == JumpStartS3FileType.SPECS: + if filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: _, model_id, specs_version = s3_key.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") return JumpStartCachedS3ContentValue( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) ) + if filetype == JumpStartS3FileType.PROPRIETARY_MANIFEST: + return JumpStartCachedS3ContentValue( + formatted_content=get_formatted_manifest(BASE_PROPRIETARY_MANIFEST) + ) + + if filetype == JumpStartS3FileType.PROPRIETARY_SPECS: + _, model_id, specs_version = s3_key.split("/") + version = specs_version.replace("proprietary_specs_", "").replace(".json", "") + return JumpStartCachedS3ContentValue( + formatted_content=get_spec_from_base_spec( + model_id=model_id, + version=version, + model_type=JumpStartModelType.PROPRIETARY, + ) + ) + raise ValueError(f"Bad value for filetype: {filetype}") diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index ffc6000c91..608a32a005 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -18,6 +18,7 @@ import pytest from sagemaker import metric_definitions +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec @@ -25,10 +26,14 @@ mock_session = Mock(s3_client=mock_client) +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_default_metric_definitions(patched_get_model_specs): +def test_jumpstart_default_metric_definitions( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -47,7 +52,11 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -63,7 +72,11 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 000540e12e..8d75731b06 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -19,19 +19,24 @@ from sagemaker import model_uris from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.model_uris.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_common_model_uri( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -47,6 +52,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -64,6 +70,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -82,6 +89,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -100,6 +108,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index b0cef0e3d4..7b5e7a598d 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -18,6 +18,7 @@ import pytest from sagemaker import resource_requirements +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.jumpstart.artifacts.resource_requirements import ( REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP, @@ -26,10 +27,14 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_resource_requirements(patched_get_model_specs): +def test_jumpstart_resource_requirements( + patched_get_model_specs, patched_validate_model_id_and_get_type +): patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS region = "us-west-2" mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -50,6 +55,7 @@ def test_jumpstart_resource_requirements(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() @@ -103,9 +109,14 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode } +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): +def test_jumpstart_no_supported_resource_requirements( + patched_get_model_specs, patched_validate_model_id_and_get_type +): + patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "no-supported-instance-types-model", "*" region = "us-west-2" @@ -126,6 +137,7 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 3f38326608..c797ba3559 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -19,19 +19,24 @@ from sagemaker import script_uris from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.script_uris.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_common_script_uri( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -47,6 +52,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -64,6 +70,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -78,7 +85,11 @@ def test_jumpstart_common_script_uri( sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="*", + s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -97,6 +108,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index b22b61dc40..c2253726bf 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -19,18 +19,23 @@ from sagemaker import base_serializers, serializers from sagemaker.jumpstart.utils import verify_model_region_and_return_specs +from sagemaker.jumpstart.enums import JumpStartModelType from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_serializers( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -50,19 +55,24 @@ def test_jumpstart_default_serializers( model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_serializer_options( - patched_get_model_specs, patched_verify_model_region_and_return_specs + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, ): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS mock_client = boto3.client("s3") mock_session = Mock(s3_client=mock_client) @@ -89,4 +99,5 @@ def test_jumpstart_serializer_options( model_id=model_id, version=model_version, s3_client=mock_client, + model_type=JumpStartModelType.OPEN_WEIGHTS, )