Skip to content

Commit ddca516

Browse files
authored
Merge branch 'zwei' into require-framework-version
2 parents 13c9d66 + 7a7c658 commit ddca516

File tree

1 file changed

+23
-185
lines changed

1 file changed

+23
-185
lines changed

doc/frameworks/mxnet/using_mxnet.rst

Lines changed: 23 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,6 @@ To train an MXNet model by using the SageMaker Python SDK:
3030
Prepare an MXNet Training Script
3131
================================
3232

33-
.. warning::
34-
The structure for training scripts changed starting at MXNet version 1.3.
35-
Make sure you refer to the correct section of this README when you prepare your script.
36-
For information on how to upgrade an old script to the new format, see `"Updating your MXNet training script" <#updating-your-mxnet-training-script>`__.
37-
38-
For versions 1.3 and higher
39-
---------------------------
40-
Your MXNet training script must be compatible with Python 2.7 or 3.6.
41-
4233
The training script is very similar to a training script you might run outside of Amazon SageMaker, but you can access useful properties about the training environment through various environment variables, including the following:
4334

4435
* ``SM_MODEL_DIR``: A string that represents the path where the training job writes the model artifacts to.
@@ -89,119 +80,8 @@ If you want to use, for example, boolean hyperparameters, you need to specify ``
8980

9081
For more on training environment variables, please visit `SageMaker Containers <https://github.com/aws/sagemaker-containers>`_.
9182

