@@ -95,7 +95,9 @@ def __init__(
95
95
self ._model_names = self ._get_model_names ()
96
96
self ._context = None
97
97
98
- def predict (self , data , initial_args = None , target_model = None , target_variant = None ):
98
+ def predict (
99
+ self , data , initial_args = None , target_model = None , target_variant = None , inference_id = None
100
+ ):
99
101
"""Return the inference from the specified endpoint.
100
102
101
103
Args:
@@ -111,8 +113,10 @@ def predict(self, data, initial_args=None, target_model=None, target_variant=Non
111
113
in case of a multi model endpoint. Does not apply to endpoints hosting
112
114
single model (Default: None)
113
115
target_variant (str): The name of the production variant to run an inference
114
- request on (Default: None). Note that the ProductionVariant identifies the model
115
- you want to host and the resources you want to deploy for hosting it.
116
+ request on (Default: None). Note that the ProductionVariant identifies the
117
+ model you want to host and the resources you want to deploy for hosting it.
118
+ inference_id (str): If you provide a value, it is added to the captured data
119
+ when you enable data capture on the endpoint (Default: None).
116
120
117
121
Returns:
118
122
object: Inference for the given input. If a deserializer was specified when creating
@@ -121,7 +125,9 @@ def predict(self, data, initial_args=None, target_model=None, target_variant=Non
121
125
as is.
122
126
"""
123
127
124
- request_args = self ._create_request_args (data , initial_args , target_model , target_variant )
128
+ request_args = self ._create_request_args (
129
+ data , initial_args , target_model , target_variant , inference_id
130
+ )
125
131
response = self .sagemaker_session .sagemaker_runtime_client .invoke_endpoint (** request_args )
126
132
return self ._handle_response (response )
127
133
@@ -131,7 +137,9 @@ def _handle_response(self, response):
131
137
content_type = response .get ("ContentType" , "application/octet-stream" )
132
138
return self .deserializer .deserialize (response_body , content_type )
133
139
134
- def _create_request_args (self , data , initial_args = None , target_model = None , target_variant = None ):
140
+ def _create_request_args (
141
+ self , data , initial_args = None , target_model = None , target_variant = None , inference_id = None
142
+ ):
135
143
"""Placeholder docstring"""
136
144
args = dict (initial_args ) if initial_args else {}
137
145
@@ -150,6 +158,9 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
150
158
if target_variant :
151
159
args ["TargetVariant" ] = target_variant
152
160
161
+ if inference_id :
162
+ args ["InferenceId" ] = inference_id
163
+
153
164
data = self .serializer .serialize (data )
154
165
155
166
args ["Body" ] = data
0 commit comments