Skip to content

[FIX] Garbage collect temp buffers after checkpoint#94

Merged
tyler-griggs merged 20 commits intomainfrom
tgriggs/ckpt-debug
Jul 17, 2025
Merged

[FIX] Garbage collect temp buffers after checkpoint#94
tyler-griggs merged 20 commits intomainfrom
tgriggs/ckpt-debug

Conversation

@tyler-griggs
Copy link
Member

@tyler-griggs tyler-griggs commented Jul 16, 2025

What does this PR do?

Resolve issue where offloading optimizers failed after checkpointing.

As reported in #70, OOMs can occur during inference engine wake-up after checkpointing. The root-cause was the state_dict materialization in save_ckpt created temporary buffers that were not garbage collected before we try to wakeup the inference engine kv cache, causing an OOM.

This PR executes the garbage collection and resolves the OOM issue.

Tests

Added a GPU test to check for successful offloading after checkpointing. It fails before this PR.

What's next?

We should switch to using pytorch's distributed checkpointing APIs for checkpointing, which is much simpler.

@tyler-griggs tyler-griggs marked this pull request as ready for review July 17, 2025 00:37
"fsdp2",
],
)
def test_offload_after_ckpt(strategy):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: does the test fail before this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed! Should have mentioned that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great

Copy link
Member

@SumanthRH SumanthRH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a minor comment. Thanks!

@tyler-griggs tyler-griggs changed the title Tgriggs/ckpt debug [FIX] Garbage collect temp buffers after checkpoint Jul 17, 2025
@tyler-griggs tyler-griggs merged commit f556802 into main Jul 17, 2025
3 checks passed
@SumanthRH SumanthRH deleted the tgriggs/ckpt-debug branch July 23, 2025 08:25
fannie1208 pushed a commit to vinid/SkyRL that referenced this pull request Aug 19, 2025
## What does this PR do?
Resolve issue where offloading optimizers failed after checkpointing. 

As reported in NovaSky-AI#70, OOMs can occur during inference engine wake-up after
checkpointing. The root-cause was the `state_dict` materialization in
`save_ckpt` created temporary buffers that were not garbage collected
before we try to wakeup the inference engine kv cache, causing an OOM.

This PR executes the garbage collection and resolves the OOM issue.

## Tests
Added a GPU test to check for successful offloading after checkpointing.
It fails before this PR.

## What's next?
We should switch to using pytorch's distributed checkpointing APIs for
checkpointing, which is much simpler.

---------

Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants