You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -30,15 +30,6 @@ To train an MXNet model by using the SageMaker Python SDK:
30
30
Prepare an MXNet Training Script
31
31
================================
32
32
33
-
.. warning::
34
-
The structure for training scripts changed starting at MXNet version 1.3.
35
-
Make sure you refer to the correct section of this README when you prepare your script.
36
-
For information on how to upgrade an old script to the new format, see `"Updating your MXNet training script" <#updating-your-mxnet-training-script>`__.
37
-
38
-
For versions 1.3 and higher
39
-
---------------------------
40
-
Your MXNet training script must be compatible with Python 2.7 or 3.6.
41
-
42
33
The training script is very similar to a training script you might run outside of Amazon SageMaker, but you can access useful properties about the training environment through various environment variables, including the following:
43
34
44
35
* ``SM_MODEL_DIR``: A string that represents the path where the training job writes the model artifacts to.
@@ -89,119 +80,8 @@ If you want to use, for example, boolean hyperparameters, you need to specify ``
89
80
90
81
For more on training environment variables, please visit `SageMaker Containers <https://github.com/aws/sagemaker-containers>`_.
91
82
92
-
For versions 1.2 and lower
93
-
--------------------------
94
-
95
-
Your MXNet training script must be compatible with Python 2.7 or 3.5.
96
-
The script must contain a function named ``train``, which Amazon SageMaker invokes to run training.
97
-
You can include other functions as well, but it must contain a ``train`` function.
98
-
99
-
When you run your script on Amazon SageMaker via the ``MXNet`` estimator, Amazon SageMaker injects information about the training environment into your training function via Python keyword arguments.
100
-
You can choose to take advantage of these by including them as keyword arguments in your train function. The full list of arguments is:
101
-
102
-
- ``hyperparameters (dict[string,string])``: The hyperparameters passed
103
-
to an Amazon SageMaker TrainingJob that runs your MXNet training script. You
104
-
can use this to pass hyperparameters to your training script.
105
-
- ``input_data_config (dict[string,dict])``: The Amazon SageMaker TrainingJob
106
-
InputDataConfig object, that's set when the Amazon SageMaker TrainingJob is
107
-
created. This is discussed in more detail below.
108
-
- ``channel_input_dirs (dict[string,string])``: A collection of
109
-
directories containing training data. When you run training, you can
110
-
partition your training data into different logical "channels".
111
-
Depending on your problem, some common channel ideas are: "train",
112
-
"test", "evaluation" or "images',"labels".
113
-
- ``output_data_dir (str)``: A directory where your training script can
114
-
write data that is moved to Amazon S3 after training is complete.
115
-
- ``num_gpus (int)``: The number of GPU devices available on your
116
-
training instance.
117
-
- ``num_cpus (int)``: The number of CPU devices available on your training instance.
118
-
- ``hosts (list[str])``: The list of host names running in the
119
-
Amazon SageMaker Training Job cluster.
120
-
- ``current_host (str)``: The name of the host executing the script.
121
-
When you use Amazon SageMaker for MXNet training, the script is run on each
122
-
host in the cluster.
123
-
124
-
A training script that takes advantage of all arguments would have the following definition:
Arguments you don't care about can be ignored by including ``**kwargs``.
133
-
134
-
.. code:: python
135
-
136
-
# Only work with hyperparameters and num_gpus, and ignore all other hyperparameters
137
-
deftrain(hyperparameters, num_gpus, **kwargs)
138
-
139
-
.. note::
140
-
**Writing a training script that imports correctly:**
141
-
When Amazon SageMaker runs your training script, it imports it as a Python module and then invokes ``train`` on the imported module.
142
-
Consequently, you should not include any statements that won't execute successfully in Amazon SageMaker when your module is imported.
143
-
For example, don't attempt to open any local files in top-level statements in your training script.
144
-
145
-
If you want to run your training script locally by using the Python interpreter, use a ``___name__ == '__main__'`` guard.
146
-
For more information, see https://stackoverflow.com/questions/419163/what-does-if-name-main-do.
147
-
148
-
Save the Model
149
-
^^^^^^^^^^^^^^
150
-
151
-
Just as you enable training by defining a ``train`` function in your training script, you enable model saving by defining a ``save`` function in your script.
152
-
If your script includes a ``save`` function, Amazon SageMaker invokes it with the return value of ``train``.
153
-
Model saving is a two-step process.
154
-
First, return the model you want to save from ``train``.
155
-
Then, define your model-serialization logic in ``save``.
156
-
157
-
Amazon SageMaker provides a default implementation of ``save`` that works with MXNet Module API ``Module`` objects.
158
-
If your training script does not define a ``save`` function, then the default ``save`` function is invoked on the return value of your ``train`` function.
159
-
160
-
The default serialization system generates three files:
161
-
162
-
- ``model-shapes.json``: A JSON list, containing a serialization of the
163
-
``Module`` ``data_shapes`` property. Each object in the list contains
164
-
the serialization of one ``DataShape`` in the returned ``Module``.
165
-
Each object has a ``name`` property, containing the ``DataShape``
166
-
name and a ``shape`` property, which is a list of that dimensions for
167
-
the shape of that ``DataShape``. For example:
168
-
169
-
.. code:: javascript
170
-
171
-
[
172
-
{"name":"images", "shape":[100, 1, 28, 28]},
173
-
{"name":"labels", "shape":[100, 1]}
174
-
]
175
-
176
-
- ``model-symbol.json``: The MXNet ``Module`` ``Symbol`` serialization,
177
-
produced by invoking ``save`` on the ``symbol`` property of the
178
-
``Module`` being saved.
179
-
- ``modle.params``: The MXNet ``Module`` parameters, produced by
180
-
invoking ``save_params`` on the ``Module`` being saved.
181
-
182
-
You can provide your own save function. This is useful if you are not working with the ``Module`` API or you need special processing.
183
-
184
-
To provide your own save function, define a ``save`` function in your training script:
185
-
186
-
.. code:: python
187
-
188
-
defsave(model, model_dir)
189
-
190
-
The function should take two arguments:
191
-
192
-
- ``model``: This is the object that is returned from your ``train`` function.
193
-
You may return an object of any type from ``train``;
194
-
you do not have to return ``Module`` or ``Gluon`` API specific objects.
195
-
If your ``train`` function does not return an object, ``model`` is set to ``None``.
196
-
- ``model_dir``: This is the string path on the Amazon SageMaker training host where you save your model.
197
-
Files created in this directory are accessible in Amazon S3 after your Amazon SageMaker Training Job completes.
198
-
199
-
After your ``train`` function completes, Amazon SageMaker invokes ``save`` with the object returned from ``train``.
200
-
201
83
.. note::
202
-
**How to save Gluon models with Amazon SageMaker:**
203
-
If your train function returns a Gluon API ``net`` object as its model, you need to write your own ``save`` function and serialize the ``net`` parameters.
204
-
Saving ``net`` parameters is covered in the `Serialization section <http://gluon.mxnet.io/chapter03_deep-neural-networks/serialization.html>`__ of the collaborative Gluon deep-learning book `"The Straight Dope" <http://gluon.mxnet.io/index.html>`__.
84
+
If you want to use MXNet 1.2 or lower, see `an older version of this page <https://sagemaker.readthedocs.io/en/v1.61.0/frameworks/mxnet/using_mxnet.html>`_.
205
85
206
86
Save a Checkpoint
207
87
-----------------
@@ -233,86 +113,44 @@ To save MXNet model checkpoints, do the following in your training script:
233
113
234
114
For a complete example of an MXNet training script that impelements checkpointing, see https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_gluon_cifar10/cifar10.py.
235
115
116
+
Save the Model
117
+
--------------
236
118
237
-
Update your MXNet training script
238
-
---------------------------------
239
-
240
-
The structure for training scripts changed with MXNet version 1.3.
241
-
The ``train`` function is no longer be required; instead the training script must be able to be run as a standalone script.
242
-
In this way, the training script is similar to a training script you might run outside of Amazon SageMaker.
243
-
244
-
There are a few steps needed to make a training script with the old format compatible with the new format.
245
-
246
-
First, add a `main guard <https://docs.python.org/3/library/__main__.html>`__ (``if __name__ == '__main__':``).
247
-
The code executed from your main guard needs to:
248
-
249
-
1. Set hyperparameters and directory locations
250
-
2. Initiate training
251
-
3. Save the model
252
-
253
-
Hyperparameters are passed as command-line arguments to your training script.
254
-
In addition, the container defines the locations of input data and where to save the model artifacts and output data as environment variables rather than passing that information as arguments to the ``train`` function.
255
-
You can find the full list of available environment variables in the `SageMaker Containers README <https://github.com/aws/sagemaker-containers#list-of-provided-environment-variables-by-sagemaker-containers>`__.
256
-
257
-
We recommend using `an argument parser <https://docs.python.org/3.5/howto/argparse.html>`__ for this part.
258
-
Using the ``argparse`` library as an example, the code looks something like this:
259
-
260
-
.. code:: python
261
-
262
-
import argparse
263
-
import os
264
-
265
-
if__name__=='__main__':
266
-
parser = argparse.ArgumentParser()
267
-
268
-
# hyperparameters sent by the client are passed as command-line arguments to the script.
The code in the main guard should also take care of training and saving the model.
281
-
This can be as simple as just calling the ``train`` and ``save`` methods used in the previous training script format:
282
-
283
-
.. code:: python
284
-
285
-
if__name__=='__main__':
286
-
# arg parsing (shown above) goes here
287
-
288
-
model = train(args.batch_size, args.epochs, args.learning_rate, args.train, args.test)
289
-
save(args.model_dir, model)
290
-
291
-
Note that saving the model is no longer be done by default; this must be done by the training script.
292
-
If you were previously relying on the default save method, you can import one from the container:
119
+
There is a default save method that can be imported when training on SageMaker:
293
120
294
121
.. code:: python
295
122
296
-
fromsagemaker_mxnet_container.training_utils import save
123
+
fromsagemaker_mxnet_training.training_utils import save
297
124
298
125
if__name__=='__main__':
299
126
# arg parsing and training (shown above) goes here
300
127
301
128
save(args.model_dir, model)
302
129
303
-
Lastly, if you were relying on the container launching a parameter server for use with distributed training, you must set ``distributions`` to the following dictionary when creating an MXNet estimator:
130
+
The default serialization system generates three files:
304
131
305
-
.. code:: python
132
+
- ``model-shapes.json``: A JSON list, containing a serialization of the
133
+
``Module`` ``data_shapes`` property. Each object in the list contains
134
+
the serialization of one ``DataShape`` in the returned ``Module``.
135
+
Each object has a ``name`` property, containing the ``DataShape``
136
+
name and a ``shape`` property, which is a list of that dimensions for
- ``model-symbol.json``: The MXNet ``Module`` ``Symbol`` serialization,
147
+
produced by invoking ``save`` on the ``symbol`` property of the
148
+
``Module`` being saved.
149
+
- ``modle.params``: The MXNet ``Module`` parameters, produced by
150
+
invoking ``save_params`` on the ``Module`` being saved.
313
151
314
152
Use third-party libraries
315
-
-------------------------
153
+
=========================
316
154
317
155
When running your training script on Amazon SageMaker, it has access to some pre-installed third-party libraries, including ``mxnet``, ``numpy``, ``onnx``, and ``keras-mxnet``.
318
156
For more information on the runtime environment, including specific package versions, see `SageMaker MXNet Containers <#sagemaker-mxnet-containers>`__.
0 commit comments