-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Simplify Tensor Parallel implementation with PyTorch TP #34184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
e60fb87
fd7f7c7
9224cab
79cc524
a2934b3
a8fc418
e84a388
396d158
7b346b5
d60679b
dda058a
12fbbe7
02c8c39
073c521
db6e5ee
5bb294e
290a7f1
bd2e89c
4892cef
9648f31
93ba283
73524c9
f312e55
ca93bdb
dc2672f
1e27d6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -141,6 +141,16 @@ class LlamaConfig(PretrainedConfig): | |||
|
|
||||
| model_type = "llama" | ||||
| keys_to_ignore_at_inference = ["past_key_values"] | ||||
| # Default tensor parallel plan for base model `LlamaModel` | ||||
| _base_model_tp_plan = { | ||||
|
||||
| class PretrainedConfig(PushToHubMixin): |
{} which I believe is best possible default for any config sub class inheriting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @kmehant @ArthurZucker for the suggestion. I moved base_model_tp_plan to PretrainedConfig in the latest commit.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,11 @@ | |
| from packaging import version | ||
| from safetensors.torch import storage_ptr, storage_size | ||
| from torch import nn | ||
| from torch.distributed.tensor import Replicate | ||
| from torch.distributed.tensor.parallel import ( | ||
| ColwiseParallel, | ||
| RowwiseParallel, | ||
| ) | ||
|
|
||
| from .utils import is_torch_xla_available, logging | ||
|
|
||
|
|
@@ -326,3 +331,24 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) | |
| else: | ||
| # Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045 | ||
| return torch.isin(elements, test_elements) | ||
|
|
||
|
|
||
| def translate_to_torch_parallel_style(style: str): | ||
| """ | ||
| In model configurations, we use a neutral type (string) to specify parallel | ||
| styles, here we translate them into torch.distributed tensor-parallel | ||
| types. | ||
| """ | ||
| if not isinstance(style, str): | ||
| raise ValueError( | ||
| f"Unsupported parallel style type {type(style)}, expected str" | ||
| ) | ||
|
|
||
| if style == "colwise": | ||
| return ColwiseParallel() | ||
| elif style == "rowwise": | ||
| return RowwiseParallel() | ||
| elif style == "colwise_rep": | ||
| return ColwiseParallel(output_layouts=Replicate()) | ||
| else: | ||
| raise ValueError(f"Unsupported parallel style value: {style}") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A mapping from tp style to the correct function might be better.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the comment! Indeed a mapping style would look better.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pitching in :) We should be able to use the same object since it applies required parallel operation to the module and returns a new copy - https://github.com/pytorch/pytorch/blob/86d4b7d60b264cae5a04a1b20719bcd7a5752a4c/torch/distributed/tensor/parallel/api.py#L95 Have also tested it empirically while benchmarking (#34194) Thanks!
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SOunds good! |
||
Uh oh!
There was an error while loading. Please reload this page.