|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
| 15 | +import json |
15 | 16 | import numpy as np
|
16 | 17 | import pytest
|
| 18 | +import requests |
17 | 19 | import sagemaker
|
18 | 20 | from sagemaker.pytorch import PyTorchModel
|
19 | 21 |
|
@@ -66,6 +68,13 @@ def test_default_inference_eia(sagemaker_session, image_uri, instance_type, acce
|
66 | 68 | def _test_default_inference(
|
67 | 69 | sagemaker_session, image_uri, instance_type, model_tar, mnist_script, accelerator_type=None
|
68 | 70 | ):
|
| 71 | + image_url = ( |
| 72 | + "https://raw.githubusercontent.com/aws/amazon-sagemaker-examples/master/" |
| 73 | + "sagemaker_neo_compilation_jobs/pytorch_torchvision/cat.jpg" |
| 74 | + ) |
| 75 | + img_data = requests.get(image_url).content |
| 76 | + with open('cat.jpg', 'wb') as file_obj: |
| 77 | + file_obj.write(img_data) |
69 | 78 | endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-pytorch-serving")
|
70 | 79 |
|
71 | 80 | model_data = sagemaker_session.upload_data(
|
@@ -94,8 +103,15 @@ def _test_default_inference(
|
94 | 103 | initial_instance_count=1, instance_type=instance_type, endpoint_name=endpoint_name
|
95 | 104 | )
|
96 | 105 |
|
97 |
| - batch_size = 100 |
98 |
| - data = np.random.rand(batch_size, 1, 28, 28).astype(np.float32) |
99 |
| - output = predictor.predict(data) |
100 |
| - |
101 |
| - assert output.shape == (batch_size, 10) |
| 106 | + if accelerator_type: |
| 107 | + batch_size = 100 |
| 108 | + data = np.random.rand(batch_size, 1, 28, 28).astype(np.float32) |
| 109 | + output = predictor.predict(data) |
| 110 | + assert output.shape == (batch_size, 10) |
| 111 | + else: |
| 112 | + with open("cat.jpg", "rb") as f: |
| 113 | + payload = f.read() |
| 114 | + payload = bytearray(payload) |
| 115 | + response = predictor.predict(payload) |
| 116 | + result = json.loads(response.decode()) |
| 117 | + assert len(result) == 1000 |
0 commit comments