Skip to content

Commit b0ca952

Browse files
authored
doc: Explain why default model_fn loads PyTorch-EI models to CPU by default (#1404)
1 parent 00ab41d commit b0ca952

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

doc/using_pytorch.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ Load a Model
290290
------------
291291

292292
Before a model can be served, it must be loaded. The SageMaker PyTorch model server loads your model by invoking a
293-
``model_fn`` function that you must provide in your script. The ``model_fn`` should have the following signature:
293+
``model_fn`` function that you must provide in your script when you are not using Elastic Inference. The ``model_fn`` should have the following signature:
294294

295295
.. code:: python
296296
@@ -316,7 +316,11 @@ It loads the model parameters from a ``model.pth`` file in the SageMaker model d
316316
However, if you are using PyTorch Elastic Inference, you do not have to provide a ``model_fn`` since the PyTorch serving
317317
container has a default one for you. But please note that if you are utilizing the default ``model_fn``, please save
318318
your ScriptModule as ``model.pt``. If you are implementing your own ``model_fn``, please use TorchScript and ``torch.jit.save``
319-
to save your ScriptModule, then load it in your ``model_fn`` with ``torch.jit.load``. For more information on inference script, please refer to:
319+
to save your ScriptModule, then load it in your ``model_fn`` with ``torch.jit.load(..., map_location=torch.device('cpu'))``.
320+
321+
The client-side Elastic Inference framework is CPU-only, even though inference still happens in a CUDA context on the server. Thus, the default ``model_fn`` for Elastic Inference loads the model to CPU. Tracing models may lead to tensor creation on a specific device, which may cause device-related errors when loading a model onto a different device. Providing an explicit ``map_location=torch.device('cpu')`` argument forces all tensors to CPU.
322+
323+
For more information on the default inference handler functions, please refer to:
320324
`SageMaker PyTorch Default Inference Handler <https://github.com/aws/sagemaker-pytorch-serving-container/blob/master/src/sagemaker_pytorch_serving_container/default_inference_handler.py>`_.
321325

322326
Serve a PyTorch Model

0 commit comments

Comments
 (0)