92-
For versions 1.2 and lower
93-
--------------------------
94-
95-
Your MXNet training script must be compatible with Python 2.7 or 3.5.
96-
The script must contain a function named ``train``, which Amazon SageMaker invokes to run training.
97-
You can include other functions as well, but it must contain a ``train`` function.
98-
99-
When you run your script on Amazon SageMaker via the ``MXNet`` estimator, Amazon SageMaker injects information about the training environment into your training function via Python keyword arguments.
100-
You can choose to take advantage of these by including them as keyword arguments in your train function. The full list of arguments is:
101-
102-
- ``hyperparameters (dict[string,string])``: The hyperparameters passed
103-
to an Amazon SageMaker TrainingJob that runs your MXNet training script. You
104-
can use this to pass hyperparameters to your training script.
105-
- ``input_data_config (dict[string,dict])``: The Amazon SageMaker TrainingJob
106-
InputDataConfig object, that's set when the Amazon SageMaker TrainingJob is
107-
created. This is discussed in more detail below.
108-
- ``channel_input_dirs (dict[string,string])``: A collection of
109-
directories containing training data. When you run training, you can
110-
partition your training data into different logical "channels".
111-
Depending on your problem, some common channel ideas are: "train",
112-
"test", "evaluation" or "images',"labels".
113-
- ``output_data_dir (str)``: A directory where your training script can
114-
write data that is moved to Amazon S3 after training is complete.
115-
- ``num_gpus (int)``: The number of GPU devices available on your
116-
training instance.
117-
- ``num_cpus (int)``: The number of CPU devices available on your training instance.
118-
- ``hosts (list[str])``: The list of host names running in the
119-
Amazon SageMaker Training Job cluster.
120-
- ``current_host (str)``: The name of the host executing the script.
121-
When you use Amazon SageMaker for MXNet training, the script is run on each
122-
host in the cluster.
123-
124-
A training script that takes advantage of all arguments would have the following definition:
125-
126-
.. code:: python
127-
128-
def train(hyperparameters, input_data_config, channel_input_dirs, output_data_dir,
129-
num_gpus, num_cpus, hosts, current_host)
130-
131-
You don't have to use all the arguments.
132-
Arguments you don't care about can be ignored by including ``**kwargs``.
133-
134-
.. code:: python
135-
136-
# Only work with hyperparameters and num_gpus, and ignore all other hyperparameters
137-
def train(hyperparameters, num_gpus, **kwargs)
138-
139-
.. note::
140-
**Writing a training script that imports correctly:**
141-
When Amazon SageMaker runs your training script, it imports it as a Python module and then invokes ``train`` on the imported module.
142-
Consequently, you should not include any statements that won't execute successfully in Amazon SageMaker when your module is imported.
143-
For example, don't attempt to open any local files in top-level statements in your training script.
144-
145-
If you want to run your training script locally by using the Python interpreter, use a ``___name__ == '__main__'`` guard.
146-
For more information, see https://stackoverflow.com/questions/419163/what-does-if-name-main-do.
147-
148-
Save the Model
149-
^^^^^^^^^^^^^^
150-
151-
Just as you enable training by defining a ``train`` function in your training script, you enable model saving by defining a ``save`` function in your script.
152-
If your script includes a ``save`` function, Amazon SageMaker invokes it with the return value of ``train``.
153-
Model saving is a two-step process.
154-
First, return the model you want to save from ``train``.
155-
Then, define your model-serialization logic in ``save``.
156-
157-
Amazon SageMaker provides a default implementation of ``save`` that works with MXNet Module API ``Module`` objects.
158-
If your training script does not define a ``save`` function, then the default ``save`` function is invoked on the return value of your ``train`` function.
159-
160-
The default serialization system generates three files:
161-
162-
- ``model-shapes.json``: A JSON list, containing a serialization of the
163-
``Module`` ``data_shapes`` property. Each object in the list contains
164-
the serialization of one ``DataShape`` in the returned ``Module``.
165-
Each object has a ``name`` property, containing the ``DataShape``
166-
name and a ``shape`` property, which is a list of that dimensions for
167-
the shape of that ``DataShape``. For example:
168-
169-
.. code:: javascript
170-
171-
[
172-
{"name":"images", "shape":[100, 1, 28, 28]},
173-
{"name":"labels", "shape":[100, 1]}
174-
]
175-
176-
- ``model-symbol.json``: The MXNet ``Module`` ``Symbol`` serialization,
177-
produced by invoking ``save`` on the ``symbol`` property of the
178-
``Module`` being saved.
179-
- ``modle.params``: The MXNet ``Module`` parameters, produced by
180-
invoking ``save_params`` on the ``Module`` being saved.
181-
182-
You can provide your own save function. This is useful if you are not working with the ``Module`` API or you need special processing.
183-
184-
To provide your own save function, define a ``save`` function in your training script:
185-
186-
.. code:: python
187-
188-
def save(model, model_dir)
189-
190-
The function should take two arguments:
191-
192-
- ``model``: This is the object that is returned from your ``train`` function.
193-
You may return an object of any type from ``train``;
194-
you do not have to return ``Module`` or ``Gluon`` API specific objects.
195-
If your ``train`` function does not return an object, ``model`` is set to ``None``.
196-
- ``model_dir``: This is the string path on the Amazon SageMaker training host where you save your model.
197-
Files created in this directory are accessible in Amazon S3 after your Amazon SageMaker Training Job completes.
198-
199-
After your ``train`` function completes, Amazon SageMaker invokes ``save`` with the object returned from ``train``.
200-
20183
.. note::
202-
**How to save Gluon models with Amazon SageMaker:**
203-
If your train function returns a Gluon API ``net`` object as its model, you need to write your own ``save`` function and serialize the ``net`` parameters.
204-
Saving ``net`` parameters is covered in the `Serialization section <http://gluon.mxnet.io/chapter03_deep-neural-networks/serialization.html>`__ of the collaborative Gluon deep-learning book `"The Straight Dope" <http://gluon.mxnet.io/index.html>`__.
84+
If you want to use MXNet 1.2 or lower, see `an older version of this page <https://sagemaker.readthedocs.io/en/v1.61.0/frameworks/mxnet/using_mxnet.html>`_.
20585

20686
Save a Checkpoint
20787
-----------------
@@ -233,86 +113,44 @@ To save MXNet model checkpoints, do the following in your training script:
233113
234114
For a complete example of an MXNet training script that impelements checkpointing, see https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_gluon_cifar10/cifar10.py.
235115

116+
Save the Model
117+
--------------
236118

