Skip to content

Commit d83cf98

Browse files
author
John Barboza
committed
make changes for pytorch 1.5.1 eia
1 parent 6610a41 commit d83cf98

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
import textwrap
1717

18-
import torch
18+
import torch, torcheia
1919
from sagemaker_inference import (
2020
content_types,
2121
decoder,
@@ -28,6 +28,9 @@
2828
INFERENCE_ACCELERATOR_PRESENT_ENV = "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"
2929
DEFAULT_MODEL_FILENAME = "model.pt"
3030

31+
torch._C._jit_set_profiling_executor(False)
32+
device = torch.device("cpu")
33+
3134

3235
class DefaultPytorchInferenceHandler(default_inference_handler.DefaultInferenceHandler):
3336
VALID_CONTENT_TYPES = (content_types.JSON, content_types.NPY)
@@ -47,7 +50,11 @@ def default_model_fn(self, model_dir):
4750
raise FileNotFoundError("Failed to load model with default model_fn: missing file {}."
4851
.format(DEFAULT_MODEL_FILENAME))
4952
# Client-framework is CPU only. But model will run in Elastic Inference server with CUDA.
50-
return torch.jit.load(model_path, map_location=torch.device('cpu'))
53+
model = torch.jit.load(model_path, map_location=torch.device('cpu'))
54+
# attach_eia() is introduced in PyTorch Elastic Inference 1.5.1
55+
# by default attach to the 0th device
56+
model = torcheia.jit.attach_eia(model, 0)
57+
return model
5158
else:
5259
raise NotImplementedError(textwrap.dedent("""
5360
Please provide a model_fn implementation.
@@ -86,8 +93,8 @@ def default_predict_fn(self, data, model):
8693
model = model.to(device)
8794
input_data = data.to(device)
8895
model.eval()
89-
with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
90-
output = model(input_data)
96+
with torch.jit.optimized_execution(True):
97+
output = model.forward(input_data)
9198
else:
9299
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93100
model = model.to(device)

0 commit comments

Comments
 (0)