Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e60fb87
Simplify Tensor Parallel implementation with PyTorch TP
kwen2501 Oct 15, 2024
fd7f7c7
Move tp_plan to config
kwen2501 Oct 23, 2024
9224cab
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Oct 30, 2024
79cc524
Lint
kwen2501 Oct 30, 2024
a2934b3
Format and warning
kwen2501 Oct 30, 2024
a8fc418
Disable copy-from check
kwen2501 Oct 30, 2024
e84a388
Conditionally get attr from config
kwen2501 Oct 31, 2024
396d158
make fix-copies
kwen2501 Oct 31, 2024
7b346b5
Move base_model_tp_plan to PretrainedConfig
kwen2501 Oct 31, 2024
d60679b
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Oct 31, 2024
dda058a
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Nov 1, 2024
12fbbe7
Move TP into from_pretrained
kwen2501 Nov 7, 2024
02c8c39
Add device context for load
kwen2501 Nov 7, 2024
073c521
Do not serialize
kwen2501 Nov 7, 2024
db6e5ee
Move _tp_plan setting to post_init
kwen2501 Nov 7, 2024
5bb294e
Add has_tp_plan
kwen2501 Nov 14, 2024
290a7f1
Add test_tp
kwen2501 Nov 15, 2024
bd2e89c
Add 'Multi-gpu inference' doc
kwen2501 Nov 15, 2024
4892cef
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Nov 15, 2024
9648f31
Add backward support for device type identification
kwen2501 Nov 15, 2024
93ba283
Auto-detect accelerator
kwen2501 Nov 16, 2024
73524c9
supports_tp_plan
kwen2501 Nov 16, 2024
f312e55
copyright year
kwen2501 Nov 16, 2024
ca93bdb
Merge remote-tracking branch 'origin/main' into tp_llama
kwen2501 Nov 17, 2024
dc2672f
Merge branch 'main' into tp_llama
kwen2501 Nov 18, 2024
1e27d6f
Fix copy
kwen2501 Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
translate_to_torch_parallel_style,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
Expand Down Expand Up @@ -5013,6 +5014,41 @@ def _is_quantized_training_enabled(self):

return self.hf_quantizer.is_trainable

def tensor_parallel(self, device_mesh):
"""
Tensor parallelize the model across the given device mesh.

Args:
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""

# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan:
logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
# In model configs, we use a neutral type (string) to specify
# parallel styles, here we translate them into torch TP types.
# Using tree_map because `tp_plan` is a dict.
tp_plan = torch.utils._pytree.tree_map(
translate_to_torch_parallel_style,
tp_plan,
)
# Apply TP to current module.
torch.distributed.tensor.parallel.parallelize_module(
mod,
device_mesh=device_mesh,
parallelize_plan=tp_plan,
)

# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
self.apply(tplize)

@property
@lru_cache
def loss_function(self):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ class LlamaConfig(PretrainedConfig):

model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `LlamaModel`
_base_model_tp_plan = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to allow this for external use by removing _ so that we can allow users to define tp plan tweaks from config.json?

Given that, shall we as well allow for providing custom tp plan as input to LlamaConfig() that overrides the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good idea. We can make this public once we prove things work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, base_model_tp_plan should be supported as input to the PreTrainedConfig!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this variable base_model_tp_plan has to be added to PreTrainedConfig

class PretrainedConfig(PushToHubMixin):
with a default value as an empty dict {} which I believe is best possible default for any config sub class inheriting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kmehant @ArthurZucker for the suggestion. I moved base_model_tp_plan to PretrainedConfig in the latest commit.

"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
81 changes: 20 additions & 61 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn

Expand Down Expand Up @@ -240,25 +239,7 @@ def __init__(self, config):
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)

gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj


Expand Down Expand Up @@ -320,31 +301,14 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)

key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)

value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -386,12 +350,7 @@ def forward(

attn_output = attn_output.reshape(bsz, q_len, -1)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None
Expand Down Expand Up @@ -564,9 +523,10 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -850,8 +810,11 @@ def __init__(self, config: LlamaConfig):
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False

self.gradient_checkpointing = False
self._tp_plan = config._base_model_tp_plan
if config.pretraining_tp != 1:
logger.warn("`pretraining_tp` is deprecated, please use `tensor_parallel` method instead.")
# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -1113,6 +1076,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...

class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -1211,13 +1175,8 @@ def forward(
)

hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

loss = None
if labels is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo
class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/olmoe/modeling_olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ def _init_weights(self, module):
"The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
OLMOE_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
class OlmoeModel(OlmoePreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
return causal_mask


# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma
class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

Expand Down
24 changes: 24 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)

from .utils import is_torch_xla_available, logging

Expand Down Expand Up @@ -326,3 +331,22 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
else:
# Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
return torch.isin(elements, test_elements)


def translate_to_torch_parallel_style(style: str):
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we translate them into torch.distributed tensor-parallel
types.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")

if style == "colwise":
return ColwiseParallel()
elif style == "rowwise":
return RowwiseParallel()
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
else:
raise ValueError(f"Unsupported parallel style value: {style}")