-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Simplify Tensor Parallel implementation with PyTorch TP #34184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
e60fb87
Simplify Tensor Parallel implementation with PyTorch TP
kwen2501 fd7f7c7
Move tp_plan to config
kwen2501 9224cab
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 79cc524
Lint
kwen2501 a2934b3
Format and warning
kwen2501 a8fc418
Disable copy-from check
kwen2501 e84a388
Conditionally get attr from config
kwen2501 396d158
make fix-copies
kwen2501 7b346b5
Move base_model_tp_plan to PretrainedConfig
kwen2501 d60679b
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 dda058a
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 12fbbe7
Move TP into from_pretrained
kwen2501 02c8c39
Add device context for load
kwen2501 073c521
Do not serialize
kwen2501 db6e5ee
Move _tp_plan setting to post_init
kwen2501 5bb294e
Add has_tp_plan
kwen2501 290a7f1
Add test_tp
kwen2501 bd2e89c
Add 'Multi-gpu inference' doc
kwen2501 4892cef
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 9648f31
Add backward support for device type identification
kwen2501 93ba283
Auto-detect accelerator
kwen2501 73524c9
supports_tp_plan
kwen2501 f312e55
copyright year
kwen2501 ca93bdb
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 dc2672f
Merge branch 'main' into tp_llama
kwen2501 1e27d6f
Fix copy
kwen2501 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -720,7 +720,11 @@ def __init__(self, config: GemmaConfig): | |
| [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | ||
| ) | ||
| self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
|
||
| self.gradient_checkpointing = False | ||
| self._tp_plan = config.base_model_tp_plan | ||
|
||
| if getattr(config, "pretraining_tp", 1) != 1: | ||
| logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") | ||
|
|
||
| # Initialize weights and apply final processing | ||
| self.post_init() | ||
|
|
@@ -982,6 +986,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( | |
|
|
||
| class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): | ||
| _tied_weights_keys = ["lm_head.weight"] | ||
| _tp_plan = {"lm_head": "colwise_rep"} | ||
|
|
||
| def __init__(self, config): | ||
| super().__init__(config) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.