Skip to content

Commit 23b6d91

Browse files
committed
Fix inference request
1 parent 36b0b7d commit 23b6d91

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

test/integration/sagemaker/test_mnist_default_inference.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import json
1516
import numpy as np
1617
import pytest
18+
import requests
1719
import sagemaker
1820
from sagemaker.pytorch import PyTorchModel
1921

@@ -66,6 +68,13 @@ def test_default_inference_eia(sagemaker_session, image_uri, instance_type, acce
6668
def _test_default_inference(
6769
sagemaker_session, image_uri, instance_type, model_tar, mnist_script, accelerator_type=None
6870
):
71+
image_url = (
72+
"https://raw.githubusercontent.com/aws/amazon-sagemaker-examples/master/"
73+
"sagemaker_neo_compilation_jobs/pytorch_torchvision/cat.jpg"
74+
)
75+
img_data = requests.get(image_url).content
76+
with open('cat.jpg', 'wb') as file_obj:
77+
file_obj.write(img_data)
6978
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-pytorch-serving")
7079

7180
model_data = sagemaker_session.upload_data(
@@ -94,8 +103,15 @@ def _test_default_inference(
94103
initial_instance_count=1, instance_type=instance_type, endpoint_name=endpoint_name
95104
)
96105

97-
batch_size = 100
98-
data = np.random.rand(batch_size, 1, 28, 28).astype(np.float32)
99-
output = predictor.predict(data)
100-
101-
assert output.shape == (batch_size, 10)
106+
if accelerator_type:
107+
batch_size = 100
108+
data = np.random.rand(batch_size, 1, 28, 28).astype(np.float32)
109+
output = predictor.predict(data)
110+
assert output.shape == (batch_size, 10)
111+
else:
112+
with open("cat.jpg", "rb") as f:
113+
payload = f.read()
114+
payload = bytearray(payload)
115+
response = predictor.predict(payload)
116+
result = json.loads(response.decode())
117+
assert len(result) == 1000

0 commit comments

Comments
 (0)