Skip to content

Commit 1d84c6b

Browse files
anirudh2290Talia Chopra
and
Talia Chopra
authored
documentation: Add SMP 1.2.0 API docs (#2098)
* Documentation: Add SMP API changes * documentation: adding noindex to distributed model parallel v1.1.0 docs * Documentation: SMP API Add 1.6 to supported versions Co-authored-by: Talia Chopra <[email protected]>
1 parent d33d082 commit 1d84c6b

9 files changed

+1207
-8
lines changed

doc/api/training/smd_model_parallel.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Select a version to see the API documentation for version. To use the library, r
4343
.. toctree::
4444
:maxdepth: 1
4545

46+
smp_versions/v1_2_0.rst
4647
smp_versions/v1_1_0.rst
4748

4849
It is recommended to use this documentation alongside `SageMaker Distributed Model Parallel

doc/api/training/smp_versions/v1.1.0/smd_model_parallel_common_api.rst

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ The following SageMaker distribute model parallel APIs are common across all fra
2424
2525
2626
.. function:: smp.init( )
27+
:noindex:
2728

2829
Initialize the library. Must be called at the beginning of training script.
2930

3031
.. function:: @smp.step(non_split_inputs, input_split_axes, [*args, **kwargs])
32+
:noindex:
3133

3234
A decorator that must be placed over a function that represents a single
3335
forward and backward pass (for training use cases), or a single forward
@@ -159,6 +161,7 @@ The following SageMaker distribute model parallel APIs are common across all fra
159161

160162

161163
.. class:: StepOutput
164+
:noindex:
162165

163166

164167
A class that encapsulates all versions of a ``tf.Tensor``
@@ -188,27 +191,32 @@ The following SageMaker distribute model parallel APIs are common across all fra
188191
post-processing operations on tensors.
189192
190193
.. data:: StepOutput.outputs
194+
:noindex:
191195
192196
Returns a list of the underlying tensors, indexed by microbatch.
193197
194198
.. function:: StepOutput.reduce_mean( )
199+
:noindex:
195200
196201
Returns a ``tf.Tensor``, ``torch.Tensor`` that averages the constituent ``tf.Tensor`` s
197202
``torch.Tensor`` s. This is commonly used for averaging loss and gradients across microbatches.
198203

199204
.. function:: StepOutput.reduce_sum( )
205+
:noindex:
200206

201207
Returns a ``tf.Tensor`` /
202208
``torch.Tensor`` that sums the constituent
203209
``tf.Tensor``\ s/\ ``torch.Tensor``\ s.
204210

205211
.. function:: StepOutput.concat( )
212+
:noindex:
206213

207214
Returns a
208215
``tf.Tensor``/``torch.Tensor`` that concatenates tensors along the
209216
batch dimension using ``tf.concat`` / ``torch.cat``.
210217

211218
.. function:: StepOutput.stack( )
219+
:noindex:
212220

213221
Applies ``tf.stack`` / ``torch.stack``
214222
operation to the list of constituent ``tf.Tensor``\ s /
@@ -217,13 +225,15 @@ The following SageMaker distribute model parallel APIs are common across all fra
217225
**TensorFlow-only methods**
218226

219227
.. function:: StepOutput.merge( )
228+
:noindex:
220229

221230
Returns a ``tf.Tensor`` that
222231
concatenates the constituent ``tf.Tensor``\ s along the batch
223232
dimension. This is commonly used for merging the model predictions
224233
across microbatches.
225234

226235
.. function:: StepOutput.accumulate(method="variable", var=None)
236+
:noindex:
227237

228238
Functionally the same as ``StepOutput.reduce_mean()``. However, it is
229239
more memory-efficient, especially for large numbers of microbatches,
@@ -249,6 +259,7 @@ The following SageMaker distribute model parallel APIs are common across all fra
249259
ignored.
250260

251261
.. _mpi_basics:
262+
:noindex:
252263

253264
MPI Basics
254265
^^^^^^^^^^
@@ -271,7 +282,8 @@ The library exposes the following basic MPI primitives to its Python API:
271282
- ``smp.get_dp_group()``: The list of ranks that hold different
272283
replicas of the same model partition.
273284

274-
.. _communication_api:
285+
.. _communication_api:
286+
:noindex:
275287

276288
Communication API
277289
^^^^^^^^^^^^^^^^^
@@ -285,6 +297,7 @@ should involve.
285297
**Helper structures**
286298

287299
.. data:: smp.CommGroup
300+
:noindex:
288301

289302
An ``enum`` that takes the values
290303
``CommGroup.WORLD``, ``CommGroup.MP_GROUP``, and ``CommGroup.DP_GROUP``.
@@ -303,6 +316,7 @@ should involve.
303316
themselves.
304317
305318
.. data:: smp.RankType
319+
:noindex:
306320
307321
An ``enum`` that takes the values
308322
``RankType.WORLD_RANK``, ``RankType.MP_RANK``, and ``RankType.DP_RANK``.
@@ -318,6 +332,7 @@ should involve.
318332
**Communication primitives:**
319333

320334
.. function:: smp.broadcast(obj, group)
335+
:noindex:
321336

322337
Sends the object to all processes in the
323338
group. The receiving process must call ``smp.recv_from`` to receive the
@@ -350,6 +365,7 @@ should involve.
350365
    smp.recv_from(0, rank_type=smp.RankType.WORLD_RANK)
351366
352367
.. function:: smp.send(obj, dest_rank, rank_type)
368+
:noindex:
353369
354370
Sends the object ``obj`` to
355371
``dest_rank``, which is of a type specified by ``rank_type``.
@@ -373,6 +389,7 @@ should involve.
373389
``recv_from`` call.
374390
375391
.. function:: smp.recv_from(src_rank, rank_type)
392+
:noindex:
376393
377394
Receive an object from a peer process. Can be used with a matching
378395
``smp.send`` or a ``smp.broadcast`` call.
@@ -398,6 +415,7 @@ should involve.
398415
``broadcast`` call, and the object is received.
399416

400417
.. function:: smp.allgather(obj, group)
418+
:noindex:
401419

402420
A collective call that gathers all the
403421
submitted objects across all ranks in the specified ``group``. Returns a
@@ -431,6 +449,7 @@ should involve.
431449
    out = smp.allgather(obj2, smp.CommGroup.MP_GROUP# returns [obj1, obj2]
432450
433451
.. function:: smp.barrier(group=smp.WORLD)
452+
:noindex:
434453

435454
A statement that hangs until all
436455
processes in the specified group reach the barrier statement, similar to
@@ -452,12 +471,14 @@ should involve.
452471
processes outside that ``mp_group``.
453472

454473
.. function:: smp.dp_barrier()
474+
:noindex:
455475

456476
Same as passing ``smp.DP_GROUP``\ to ``smp.barrier()``.
457477
Waits for the processes in the same \ ``dp_group`` as
458478
the current process to reach the same point in execution.
459479
460480
.. function:: smp.mp_barrier()
481+
:noindex:
461482
462483
Same as passing ``smp.MP_GROUP`` to
463484
``smp.barrier()``. Waits for the processes in the same ``mp_group`` as

doc/api/training/smp_versions/v1.1.0/smd_model_parallel_pytorch.rst

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ This API document assumes you use the following import statements in your traini
2323
to learn how to use the following API in your PyTorch training script.
2424

2525
.. class:: smp.DistributedModel
26+
:noindex:
2627

2728
A sub-class of ``torch.nn.Module`` which specifies the model to be
2829
partitioned. Accepts a ``torch.nn.Module`` object ``module`` which is
@@ -157,6 +158,7 @@ This API document assumes you use the following import statements in your traini
157158
**Methods**
158159
159160
.. function:: backward(tensors, grad_tensors)
161+
:noindex:
160162
161163
Triggers a distributed backward
162164
pass across model partitions. Example usage provided in the previous
@@ -165,42 +167,49 @@ This API document assumes you use the following import statements in your traini
165167
``retain_grad`` and ``create_graph``  flags are not supported.
166168
167169
.. function:: local_buffers( )
170+
:noindex:
168171
169172
Returns an iterator over buffers for the modules in
170173
the partitioned model that have been assigned to the current process.
171174
172175
.. function:: local_named_buffers( )
176+
:noindex:
173177
174178
Returns an iterator over buffers for the
175179
modules in the partitioned model that have been assigned to the current
176180
process. This yields both the name of the buffer as well as the buffer
177181
itself.
178182
179183
.. function:: local_parameters( )
184+
:noindex:
180185
181186
Returns an iterator over parameters for the
182187
modules in the partitioned model that have been assigned to the current
183188
process.
184189
185190
.. function:: local_named_parameters( )
191+
:noindex:
186192
187193
Returns an iterator over parameters for
188194
the modules in the partitioned model that have been assigned to the
189195
current process. This yields both the name of the parameter as well as
190196
the parameter itself.
191197
192198
.. function:: local_modules( )
199+
:noindex:
193200
194201
Returns an iterator over the modules in the
195202
partitioned model that have been assigned to the current process.
196203
197204
.. function:: local_named_modules( )
205+
:noindex:
198206
199207
Returns an iterator over the modules in the
200208
partitioned model that have been assigned to the current process. This
201209
yields both the name of the module as well as the module itself.
202210
203211
.. function:: local_state_dict( )
212+
:noindex:
204213
205214
Returns the ``state_dict`` that contains local
206215
parameters that belong to the current \ ``mp_rank``. This ``state_dict``
@@ -210,34 +219,39 @@ This API document assumes you use the following import statements in your traini
210219
partition, or to the entire model.
211220
212221
.. function:: state_dict( )
222+
:noindex:
213223
214224
Returns the ``state_dict`` that contains parameters
215225
for the entire model. It first collects the \ ``local_state_dict``  and
216226
gathers and merges the \ ``local_state_dict`` from all ``mp_rank``\ s to
217227
create a full ``state_dict``.
218228

219229
.. function:: load_state_dict( )
230+
:noindex:
220231

221232
Same as the ``torch.module.load_state_dict()`` ,
222233
except: It first gathers and merges the ``state_dict``\ s across
223234
``mp_rank``\ s, if they are partial. The actual loading happens after the
224235
model partition so that each rank knows its local parameters.
225236

226237
.. function:: register_post_partition_hook(hook)
238+
:noindex:
227239

228240
Registers a callable ``hook`` to
229241
be executed after the model is partitioned. This is useful in situations
230242
where an operation needs to be executed after the model partition during
231-
the first call to ``smp.step``, but before the actual execution of the
243+
the first call to ``smp.step`` but before the actual execution of the
232244
first forward pass. Returns a ``RemovableHandle`` object ``handle``,
233245
which can be used to remove the hook by calling ``handle.remove()``.
234246

235-
.. function:: cpu( )
247+
.. function:: cpu( )
248+
:noindex:
236249

237250
Allgathers parameters and buffers across all ``mp_rank``\ s and moves them
238251
to the CPU.
239252

240253
.. class:: smp.DistributedOptimizer
254+
:noindex:
241255

242256
**Parameters**
243257
- ``optimizer``
@@ -246,13 +260,15 @@ This API document assumes you use the following import statements in your traini
246260
returns ``optimizer`` with the following methods overridden:
247261

248262
.. function:: state_dict( )
263+
:noindex:
249264

250265
Returns the ``state_dict`` that contains optimizer state for the entire model.
251266
It first collects the ``local_state_dict`` and gathers and merges
252267
the ``local_state_dict`` from all ``mp_rank``s to create a full
253268
``state_dict``.
254269

255270
.. function:: load_state_dict( )
271+
:noindex:
256272

257273
Same as the ``torch.optimizer.load_state_dict()`` , except:
258274
@@ -262,6 +278,7 @@ This API document assumes you use the following import statements in your traini
262278
rank knows its local parameters.
263279

264280
.. function:: local_state_dict( )
281+
:noindex:
265282

266283
Returns the ``state_dict`` that contains the
267284
local optimizer state that belongs to the current \ ``mp_rank``. This
@@ -308,70 +325,79 @@ This API document assumes you use the following import statements in your traini
308325
        self.child3 = Child3()                # child3 on default_partition
309326
310327
.. function:: smp.get_world_process_group( )
328+
:noindex:
311329

312330
Returns a ``torch.distributed`` ``ProcessGroup`` that consists of all
313331
processes, which can be used with the ``torch.distributed`` API.
314332
Requires ``"ddp": True`` in SageMaker Python SDK parameters.
315333
316334
.. function:: smp.get_mp_process_group( )
335+
:noindex:
317336
318337
Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the
319338
processes in the ``MP_GROUP`` which contains the current process, which
320339
can be used with the \ ``torch.distributed`` API. Requires
321340
``"ddp": True`` in SageMaker Python SDK parameters.
322341
323342
.. function:: smp.get_dp_process_group( )
343+
:noindex:
324344
325345
Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the
326346
processes in the ``DP_GROUP`` which contains the current process, which
327347
can be used with the \ ``torch.distributed`` API. Requires
328348
``"ddp": True`` in SageMaker Python SDK parameters.
329349
330350
.. function:: smp.is_initialized( )
351+
:noindex:
331352
332353
Returns ``True`` if ``smp.init`` has already been called for the
333354
process, and ``False`` otherwise.
334355

335356
.. function::smp.is_tracing( )
357+
:noindex:
336358
337359
Returns ``True`` if the current process is running the tracing step, and
338360
``False`` otherwise.
339361
340362
.. data:: smp.nn.FusedLayerNorm
363+
:noindex:
341364

342365
`Apex Fused Layer Norm <https://nvidia.github.io/apex/layernorm.html>`__ is currently not
343366
supported by the library. ``smp.nn.FusedLayerNorm`` replaces ``apex``
344367
``FusedLayerNorm`` and provides the same functionality. This requires
345368
``apex`` to be installed on the system.
346369

347370
.. data:: smp.optimizers.FusedNovoGrad
348-
371+
:noindex:
349372

350373
`Fused Novo Grad optimizer <https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedNovoGrad>`__ is
351374
currently not supported by the library. ``smp.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad``
352375
optimizer and provides the same functionality. This requires ``apex`` to
353376
be installed on the system.
354377

355378
.. data:: smp.optimizers.FusedLamb
356-
379+
:noindex:
357380

358381
`FusedLamb optimizer <https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedLAMB>`__
359382
currently doesn’t work with the library. ``smp.optimizers.FusedLamb`` replaces
360383
``apex`` ``FusedLamb`` optimizer and provides the same functionality.
361384
This requires ``apex`` to be installed on the system.
362385

363386
.. data:: smp.amp.GradScaler
387+
:noindex:
364388

365389
`Torch AMP Gradscaler <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`__
366390
currently doesn’t work with the library. ``smp.amp.GradScaler`` replaces
367391
``torch.amp.GradScaler`` and provides the same functionality.
368392

369393
.. _pytorch_saving_loading:
394+
:noindex:
370395

371396
APIs for Saving and Loading
372397
^^^^^^^^^^^^^^^^^^^^^^^^^^^
373398

374399
.. function:: smp.save( )
400+
:noindex:
375401

376402
Saves an object. This operation is similar to ``torch.save()``, except
377403
it has an additional keyword argument, ``partial``, and accepts only
@@ -394,6 +420,7 @@ APIs for Saving and Loading
394420
override the defaultprotocol.
395421

396422
.. function:: smp.load( )
423+
:noindex:
397424

398425
Loads an object saved with ``smp.save()`` from a file.
399426

@@ -418,6 +445,7 @@ APIs for Saving and Loading
418445
Should be used when loading a model trained with the library.
419446

420447
.. _pytorch_saving_loading_instructions:
448+
:noindex:
421449

422450
General Instruction For Saving and Loading
423451
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

0 commit comments

Comments
 (0)