237-
Update your MXNet training script
238-
---------------------------------
239-
240-
The structure for training scripts changed with MXNet version 1.3.
241-
The ``train`` function is no longer be required; instead the training script must be able to be run as a standalone script.
242-
In this way, the training script is similar to a training script you might run outside of Amazon SageMaker.
243-
244-
There are a few steps needed to make a training script with the old format compatible with the new format.
245-
246-
First, add a `main guard <https://docs.python.org/3/library/__main__.html>`__ (``if __name__ == '__main__':``).
247-
The code executed from your main guard needs to:
248-
249-
1. Set hyperparameters and directory locations
250-
2. Initiate training
251-
3. Save the model
252-
253-
Hyperparameters are passed as command-line arguments to your training script.
254-
In addition, the container defines the locations of input data and where to save the model artifacts and output data as environment variables rather than passing that information as arguments to the ``train`` function.
255-
You can find the full list of available environment variables in the `SageMaker Containers README <https://github.com/aws/sagemaker-containers#list-of-provided-environment-variables-by-sagemaker-containers>`__.
256-
257-
We recommend using `an argument parser <https://docs.python.org/3.5/howto/argparse.html>`__ for this part.
258-
Using the ``argparse`` library as an example, the code looks something like this:
259-
260-
.. code:: python
261-
262-
import argparse
263-
import os
264-
265-
if __name__ == '__main__':
266-
parser = argparse.ArgumentParser()
267-
268-
# hyperparameters sent by the client are passed as command-line arguments to the script.
269-
parser.add_argument('--epochs', type=int, default=10)
270-
parser.add_argument('--batch-size', type=int, default=100)
271-
parser.add_argument('--learning-rate', type=float, default=0.1)
272-
273-
# input data and model directories
274-
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
275-
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
276-
parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])
277-
278-
args, _ = parser.parse_known_args()
279-
280-
The code in the main guard should also take care of training and saving the model.
281-
This can be as simple as just calling the ``train`` and ``save`` methods used in the previous training script format:
282-
283-
.. code:: python
284-
285-
if __name__ == '__main__':
286-
# arg parsing (shown above) goes here
287-
288-
model = train(args.batch_size, args.epochs, args.learning_rate, args.train, args.test)
289-
save(args.model_dir, model)
290-
291-
Note that saving the model is no longer be done by default; this must be done by the training script.
292-
If you were previously relying on the default save method, you can import one from the container:
119+
There is a default save method that can be imported when training on SageMaker:
293120

294121
.. code:: python
295122
296-
from sagemaker_mxnet_container.training_utils import save
123+
from sagemaker_mxnet_training.training_utils import save
297124
298125
if __name__ == '__main__':
299126
# arg parsing and training (shown above) goes here
300127
301128
save(args.model_dir, model)
302129
303-
Lastly, if you were relying on the container launching a parameter server for use with distributed training, you must set ``distributions`` to the following dictionary when creating an MXNet estimator:
130+
The default serialization system generates three files:
304131

305-
.. code:: python
132+
- ``model-shapes.json``: A JSON list, containing a serialization of the
133+
``Module`` ``data_shapes`` property. Each object in the list contains
134+
the serialization of one ``DataShape`` in the returned ``Module``.
135+
Each object has a ``name`` property, containing the ``DataShape``
136+
name and a ``shape`` property, which is a list of that dimensions for
137+
the shape of that ``DataShape``. For example:
306138

307-
from sagemaker.mxnet import MXNet
139+
.. code:: javascript
308140
309-
estimator = MXNet('path-to-distributed-training-script.py',
310-
...,
311-
distributions={'parameter_server': {'enabled': True}})
141+
[
142+
{"name":"images", "shape":[100, 1, 28, 28]},
143+
{"name":"labels", "shape":[100, 1]}
144+
]
312145
146+
- ``model-symbol.json``: The MXNet ``Module`` ``Symbol`` serialization,
147+
produced by invoking ``save`` on the ``symbol`` property of the
148+
``Module`` being saved.
149+
- ``modle.params``: The MXNet ``Module`` parameters, produced by
150+
invoking ``save_params`` on the ``Module`` being saved.
313151

314152
Use third-party libraries
315-
-------------------------
153+
=========================
316154

317155
When running your training script on Amazon SageMaker, it has access to some pre-installed third-party libraries, including ``mxnet``, ``numpy``, ``onnx``, and ``keras-mxnet``.
318156
For more information on the runtime environment, including specific package versions, see `SageMaker MXNet Containers <#sagemaker-mxnet-containers>`__.

0 commit comments

Comments
 (0)