Skip to content

Commit 58b37b0

Browse files
committed
change: Enable default model fn for cpu and gpu
1 parent f498c2f commit 58b37b0

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,14 @@ def default_model_fn(self, model_dir):
4949
# Client-framework is CPU only. But model will run in Elastic Inference server with CUDA.
5050
return torch.jit.load(model_path, map_location=torch.device('cpu'))
5151
else:
52-
raise NotImplementedError(textwrap.dedent("""
53-
Please provide a model_fn implementation.
54-
See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk
55-
"""))
52+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53+
model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME)
54+
if not os.path.exists(model_path):
55+
raise FileNotFoundError("Failed to load model with default model_fn: missing file {}."
56+
.format(DEFAULT_MODEL_FILENAME))
57+
model = torch.jit.load(model_path, map_location=device)
58+
model = model.to(device)
59+
return model
5660

5761
def default_input_fn(self, input_data, content_type):
5862
"""A default input_fn that can handle JSON, CSV and NPZ formats.

0 commit comments

Comments
 (0)