-
Notifications
You must be signed in to change notification settings - Fork 72
change: Enable default model fn for cpu and gpu #107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
58b37b0
change: Enable default model fn for cpu and gpu
saimidu 455ee5a
Add tests
saimidu b04f6a4
Use resnet18 model for default inference
saimidu fb924a9
Fix bugs in path definition
saimidu 36b0b7d
Fix bug
saimidu 23b6d91
Fix inference request
saimidu 6f8c71d
Change Predictor Class on Model object
saimidu bf0e407
Fix flake8 failures
saimidu a20826b
Rename test file
saimidu a93860a
Disable EIA tests until new PT EIA image is available
saimidu 7eb0737
Auto load any model name
saimidu ed36950
Add test
saimidu c098a5c
Run test on GPU
saimidu a996d42
Create new folder for any model name test
saimidu 92def12
Add specific name
saimidu 414c83a
Address comments
saimidu 94dd4d5
Add unit tests for unknown model names
saimidu 945c677
Fix unit tests
saimidu 56be3c2
Fix flake8 failure
saimidu 4b19f3d
Fix test formatting
saimidu e8a57ce
Add Exception for model load failure
saimidu 9953787
Add unit test for model load failure
saimidu 17a262f
Fix issues
saimidu 93f176c
Fix lambda
saimidu 26b9d86
Fix
saimidu 4729dcc
Fix pattern
saimidu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import json | ||
import numpy as np | ||
import pytest | ||
import requests | ||
import sagemaker | ||
from sagemaker.predictor import RealTimePredictor | ||
from sagemaker.pytorch import PyTorchModel, PyTorchPredictor | ||
|
||
from integration import ( | ||
default_model_script, | ||
default_model_tar, | ||
default_traced_resnet_script, | ||
default_model_traced_resnet18_tar, | ||
default_model_eia_script, | ||
default_model_eia_tar, | ||
) | ||
from integration.sagemaker.timeout import timeout_and_delete_endpoint | ||
|
||
|
||
@pytest.mark.cpu_test | ||
def test_default_inference_cpu(sagemaker_session, image_uri, instance_type): | ||
instance_type = instance_type or "ml.c4.xlarge" | ||
# Scripted model is serialized with torch.jit.save(). | ||
# Default inference test doesn't need to instantiate model definition | ||
_test_default_inference( | ||
sagemaker_session, image_uri, instance_type, default_model_tar, default_model_script | ||
) | ||
|
||
|
||
@pytest.mark.gpu_test | ||
def test_default_inference_gpu(sagemaker_session, image_uri, instance_type): | ||
instance_type = instance_type or "ml.p2.xlarge" | ||
# Scripted model is serialized with torch.jit.save(). | ||
# Default inference test doesn't need to instantiate model definition | ||
_test_default_inference( | ||
sagemaker_session, image_uri, instance_type, default_model_tar, default_model_script | ||
) | ||
|
||
|
||
@pytest.mark.skip( | ||
reason="Latest EIA version - 1.5.1 uses mms. Enable when EIA images use torchserve" | ||
) | ||
@pytest.mark.eia_test | ||
def test_default_inference_eia(sagemaker_session, image_uri, instance_type, accelerator_type): | ||
instance_type = instance_type or "ml.c4.xlarge" | ||
# Scripted model is serialized with torch.jit.save(). | ||
# Default inference test doesn't need to instantiate model definition | ||
_test_default_inference( | ||
sagemaker_session, | ||
image_uri, | ||
instance_type, | ||
default_model_eia_tar, | ||
default_model_eia_script, | ||
accelerator_type=accelerator_type, | ||
) | ||
|
||
|
||
@pytest.mark.gpu_test | ||
def test_default_inference_any_model_name_gpu(sagemaker_session, image_uri, instance_type): | ||
instance_type = instance_type or "ml.p2.xlarge" | ||
# Scripted model is serialized with torch.jit.save(). | ||
# Default inference test doesn't need to instantiate model definition | ||
_test_default_inference( | ||
sagemaker_session, | ||
image_uri, | ||
instance_type, | ||
default_model_traced_resnet18_tar, | ||
default_traced_resnet_script, | ||
) | ||
|
||
|
||
def _test_default_inference( | ||
sagemaker_session, image_uri, instance_type, model_tar, mnist_script, accelerator_type=None | ||
): | ||
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-pytorch-serving") | ||
|
||
model_data = sagemaker_session.upload_data( | ||
path=model_tar, | ||
key_prefix="sagemaker-pytorch-serving/models", | ||
) | ||
|
||
pytorch = PyTorchModel( | ||
model_data=model_data, | ||
role="SageMakerRole", | ||
predictor_cls=RealTimePredictor if not accelerator_type else PyTorchPredictor, | ||
entry_point=mnist_script, | ||
image=image_uri, | ||
sagemaker_session=sagemaker_session, | ||
) | ||
with timeout_and_delete_endpoint(endpoint_name, sagemaker_session, minutes=30): | ||
predictor = pytorch.deploy( | ||
initial_instance_count=1, | ||
instance_type=instance_type, | ||
accelerator_type=accelerator_type, | ||
endpoint_name=endpoint_name, | ||
) | ||
|
||
if accelerator_type: | ||
batch_size = 100 | ||
data = np.random.rand(batch_size, 1, 28, 28).astype(np.float32) | ||
output = predictor.predict(data) | ||
assert output.shape == (batch_size, 10) | ||
saimidu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
image_url = ( | ||
"https://raw.githubusercontent.com/aws/amazon-sagemaker-examples/master/" | ||
"sagemaker_neo_compilation_jobs/pytorch_torchvision/cat.jpg" | ||
) | ||
img_data = requests.get(image_url).content | ||
with open("cat.jpg", "wb") as file_obj: | ||
file_obj.write(img_data) | ||
with open("cat.jpg", "rb") as f: | ||
payload = f.read() | ||
payload = bytearray(payload) | ||
response = predictor.predict(payload) | ||
result = json.loads(response.decode()) | ||
assert len(result) == 1000 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
import logging | ||
import os | ||
import sys | ||
|
||
import torch | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.DEBUG) | ||
logger.addHandler(logging.StreamHandler(sys.stdout)) | ||
|
||
|
||
def predict_fn(input_data, model): | ||
logger.info('Performing EIA inference with Torch JIT context with input of size {}'.format(input_data.shape)) | ||
# With EI, client instance should be CPU for cost-efficiency. | ||
# Sub-graphs with unsupported arguments run locally. Server runs with CUDA | ||
device = torch.device('cpu') | ||
model = model.to(device) | ||
input_data = input_data.to(device) | ||
with torch.no_grad(): | ||
# Set the target device to the accelerator ordinal | ||
with torch.jit.optimized_execution(True, {'target_device': 'eia:0'}): | ||
return model(input_data) |
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import io | ||
import json | ||
import logging | ||
|
||
import numpy as np | ||
import torch | ||
import torchvision.transforms as transforms | ||
from PIL import Image # Training container doesn't have this package | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.DEBUG) | ||
|
||
|
||
def transform_fn(model, payload, request_content_type, response_content_type): | ||
|
||
logger.info("Invoking user-defined transform function") | ||
|
||
if request_content_type and request_content_type != "application/octet-stream": | ||
raise RuntimeError( | ||
"Content type must be application/octet-stream. Provided: {0}".format( | ||
request_content_type | ||
) | ||
) | ||
|
||
# preprocess | ||
decoded = Image.open(io.BytesIO(payload)) | ||
preprocess = transforms.Compose( | ||
[ | ||
transforms.Resize(256), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | ||
] | ||
) | ||
normalized = preprocess(decoded) | ||
batchified = normalized.unsqueeze(0) | ||
|
||
# predict | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
batchified = batchified.to(device) | ||
result = model.forward(batchified) | ||
|
||
# Softmax (assumes batch size 1) | ||
result = np.squeeze(result.cpu().detach().numpy()) | ||
result_exp = np.exp(result - np.max(result)) | ||
result = result_exp / np.sum(result_exp) | ||
|
||
response_body = json.dumps(result.tolist()) | ||
content_type = "application/json" | ||
|
||
return response_body, content_type |
Binary file not shown.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.