[model loading] don't gc.collect() if only 1 shard is used
#36721
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes a test-related regression introduced by #36033. Despite speeding up
from_pretrainedin general, most tests that relied onfrom_pretrainedbecame much slower (see tests section below).Measuring time, the entire slowdown can be traced to a single line, a
gc.collect(). This line was not added in #36033, it was moved. Before #36033, if the checkpoint was not sharded, astate_dictwould have been passed to_load_pretrained_model, and thegc.collect()branch would not be reached.This PR adds a tiny
ifto skipgc.collect()if the checkpoint is not sharded. Since many tests rely on unshardedfrom_pretrained, we can immediately observe faster tests.Tests
py.test tests/models/gpt2/test_modeling_gpt2.pytimes, which includes a mix of tests with and withoutfrom_pretrained, on my machine:from_pretrained#36033: 25.41smain: 61.14sfrom_pretrained#36033 resulted in 15%+ test speedup, after this fix)