@@ -23,6 +23,7 @@ This API document assumes you use the following import statements in your traini
23
23
to learn how to use the following API in your PyTorch training script.
24
24
25
25
.. class :: smp.DistributedModel
26
+ :noindex:
26
27
27
28
A sub-class of ``torch.nn.Module`` which specifies the model to be
28
29
partitioned. Accepts a ``torch.nn.Module`` object ``module`` which is
@@ -157,6 +158,7 @@ This API document assumes you use the following import statements in your traini
157
158
**Methods**
158
159
159
160
.. function:: backward(tensors, grad_tensors)
161
+ :noindex:
160
162
161
163
Triggers a distributed backward
162
164
pass across model partitions. Example usage provided in the previous
@@ -165,42 +167,49 @@ This API document assumes you use the following import statements in your traini
165
167
``retain_grad `` and ``create_graph`` flags are not supported.
166
168
167
169
.. function:: local_buffers( )
170
+ :noindex:
168
171
169
172
Returns an iterator over buffers for the modules in
170
173
the partitioned model that have been assigned to the current process.
171
174
172
175
.. function:: local_named_buffers( )
176
+ :noindex:
173
177
174
178
Returns an iterator over buffers for the
175
179
modules in the partitioned model that have been assigned to the current
176
180
process. This yields both the name of the buffer as well as the buffer
177
181
itself.
178
182
179
183
.. function:: local_parameters( )
184
+ :noindex:
180
185
181
186
Returns an iterator over parameters for the
182
187
modules in the partitioned model that have been assigned to the current
183
188
process.
184
189
185
190
.. function:: local_named_parameters( )
191
+ :noindex:
186
192
187
193
Returns an iterator over parameters for
188
194
the modules in the partitioned model that have been assigned to the
189
195
current process. This yields both the name of the parameter as well as
190
196
the parameter itself.
191
197
192
198
.. function:: local_modules( )
199
+ :noindex:
193
200
194
201
Returns an iterator over the modules in the
195
202
partitioned model that have been assigned to the current process.
196
203
197
204
.. function:: local_named_modules( )
205
+ :noindex:
198
206
199
207
Returns an iterator over the modules in the
200
208
partitioned model that have been assigned to the current process. This
201
209
yields both the name of the module as well as the module itself.
202
210
203
211
.. function:: local_state_dict( )
212
+ :noindex:
204
213
205
214
Returns the ``state_dict `` that contains local
206
215
parameters that belong to the current \ ``mp_rank ``. This ``state_dict ``
@@ -210,34 +219,39 @@ This API document assumes you use the following import statements in your traini
210
219
partition, or to the entire model.
211
220
212
221
.. function:: state_dict( )
222
+ :noindex:
213
223
214
224
Returns the ``state_dict `` that contains parameters
215
225
for the entire model. It first collects the \ ``local_state_dict`` and
216
226
gathers and merges the \ ``local_state_dict`` from all ``mp_rank ``\ s to
217
227
create a full ``state_dict ``.
218
228
219
229
.. function :: load_state_dict( )
230
+ :noindex:
220
231
221
232
Same as the ``torch.module.load_state_dict()`` ,
222
233
except: It first gathers and merges the ``state_dict ``\ s across
223
234
``mp_rank ``\ s, if they are partial. The actual loading happens after the
224
235
model partition so that each rank knows its local parameters.
225
236
226
237
.. function :: register_post_partition_hook(hook)
238
+ :noindex:
227
239
228
240
Registers a callable ``hook`` to
229
241
be executed after the model is partitioned. This is useful in situations
230
242
where an operation needs to be executed after the model partition during
231
- the first call to ``smp.step ``, but before the actual execution of the
243
+ the first call to ``smp.step `` but before the actual execution of the
232
244
first forward pass. Returns a ``RemovableHandle`` object ``handle ``,
233
245
which can be used to remove the hook by calling ``handle.remove() ``.
234
246
235
- .. function :: cpu( )
247
+ .. function :: cpu( )
248
+ :noindex:
236
249
237
250
Allgathers parameters and buffers across all ``mp_rank ``\ s and moves them
238
251
to the CPU.
239
252
240
253
.. class :: smp.DistributedOptimizer
254
+ :noindex:
241
255
242
256
**Parameters **
243
257
- ``optimizer ``
@@ -246,13 +260,15 @@ This API document assumes you use the following import statements in your traini
246
260
returns ``optimizer `` with the following methods overridden:
247
261
248
262
.. function :: state_dict( )
263
+ :noindex:
249
264
250
265
Returns the ``state_dict `` that contains optimizer state for the entire model.
251
266
It first collects the ``local_state_dict`` and gathers and merges
252
267
the ``local_state_dict`` from all ``mp_rank``s to create a full
253
268
``state_dict ``.
254
269
255
270
.. function :: load_state_dict( )
271
+ :noindex:
256
272
257
273
Same as the ``torch.optimizer.load_state_dict()`` , except:
258
274
@@ -262,6 +278,7 @@ This API document assumes you use the following import statements in your traini
262
278
rank knows its local parameters.
263
279
264
280
.. function :: local_state_dict( )
281
+ :noindex:
265
282
266
283
Returns the ``state_dict `` that contains the
267
284
local optimizer state that belongs to the current \ ``mp_rank ``. This
@@ -308,70 +325,79 @@ This API document assumes you use the following import statements in your traini
308
325
self .child3 = Child3() # child3 on default_partition
309
326
310
327
.. function :: smp.get_world_process_group( )
328
+ :noindex:
311
329
312
330
Returns a ``torch.distributed`` ``ProcessGroup`` that consists of all
313
331
processes, which can be used with the ``torch.distributed`` API.
314
332
Requires ``"ddp": True`` in SageMaker Python SDK parameters.
315
333
316
334
.. function:: smp.get_mp_process_group( )
335
+ :noindex:
317
336
318
337
Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the
319
338
processes in the ``MP_GROUP `` which contains the current process, which
320
339
can be used with the \ ``torch.distributed`` API. Requires
321
340
``"ddp": True`` in SageMaker Python SDK parameters.
322
341
323
342
.. function:: smp.get_dp_process_group( )
343
+ :noindex:
324
344
325
345
Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the
326
346
processes in the ``DP_GROUP `` which contains the current process, which
327
347
can be used with the \ ``torch.distributed`` API. Requires
328
348
``"ddp": True`` in SageMaker Python SDK parameters.
329
349
330
350
.. function:: smp.is_initialized( )
351
+ :noindex:
331
352
332
353
Returns ``True `` if ``smp.init `` has already been called for the
333
354
process, and ``False `` otherwise.
334
355
335
356
.. function::smp.is_tracing( )
357
+ :noindex:
336
358
337
359
Returns ``True`` if the current process is running the tracing step, and
338
360
``False`` otherwise.
339
361
340
362
.. data :: smp.nn.FusedLayerNorm
363
+ :noindex:
341
364
342
365
`Apex Fused Layer Norm <https://nvidia.github.io/apex/layernorm.html>`__ is currently not
343
366
supported by the library. ``smp.nn.FusedLayerNorm `` replaces ``apex ``
344
367
``FusedLayerNorm `` and provides the same functionality. This requires
345
368
``apex `` to be installed on the system.
346
369
347
370
.. data :: smp.optimizers.FusedNovoGrad
348
-
371
+ :noindex:
349
372
350
373
`Fused Novo Grad optimizer <https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedNovoGrad>`__ is
351
374
currently not supported by the library. ``smp.optimizers.FusedNovoGrad`` replaces ``apex `` ``FusedNovoGrad ``
352
375
optimizer and provides the same functionality. This requires ``apex `` to
353
376
be installed on the system.
354
377
355
378
.. data :: smp.optimizers.FusedLamb
356
-
379
+ :noindex:
357
380
358
381
`FusedLamb optimizer <https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedLAMB >`__
359
382
currently doesn’t work with the library. ``smp.optimizers.FusedLamb`` replaces
360
383
``apex `` ``FusedLamb `` optimizer and provides the same functionality.
361
384
This requires ``apex `` to be installed on the system.
362
385
363
386
.. data :: smp.amp.GradScaler
387
+ :noindex:
364
388
365
389
`Torch AMP Gradscaler <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler >`__
366
390
currently doesn’t work with the library. ``smp.amp.GradScaler`` replaces
367
391
``torch.amp.GradScaler `` and provides the same functionality.
368
392
369
393
.. _pytorch_saving_loading :
394
+ :noindex:
370
395
371
396
APIs for Saving and Loading
372
397
^^^^^^^^^^^^^^^^^^^^^^^^^^^
373
398
374
399
.. function :: smp.save( )
400
+ :noindex:
375
401
376
402
Saves an object. This operation is similar to ``torch.save() ``, except
377
403
it has an additional keyword argument, ``partial ``, and accepts only
@@ -394,6 +420,7 @@ APIs for Saving and Loading
394
420
override the defaultprotocol.
395
421
396
422
.. function :: smp.load( )
423
+ :noindex:
397
424
398
425
Loads an object saved with ``smp.save() `` from a file.
399
426
@@ -418,6 +445,7 @@ APIs for Saving and Loading
418
445
Should be used when loading a model trained with the library.
419
446
420
447
.. _pytorch_saving_loading_instructions :
448
+ :noindex:
421
449
422
450
General Instruction For Saving and Loading
423
451
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0 commit comments