Skip to content

Commit a996d42

Browse files
committed
Create new folder for any model name test
1 parent c098a5c commit a996d42

File tree

4 files changed

+60
-9
lines changed

4 files changed

+60
-9
lines changed

test/integration/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
code_sub_dir = 'code'
2828
default_sub_dir = 'default_model'
2929
default_sub_eia_dir = 'default_model_eia'
30+
default_sub_traced_resnet_dir = 'default_traced_resnet'
3031

3132
model_cpu_dir = os.path.join(mnist_path, cpu_sub_dir)
3233
mnist_cpu_script = os.path.join(model_cpu_dir, code_sub_dir, 'mnist.py')
@@ -67,11 +68,13 @@
6768
default_model_tar = file_utils.make_tarfile(
6869
default_model_script, os.path.join(default_model_dir, "model.pt"), default_model_dir, script_path="code"
6970
)
71+
72+
default_traced_resnet_dir = os.path.join(resnet18_path, default_sub_traced_resnet_dir)
73+
default_traced_resnet_script = os.path.join(default_traced_resnet_dir, code_sub_dir, "resnet18.py")
7074
default_model_traced_resnet18_tar = file_utils.make_tarfile(
71-
default_model_script,
72-
os.path.join(default_model_dir, "traced_resnet18.pt"),
73-
default_model_dir,
74-
filename="traced_resnet18.tar.gz",
75+
default_traced_resnet_script,
76+
os.path.join(default_traced_resnet_dir, "traced_resnet18.pt"),
77+
default_traced_resnet_dir,
7578
script_path="code",
7679
)
7780

test/integration/sagemaker/test_default_inference.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from integration import (
2424
default_model_script,
2525
default_model_tar,
26+
default_traced_resnet_script,
2627
default_model_traced_resnet18_tar,
2728
default_model_eia_script,
2829
default_model_eia_tar,
@@ -74,11 +75,7 @@ def test_default_inference_any_model_name_gpu(sagemaker_session, image_uri, inst
7475
# Scripted model is serialized with torch.jit.save().
7576
# Default inference test doesn't need to instantiate model definition
7677
_test_default_inference(
77-
sagemaker_session,
78-
image_uri,
79-
instance_type,
80-
default_model_traced_resnet18_tar,
81-
default_model_script,
78+
sagemaker_session, image_uri, instance_type, default_model_traced_resnet18_tar, default_traced_resnet_script
8279
)
8380

8481

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import io
2+
import json
3+
import logging
4+
5+
import numpy as np
6+
import torch
7+
import torchvision.transforms as transforms
8+
from PIL import Image # Training container doesn't have this package
9+
10+
logger = logging.getLogger(__name__)
11+
logger.setLevel(logging.DEBUG)
12+
13+
14+
def transform_fn(model, payload, request_content_type, response_content_type):
15+
16+
logger.info("Invoking user-defined transform function")
17+
18+
if request_content_type and request_content_type != "application/octet-stream":
19+
raise RuntimeError(
20+
"Content type must be application/octet-stream. Provided: {0}".format(
21+
request_content_type
22+
)
23+
)
24+
25+
# preprocess
26+
decoded = Image.open(io.BytesIO(payload))
27+
preprocess = transforms.Compose(
28+
[
29+
transforms.Resize(256),
30+
transforms.CenterCrop(224),
31+
transforms.ToTensor(),
32+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33+
]
34+
)
35+
normalized = preprocess(decoded)
36+
batchified = normalized.unsqueeze(0)
37+
38+
# predict
39+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40+
batchified = batchified.to(device)
41+
result = model.forward(batchified)
42+
43+
# Softmax (assumes batch size 1)
44+
result = np.squeeze(result.cpu().detach().numpy())
45+
result_exp = np.exp(result - np.max(result))
46+
result = result_exp / np.sum(result_exp)
47+
48+
response_body = json.dumps(result.tolist())
49+
content_type = "application/json"
50+
51+
return response_body, content_type

0 commit comments

Comments
 (0)