Skip to content

Conversation

@sfc-gh-sbekman
Copy link
Collaborator

@sfc-gh-sbekman sfc-gh-sbekman commented Feb 11, 2025

This PR is implementing/porting Sequence Parallelism via Deepspeed Ulysses

For PR reviewers:

Readiness status:

  • Code
  • Docs
  • Tests

Related:

Dependencies:

Comment on lines 327 to 328
# XXX: this was incorrect for GAS
return self.config.epochs * len(self.train_dataloader) # // self.config.gradient_accumulation_steps
Copy link
Collaborator Author

@sfc-gh-sbekman sfc-gh-sbekman May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-mwyatt, please confirm that my correction is kosher and I will remove comments - this PR makes partial progress on reporting and accounting with GAS>1 (enough to make loss and counters and wandb reporting correct, but it'll need more work to complete / make it smooth)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's because of SP? But then the original doesn't take into an account DP, so perhaps it needs more work?

Copy link
Collaborator Author

@sfc-gh-sbekman sfc-gh-sbekman May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with SP>1 len(self.train_dataloader) is sp_world_size*len(original_train_dataloader)

the training loop needs to make all the iterations - it can't make GAS times less iterations - it's only the accounting that should skip reporting any iterations that weren't at the GAS boundary.

stas00 added a commit to deepspeedai/DeepSpeed that referenced this pull request May 31, 2025
This is the Deepspeed counterpart of
snowflakedb/ArcticTraining#45 - as the new
feature(s) require changes on both sides.


For PR reviewers: 

Readiness status:
- [x] Code
- [x] Tests
- [ ] Docs - working on it


Features:

- [x] add support for delaying grad addition via
`param.ds_grad_is_ready` flag (used when performing tiled compute in an
autograd function)
- [x] add light sp-only mpu version (Jeff Rasley)
- [x] improved debug
- [x] added `all_gather_object` to `dist`
- [x] `UlyssesSPAttentionHF` (port of UlyssesAttention from
Megatron-Deepspeed plus modern MHA-variations)
- [x] `UlyssesSPDataLoaderAdapter` - DL adapter to shard the normal DL
batches to be used by `UlyssesSPAttentionHF`
- [x] `SequenceTiledCompute` - generic autograd function to perform
compute after tiling on the sequence dimension
- [x] `TiledMLP` - a specific autograd function to perform tiled MLP
(it's much easier to understand before trying to grok
`SequenceTiledCompute`)
- [x] added a differentiable `_DimZeroAllToAll` (Samyam Rajbhandari)
- [x] torch-dist-check now allows `torch.distributed.nn` (which is
needed since deepspeed's dist is not up to date with
`torch.distributed.nn`)

---------

Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Copy link
Collaborator

@sfc-gh-mwyatt sfc-gh-mwyatt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, however I'm confused about the SFTTrainer loss implementation. In the INTEGRATIONS doc file you show that the step function should call a different loss function when SP is enabled, but this is not reflected in the code.

Comment on lines 60 to 73
in `step`:

```
if self.config.sequence_parallel_size == 1:
# this is the original code
loss = self.loss(batch)
self.model.backward(loss)
...
else:
# sp will do backward inside sp_fwd_bwd_loss
# the returned loss is already averaged across ranks
loss = self.sp_fwd_bwd_loss(batch)
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem to be reflected in the code. Am I missing something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the doc is outdated - was written for the original implementation, as I mentioned in the slack post only .py code is ready for review.

I was planning to work on the doc, but got pulled into working on plots. Will get to it now.

Copy link
Collaborator Author

@sfc-gh-sbekman sfc-gh-sbekman Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have just rewritten them to reflect reality.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once reviewed I think we should move INTEGRATION.md to the deepspeed repo, what do you think? Since that's where the components are.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heads up - this doc has moved to deepspeedai/DeepSpeed#7331


logger:
level: WARNING
# level: INFO
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove these (presumably) debug comments before merging?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed most of these already, thought it'd be a good final version with an easy way for the user to switch between various options they are likely to want. These aren't debug.

e.g. I also left:

  #attn_implementation: sdpa

and datasets.

But I can remove them if you feel the user will benefit from not seeing other options they are likely to want to quickly turn on/off.

Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
@sfc-gh-sbekman
Copy link
Collaborator Author

@sfc-gh-mwyatt, docs are now ready for review.

Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
@sfc-gh-sbekman sfc-gh-sbekman merged commit 9d33fd2 into main Jun 3, 2025
4 checks passed
@sfc-gh-sbekman sfc-gh-sbekman deleted the stas/sp branch June 3, 2025 18:19
deepcharm pushed a commit to deepcharm/DeepSpeed that referenced this pull request Jun 16, 2025
This is the Deepspeed counterpart of
snowflakedb/ArcticTraining#45 - as the new
feature(s) require changes on both sides.

For PR reviewers:

Readiness status:
- [x] Code
- [x] Tests
- [ ] Docs - working on it

Features:

- [x] add support for delaying grad addition via
`param.ds_grad_is_ready` flag (used when performing tiled compute in an
autograd function)
- [x] add light sp-only mpu version (Jeff Rasley)
- [x] improved debug
- [x] added `all_gather_object` to `dist`
- [x] `UlyssesSPAttentionHF` (port of UlyssesAttention from
Megatron-Deepspeed plus modern MHA-variations)
- [x] `UlyssesSPDataLoaderAdapter` - DL adapter to shard the normal DL
batches to be used by `UlyssesSPAttentionHF`
- [x] `SequenceTiledCompute` - generic autograd function to perform
compute after tiling on the sequence dimension
- [x] `TiledMLP` - a specific autograd function to perform tiled MLP
(it's much easier to understand before trying to grok
`SequenceTiledCompute`)
- [x] added a differentiable `_DimZeroAllToAll` (Samyam Rajbhandari)
- [x] torch-dist-check now allows `torch.distributed.nn` (which is
needed since deepspeed's dist is not up to date with
`torch.distributed.nn`)

---------

Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Signed-off-by: Max Kovalenko <[email protected]>
Antlera pushed a commit to Antlera/DeepSpeed that referenced this pull request Jun 27, 2025
This is the Deepspeed counterpart of
snowflakedb/ArcticTraining#45 - as the new
feature(s) require changes on both sides.


For PR reviewers: 

Readiness status:
- [x] Code
- [x] Tests
- [ ] Docs - working on it


Features:

- [x] add support for delaying grad addition via
`param.ds_grad_is_ready` flag (used when performing tiled compute in an
autograd function)
- [x] add light sp-only mpu version (Jeff Rasley)
- [x] improved debug
- [x] added `all_gather_object` to `dist`
- [x] `UlyssesSPAttentionHF` (port of UlyssesAttention from
Megatron-Deepspeed plus modern MHA-variations)
- [x] `UlyssesSPDataLoaderAdapter` - DL adapter to shard the normal DL
batches to be used by `UlyssesSPAttentionHF`
- [x] `SequenceTiledCompute` - generic autograd function to perform
compute after tiling on the sequence dimension
- [x] `TiledMLP` - a specific autograd function to perform tiled MLP
(it's much easier to understand before trying to grok
`SequenceTiledCompute`)
- [x] added a differentiable `_DimZeroAllToAll` (Samyam Rajbhandari)
- [x] torch-dist-check now allows `torch.distributed.nn` (which is
needed since deepspeed's dist is not up to date with
`torch.distributed.nn`)

---------

Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
This is the Deepspeed counterpart of
snowflakedb/ArcticTraining#45 - as the new
feature(s) require changes on both sides.


For PR reviewers: 

Readiness status:
- [x] Code
- [x] Tests
- [ ] Docs - working on it


Features:

- [x] add support for delaying grad addition via
`param.ds_grad_is_ready` flag (used when performing tiled compute in an
autograd function)
- [x] add light sp-only mpu version (Jeff Rasley)
- [x] improved debug
- [x] added `all_gather_object` to `dist`
- [x] `UlyssesSPAttentionHF` (port of UlyssesAttention from
Megatron-Deepspeed plus modern MHA-variations)
- [x] `UlyssesSPDataLoaderAdapter` - DL adapter to shard the normal DL
batches to be used by `UlyssesSPAttentionHF`
- [x] `SequenceTiledCompute` - generic autograd function to perform
compute after tiling on the sequence dimension
- [x] `TiledMLP` - a specific autograd function to perform tiled MLP
(it's much easier to understand before trying to grok
`SequenceTiledCompute`)
- [x] added a differentiable `_DimZeroAllToAll` (Samyam Rajbhandari)
- [x] torch-dist-check now allows `torch.distributed.nn` (which is
needed since deepspeed's dist is not up to date with
`torch.distributed.nn`)

---------

Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants