Skip to content

Commit 44372d4

Browse files
committed
feature: add inference_id to predict
1 parent 9f813e2 commit 44372d4

File tree

3 files changed

+58
-10
lines changed

3 files changed

+58
-10
lines changed

src/sagemaker/local/local_session.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -426,17 +426,27 @@ def invoke_endpoint(
426426
CustomAttributes=None,
427427
TargetModel=None,
428428
TargetVariant=None,
429+
InferenceId=None,
429430
):
430431
"""Invoke the endpoint.
431432
432433
Args:
433-
Body:
434-
EndpointName:
435-
Accept: (Default value = None)
436-
CustomAttributes: (Default value = None)
434+
Body: Input data for which you want the model to provide inference.
435+
EndpointName: The name of the endpoint that you specified when you
436+
created the endpoint using the CreateEndpoint API.
437+
ContentType: The MIME type of the input data in the request body (Default value = None)
438+
Accept: The desired MIME type of the inference in the response (Default value = None)
439+
CustomAttributes: Provides additional information about a request for an inference
440+
submitted to a model hosted at an Amazon SageMaker endpoint (Default value = None)
441+
TargetModel: The model to request for inference when invoking a multi-model endpoint
442+
(Default value = None)
443+
TargetVariant: Specify the production variant to send the inference request to when
444+
invoking an endpoint that is running two or more variants (Default value = None)
445+
InferenceId: If you provide a value, it is added to the captured data when you enable
446+
data capture on the endpoint (Default value = None)
437447
438448
Returns:
439-
449+
object: Inference for the given input.
440450
"""
441451
url = "http://localhost:%s/invocations" % self.serving_port
442452
headers = {}
@@ -456,6 +466,9 @@ def invoke_endpoint(
456466
if TargetVariant is not None:
457467
headers["X-Amzn-SageMaker-Target-Variant"] = TargetVariant
458468

469+
if InferenceId is not None:
470+
headers["X-Amzn-SageMaker-Inference-Id"] = InferenceId
471+
459472
r = self.http.request("POST", url, body=Body, preload_content=False, headers=headers)
460473

461474
return {"Body": r, "ContentType": Accept}

src/sagemaker/predictor.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def __init__(
9595
self._model_names = self._get_model_names()
9696
self._context = None
9797

98-
def predict(self, data, initial_args=None, target_model=None, target_variant=None):
98+
def predict(
99+
self, data, initial_args=None, target_model=None, target_variant=None, inference_id=None
100+
):
99101
"""Return the inference from the specified endpoint.
100102
101103
Args:
@@ -111,8 +113,10 @@ def predict(self, data, initial_args=None, target_model=None, target_variant=Non
111113
in case of a multi model endpoint. Does not apply to endpoints hosting
112114
single model (Default: None)
113115
target_variant (str): The name of the production variant to run an inference
114-
request on (Default: None). Note that the ProductionVariant identifies the model
115-
you want to host and the resources you want to deploy for hosting it.
116+
request on (Default: None). Note that the ProductionVariant identifies the
117+
model you want to host and the resources you want to deploy for hosting it.
118+
inference_id (str): If you provide a value, it is added to the captured data
119+
when you enable data capture on the endpoint (Default: None).
116120
117121
Returns:
118122
object: Inference for the given input. If a deserializer was specified when creating
@@ -121,7 +125,9 @@ def predict(self, data, initial_args=None, target_model=None, target_variant=Non
121125
as is.
122126
"""
123127

124-
request_args = self._create_request_args(data, initial_args, target_model, target_variant)
128+
request_args = self._create_request_args(
129+
data, initial_args, target_model, target_variant, inference_id
130+
)
125131
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
126132
return self._handle_response(response)
127133

@@ -131,7 +137,9 @@ def _handle_response(self, response):
131137
content_type = response.get("ContentType", "application/octet-stream")
132138
return self.deserializer.deserialize(response_body, content_type)
133139

134-
def _create_request_args(self, data, initial_args=None, target_model=None, target_variant=None):
140+
def _create_request_args(
141+
self, data, initial_args=None, target_model=None, target_variant=None, inference_id=None
142+
):
135143
"""Placeholder docstring"""
136144
args = dict(initial_args) if initial_args else {}
137145

@@ -150,6 +158,9 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
150158
if target_variant:
151159
args["TargetVariant"] = target_variant
152160

161+
if inference_id:
162+
args["InferenceId"] = inference_id
163+
153164
data = self.serializer.serialize(data)
154165

155166
args["Body"] = data

tests/unit/test_predictor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
RETURN_VALUE = 0
3232
CSV_RETURN_VALUE = "1,2,3\r\n"
3333
PRODUCTION_VARIANT_1 = "PRODUCTION_VARIANT_1"
34+
INFERENCE_ID = "inference-id"
3435

3536
ENDPOINT_DESC = {"EndpointArn": "foo", "EndpointConfigName": ENDPOINT}
3637

@@ -98,6 +99,29 @@ def test_predict_call_with_target_variant():
9899
assert result == RETURN_VALUE
99100

100101

102+
def test_predict_call_with_inference_id():
103+
sagemaker_session = empty_sagemaker_session()
104+
predictor = Predictor(ENDPOINT, sagemaker_session)
105+
106+
data = "untouched"
107+
result = predictor.predict(data, inference_id=INFERENCE_ID)
108+
109+
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
110+
111+
expected_request_args = {
112+
"Accept": DEFAULT_ACCEPT,
113+
"Body": data,
114+
"ContentType": DEFAULT_CONTENT_TYPE,
115+
"EndpointName": ENDPOINT,
116+
"InferenceId": INFERENCE_ID,
117+
}
118+
119+
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
120+
assert kwargs == expected_request_args
121+
122+
assert result == RETURN_VALUE
123+
124+
101125
def test_multi_model_predict_call():
102126
sagemaker_session = empty_sagemaker_session()
103127
predictor = Predictor(ENDPOINT, sagemaker_session)

0 commit comments

Comments
 (0)