From 455f26606f5cb610801bda45a70c98a39adc5574 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 23 May 2024 09:32:50 -0700 Subject: [PATCH 1/6] RoutingConfig --- src/sagemaker/enums.py | 7 +++++++ src/sagemaker/model.py | 14 +++++++++++++- src/sagemaker/serve/builder/model_builder.py | 2 +- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/enums.py b/src/sagemaker/enums.py index 5b4d0d6790..4ffe248768 100644 --- a/src/sagemaker/enums.py +++ b/src/sagemaker/enums.py @@ -28,3 +28,10 @@ class EndpointType(Enum): INFERENCE_COMPONENT_BASED = ( "InferenceComponentBased" # Amazon SageMaker Inference Component Based Endpoint ) + + +class RoutingStrategy(Enum): + """Strategy for routing https traffics.""" + + RANDOM = "RANDOM" + LEAST_OUTSTANDING_REQUESTS = "LEAST_OUTSTANDING_REQUESTS" diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index fd21b6342e..d424ce5627 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -20,7 +20,7 @@ import os import re import copy -from typing import List, Dict, Optional, Union +from typing import List, Dict, Optional, Union, Any import sagemaker from sagemaker import ( @@ -1309,6 +1309,7 @@ def deploy( resources: Optional[ResourceRequirements] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, managed_instance_scaling: Optional[str] = None, + routing_config: Optional[Dict[str, Any]] = None, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -1406,6 +1407,15 @@ def deploy( Endpoint. (Default: None). endpoint_type (Optional[EndpointType]): The type of an endpoint used to deploy models. (Default: EndpointType.MODEL_BASED). + routing_config (Optional[Dict[str, Any]): Settings the control how the endpoint routes incoming + traffic to the instances that the endpoint hosts. + Currently, support dictionary key ``RoutingStrategy``. + + .. code:: python + + { + "RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM + } Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -1543,6 +1553,7 @@ def deploy( model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, managed_instance_scaling=managed_instance_scaling_config, + routing_config=routing_config, ) self.sagemaker_session.endpoint_from_production_variants( @@ -1625,6 +1636,7 @@ def deploy( volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, + routing_config=routing_config, ) if endpoint_name: self.endpoint_name = endpoint_name diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 1fe75065d5..44bc46b00b 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -169,7 +169,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, in order for model builder to build the artifacts correctly (according to the model server). Possible values for this argument are ``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``, - ``TRITON``,``TGI``, and ``TEI``. + ``TRITON``, ``TGI``, and ``TEI``. model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata. Currently, ``HF_TASK`` is overridable for HuggingFace model. HF_TASK should be set for new models without task metadata in the Hub, adding unsupported task types will throw From 4ba3c72d87c9346e343412c65a5298b99741afe6 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 23 May 2024 12:36:10 -0700 Subject: [PATCH 2/6] Refactoring --- src/sagemaker/huggingface/model.py | 1 + src/sagemaker/model.py | 3 +++ src/sagemaker/utils.py | 32 ++++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index f71dca0ac8..662baecae6 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -334,6 +334,7 @@ def deploy( endpoint_type=kwargs.get("endpoint_type", None), resources=kwargs.get("resources", None), managed_instance_scaling=kwargs.get("managed_instance_scaling", None), + routing_config=kwargs.get("routing_config", None), ) def register( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index d424ce5627..4cac06149c 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -66,6 +66,7 @@ resolve_nested_dict_value_from_config, format_tags, Tags, + _resolve_routing_config, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor @@ -1468,6 +1469,8 @@ def deploy( if self.role is None: raise ValueError("Role can not be null for deploying a model") + routing_config = _resolve_routing_config(routing_config) + if ( inference_recommendation_id is not None or self.inference_recommender_job_results is not None diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 0436c0afea..5cf89033dc 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -44,6 +44,7 @@ _log_sagemaker_config_single_substitution, _log_sagemaker_config_merge, ) +from sagemaker.enums import RoutingStrategy from sagemaker.session_settings import SessionSettings from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string from sagemaker.workflow.entities import PipelineVariable @@ -1655,3 +1656,34 @@ def deep_override_dict( ) flattened_dict1.update(flattened_dict2) return unflatten_dict(flattened_dict1) if flattened_dict1 else {} + + +def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """Resolve Routing Config + + Args: + routing_config (Optional[Dict[str, Any]]): The routing config. + + Returns: + Optional[Dict[str, Any]]: The resolved routing config. + + Raises: + ValueError: If the RoutingStrategy is invalid. + """ + + if routing_config: + routing_strategy = routing_config.get("RoutingStrategy", None) + if routing_strategy: + if isinstance(routing_strategy, RoutingStrategy): + return {"RoutingStrategy": routing_strategy.name} + if isinstance(routing_strategy, str) and ( + routing_strategy.lower() == RoutingStrategy.RANDOM.name.lower() + or routing_strategy.lower() + == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name.lower() + ): + return {"RoutingStrategy": routing_strategy} + raise ValueError( + "RoutingStrategy must be either RoutingStrategy.RANDOM " + "or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS" + ) + return None From b405009a14ed3301eb221f1b38242a98c853ce4e Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 23 May 2024 13:03:12 -0700 Subject: [PATCH 3/6] Docstring --- src/sagemaker/enums.py | 5 +++++ src/sagemaker/jumpstart/factory/model.py | 2 ++ src/sagemaker/jumpstart/model.py | 2 ++ src/sagemaker/jumpstart/types.py | 3 +++ 4 files changed, 12 insertions(+) diff --git a/src/sagemaker/enums.py b/src/sagemaker/enums.py index 4ffe248768..f02b275cbe 100644 --- a/src/sagemaker/enums.py +++ b/src/sagemaker/enums.py @@ -34,4 +34,9 @@ class RoutingStrategy(Enum): """Strategy for routing https traffics.""" RANDOM = "RANDOM" + """The endpoint routes each request to a randomly chosen instance. + """ LEAST_OUTSTANDING_REQUESTS = "LEAST_OUTSTANDING_REQUESTS" + """The endpoint routes requests to the specific instances that have + more capacity to process them. + """ diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 28746990e3..0df9c0fb7a 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -555,6 +555,7 @@ def get_deploy_kwargs( resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: Optional[EndpointType] = None, + **kwargs, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -586,6 +587,7 @@ def get_deploy_kwargs( accept_eula=accept_eula, endpoint_logging=endpoint_logging, resources=resources, + **kwargs, ) deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 4529bc11b9..f7081a5144 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -496,6 +496,7 @@ def deploy( resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, + **kwargs, ) -> PredictorBase: """Creates endpoint by calling base ``Model`` class `deploy` method. @@ -625,6 +626,7 @@ def deploy( managed_instance_scaling=managed_instance_scaling, endpoint_type=endpoint_type, model_type=self.model_type, + **kwargs, ) if ( self.model_type == JumpStartModelType.PROPRIETARY diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 05c38da266..adc3ae9d9d 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1614,6 +1614,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "endpoint_logging", "resources", "endpoint_type", + "routing_config", ] SERIALIZATION_EXCLUSION_SET = { @@ -1658,6 +1659,7 @@ def __init__( endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, endpoint_type: Optional[EndpointType] = None, + **kwargs, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -1690,6 +1692,7 @@ def __init__( self.endpoint_logging = endpoint_logging self.resources = resources self.endpoint_type = endpoint_type + self.routing_config = kwargs.get("routing_config") class JumpStartEstimatorInitKwargs(JumpStartKwargs): From 69ce32ed8497d5f9082abba77bb1816d2881652d Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 23 May 2024 14:41:09 -0700 Subject: [PATCH 4/6] UT --- .../sagemaker/jumpstart/model/test_model.py | 2 +- tests/unit/sagemaker/model/test_deploy.py | 5 ++++ tests/unit/test_utils.py | 29 +++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 8b00eb5bcd..d1f31e61a0 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -699,7 +699,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): and reach out to JumpStart team.""" init_args_to_skip: Set[str] = set([]) - deploy_args_to_skip: Set[str] = set(["kwargs"]) + deploy_args_to_skip: Set[str] = set(["routing_config"]) parent_class_init = Model.__init__ parent_class_init_args = set(signature(parent_class_init).parameters.keys()) diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 953cbe775c..69ea2c1f56 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -125,6 +125,7 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + routing_config=None, ) sagemaker_session.create_model.assert_called_with( @@ -184,6 +185,7 @@ def test_deploy_accelerator_type( volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + routing_config=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -506,6 +508,7 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model, volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + routing_config=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -938,6 +941,7 @@ def test_deploy_customized_volume_size_and_timeout( volume_size=volume_size_gb, model_data_download_timeout=model_data_download_timeout_sec, container_startup_health_check_timeout=startup_health_check_timeout_sec, + routing_config=None, ) sagemaker_session.create_model.assert_called_with( @@ -987,6 +991,7 @@ def test_deploy_with_resources(sagemaker_session, name_from_base, production_var volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + routing_config=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( name=name_from_base(MODEL_NAME), diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 81d8279e6d..d81984e81f 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -30,6 +30,7 @@ from mock import call, patch, Mock, MagicMock, PropertyMock import sagemaker +from sagemaker.enums import RoutingStrategy from sagemaker.experiments._run_context import _RunContext from sagemaker.session_settings import SessionSettings from sagemaker.utils import ( @@ -50,6 +51,7 @@ _is_bad_link, custom_extractall_tarfile, can_model_package_source_uri_autopopulate, + _resolve_routing_config, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -1866,3 +1868,30 @@ def test_deep_override_skip_keys(self): expected_result = {"a": 1, "b": {"x": 20, "y": 3, "z": 30}, "c": [4, 5]} self.assertEqual(deep_override_dict(dict1, dict2, skip_keys=["c", "d"]), expected_result) + + +@pytest.mark.parametrize( + "routing_config, expected", + [ + ({"RoutingStrategy": RoutingStrategy.RANDOM}, {"RoutingStrategy": "RANDOM"}), + ({"RoutingStrategy": "RANDOM"}, {"RoutingStrategy": "RANDOM"}), + ( + {"RoutingStrategy": RoutingStrategy.LEAST_OUTSTANDING_REQUESTS}, + {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"}, + ), + ( + {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"}, + {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"}, + ), + ({"RoutingStrategy": None}, None), + (None, None), + ], +) +def test_resolve_routing_config(routing_config, expected): + res = _resolve_routing_config(routing_config) + + assert res == expected + + +def test_resolve_routing_config_ex(): + pytest.raises(ValueError, lambda: _resolve_routing_config({"RoutingStrategy": "Invalid"})) From 4e98e0f1687ecaadd11fa1962a59621740f792fc Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 23 May 2024 15:23:12 -0700 Subject: [PATCH 5/6] Refactoring --- src/sagemaker/jumpstart/factory/model.py | 4 ++-- src/sagemaker/jumpstart/model.py | 8 +++++--- src/sagemaker/jumpstart/types.py | 4 ++-- src/sagemaker/model.py | 6 +++--- src/sagemaker/utils.py | 5 ++--- tests/unit/sagemaker/jumpstart/model/test_model.py | 2 +- 6 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 0df9c0fb7a..41c1ca4437 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -555,7 +555,7 @@ def get_deploy_kwargs( resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: Optional[EndpointType] = None, - **kwargs, + routing_config: Optional[Dict[str, Any]] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -587,7 +587,7 @@ def get_deploy_kwargs( accept_eula=accept_eula, endpoint_logging=endpoint_logging, resources=resources, - **kwargs, + routing_config=routing_config, ) deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index f7081a5144..994193de3e 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Any from botocore.exceptions import ClientError from sagemaker import payloads @@ -496,7 +496,7 @@ def deploy( resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, - **kwargs, + routing_config: Optional[Dict[str, Any]] = None, ) -> PredictorBase: """Creates endpoint by calling base ``Model`` class `deploy` method. @@ -591,6 +591,8 @@ def deploy( endpoint. endpoint_type (EndpointType): The type of endpoint used to deploy models. (Default: EndpointType.MODEL_BASED). + routing_config (Optional[Dict]): Settings the control how the endpoint routes + incoming traffic to the instances that the endpoint hosts. Raises: MarketplaceModelSubscriptionError: If the caller is not subscribed to the model. @@ -626,7 +628,7 @@ def deploy( managed_instance_scaling=managed_instance_scaling, endpoint_type=endpoint_type, model_type=self.model_type, - **kwargs, + routing_config=routing_config, ) if ( self.model_type == JumpStartModelType.PROPRIETARY diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index adc3ae9d9d..bea125d423 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1659,7 +1659,7 @@ def __init__( endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, endpoint_type: Optional[EndpointType] = None, - **kwargs, + routing_config: Optional[Dict[str, Any]] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -1692,7 +1692,7 @@ def __init__( self.endpoint_logging = endpoint_logging self.resources = resources self.endpoint_type = endpoint_type - self.routing_config = kwargs.get("routing_config") + self.routing_config = routing_config class JumpStartEstimatorInitKwargs(JumpStartKwargs): diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 4cac06149c..1bb6cb2e5c 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1414,9 +1414,9 @@ def deploy( .. code:: python - { - "RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM - } + { + "RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM + } Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 5cf89033dc..ad2aec8b8d 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1677,9 +1677,8 @@ def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optiona if isinstance(routing_strategy, RoutingStrategy): return {"RoutingStrategy": routing_strategy.name} if isinstance(routing_strategy, str) and ( - routing_strategy.lower() == RoutingStrategy.RANDOM.name.lower() - or routing_strategy.lower() - == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name.lower() + routing_strategy.upper() == RoutingStrategy.RANDOM.name + or routing_strategy.upper() == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name ): return {"RoutingStrategy": routing_strategy} raise ValueError( diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index d1f31e61a0..8b00eb5bcd 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -699,7 +699,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): and reach out to JumpStart team.""" init_args_to_skip: Set[str] = set([]) - deploy_args_to_skip: Set[str] = set(["routing_config"]) + deploy_args_to_skip: Set[str] = set(["kwargs"]) parent_class_init = Model.__init__ parent_class_init_args = set(signature(parent_class_init).parameters.keys()) From 0d6184513f3f3e465f7b47698ec99a004aece670 Mon Sep 17 00:00:00 2001 From: Jonathan Makunga Date: Thu, 23 May 2024 15:26:27 -0700 Subject: [PATCH 6/6] Refactoring --- src/sagemaker/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index ad2aec8b8d..430effefa3 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1680,7 +1680,7 @@ def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optiona routing_strategy.upper() == RoutingStrategy.RANDOM.name or routing_strategy.upper() == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name ): - return {"RoutingStrategy": routing_strategy} + return {"RoutingStrategy": routing_strategy.upper()} raise ValueError( "RoutingStrategy must be either RoutingStrategy.RANDOM " "or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS"