15
15
import os
16
16
import textwrap
17
17
18
- import torch
18
+ import torch , torcheia
19
19
from sagemaker_inference import (
20
20
content_types ,
21
21
decoder ,
28
28
INFERENCE_ACCELERATOR_PRESENT_ENV = "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"
29
29
DEFAULT_MODEL_FILENAME = "model.pt"
30
30
31
+ torch ._C ._jit_set_profiling_executor (False )
32
+ device = torch .device ("cpu" )
33
+
31
34
32
35
class DefaultPytorchInferenceHandler (default_inference_handler .DefaultInferenceHandler ):
33
36
VALID_CONTENT_TYPES = (content_types .JSON , content_types .NPY )
@@ -47,7 +50,11 @@ def default_model_fn(self, model_dir):
47
50
raise FileNotFoundError ("Failed to load model with default model_fn: missing file {}."
48
51
.format (DEFAULT_MODEL_FILENAME ))
49
52
# Client-framework is CPU only. But model will run in Elastic Inference server with CUDA.
50
- return torch .jit .load (model_path , map_location = torch .device ('cpu' ))
53
+ model = torch .jit .load (model_path , map_location = torch .device ('cpu' ))
54
+ # attach_eia() is introduced in PyTorch Elastic Inference 1.5.1
55
+ # by default attach to the 0th device
56
+ model = torcheia .jit .attach_eia (model , 0 )
57
+ return model
51
58
else :
52
59
raise NotImplementedError (textwrap .dedent ("""
53
60
Please provide a model_fn implementation.
@@ -86,8 +93,8 @@ def default_predict_fn(self, data, model):
86
93
model = model .to (device )
87
94
input_data = data .to (device )
88
95
model .eval ()
89
- with torch .jit .optimized_execution (True , { "target_device" : "eia:0" } ):
90
- output = model (input_data )
96
+ with torch .jit .optimized_execution (True ):
97
+ output = model . forward (input_data )
91
98
else :
92
99
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
93
100
model = model .to (device )
0 commit comments