1515import os
1616import textwrap
1717
18- import torch
18+ import torch , torcheia
1919from sagemaker_inference import (
2020 content_types ,
2121 decoder ,
2828INFERENCE_ACCELERATOR_PRESENT_ENV = "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"
2929DEFAULT_MODEL_FILENAME = "model.pt"
3030
31+ torch ._C ._jit_set_profiling_executor (False )
32+ device = torch .device ("cpu" )
33+
3134
3235class DefaultPytorchInferenceHandler (default_inference_handler .DefaultInferenceHandler ):
3336 VALID_CONTENT_TYPES = (content_types .JSON , content_types .NPY )
@@ -47,7 +50,11 @@ def default_model_fn(self, model_dir):
4750 raise FileNotFoundError ("Failed to load model with default model_fn: missing file {}."
4851 .format (DEFAULT_MODEL_FILENAME ))
4952 # Client-framework is CPU only. But model will run in Elastic Inference server with CUDA.
50- return torch .jit .load (model_path , map_location = torch .device ('cpu' ))
53+ model = torch .jit .load (model_path , map_location = torch .device ('cpu' ))
54+ # attach_eia() is introduced in PyTorch Elastic Inference 1.5.1
55+ # by default attach to the 0th device
56+ model = torcheia .jit .attach_eia (model , 0 )
57+ return model
5158 else :
5259 raise NotImplementedError (textwrap .dedent ("""
5360 Please provide a model_fn implementation.
@@ -86,8 +93,8 @@ def default_predict_fn(self, data, model):
8693 model = model .to (device )
8794 input_data = data .to (device )
8895 model .eval ()
89- with torch .jit .optimized_execution (True , { "target_device" : "eia:0" } ):
90- output = model (input_data )
96+ with torch .jit .optimized_execution (True ):
97+ output = model . forward (input_data )
9198 else :
9299 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
93100 model = model .to (device )
0 commit comments