Skip to content

Commit 22c9c54

Browse files
committed
Added Llama 3 Tensor Parallelism
1 parent 5967027 commit 22c9c54

File tree

2 files changed

+136
-95
lines changed

2 files changed

+136
-95
lines changed

examples/distributed_inference/llama3_model.py

Lines changed: 102 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88
import torch
99
import torch.nn.functional as F
1010
from torch import nn
11+
from torch.distributed._tensor import Replicate, Shard
12+
from torch.distributed.device_mesh import DeviceMesh
13+
from torch.distributed.tensor.parallel import (
14+
ColwiseParallel,
15+
PrepareModuleInput,
16+
RowwiseParallel,
17+
SequenceParallel,
18+
parallelize_module,
19+
)
1120

1221

1322
@dataclass
@@ -27,6 +36,7 @@ class ModelArgs:
2736
# If `True`, then each transformer block init uses its layer ID, and if
2837
# `False`, each uses the total number of transformer blocks
2938
depth_init: bool = True
39+
device: str = "cuda"
3040

3141

3242
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
@@ -168,14 +178,22 @@ class Attention(nn.Module):
168178
def __init__(self, model_args: ModelArgs):
169179
super().__init__()
170180
self.n_heads = model_args.n_heads
171-
self.n_kv_heads = model_args.n_heads if model_args.n_kv_heads is None else model_args.n_kv_heads
181+
self.n_kv_heads = (
182+
model_args.n_heads
183+
if model_args.n_kv_heads is None
184+
else model_args.n_kv_heads
185+
)
172186
self.n_rep = self.n_heads // self.n_kv_heads
173187
self.head_dim = model_args.dim // model_args.n_heads
174188

175-
self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False)
189+
self.wq = nn.Linear(
190+
model_args.dim, model_args.n_heads * self.head_dim, bias=False
191+
)
176192
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
177193
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
178-
self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False)
194+
self.wo = nn.Linear(
195+
model_args.n_heads * self.head_dim, model_args.dim, bias=False
196+
)
179197

180198
def init_weights(self, init_std: float):
181199
for linear in (self.wq, self.wk, self.wv):
@@ -216,7 +234,9 @@ def forward(
216234

217235
# we use casual mask for training
218236
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
219-
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
237+
output = output.transpose(
238+
1, 2
239+
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
220240
output = output.view(bs, seqlen, -1)
221241
return self.wo(output)
222242

@@ -330,7 +350,7 @@ def init_weights(self):
330350
self.feed_forward.init_weights(self.weight_init_std)
331351

332352

333-
class Transformer(nn.Module):
353+
class ParallelTransformer(nn.Module):
334354
"""Transformer Module.
335355
336356
Args:
@@ -348,13 +368,16 @@ class Transformer(nn.Module):
348368
349369
"""
350370

351-
def __init__(self, model_args: ModelArgs):
371+
def __init__(self, model_args: ModelArgs, tp_mesh: DeviceMesh):
372+
# Here we use distributed model initialization to avoid memory overflow
352373
super().__init__()
353374
self.model_args = model_args
354375
self.vocab_size = model_args.vocab_size
355376
self.n_layers = model_args.n_layers
356377

357378
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
379+
self.tok_embeddings.to(model_args.device)
380+
self.tok_embeddings = self.parallel_embeddings(self.tok_embeddings, tp_mesh)
358381

359382
# TODO persistent should be set to false, since this buffer can be recomputed.
360383
# however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
@@ -363,17 +386,83 @@ def __init__(self, model_args: ModelArgs):
363386
# a seed checkpoint rather than calling init_weights, we need freqs_cis to be
364387
# initialized by the checkpoint, or we need to add a separate initializer for
365388
# just the non-persistent buffers that is called after loading checkpoints.
366-
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
389+
self.register_buffer(
390+
"freqs_cis",
391+
self._precompute_freqs_cis().to(model_args.device),
392+
persistent=True,
393+
)
367394

368-
self.layers = torch.nn.ModuleDict()
395+
self.layers = torch.nn.ModuleDict().to(model_args.device)
369396
for layer_id in range(model_args.n_layers):
370-
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
397+
block = TransformerBlock(layer_id, model_args).to(model_args.device)
398+
self.layers[str(layer_id)] = block
399+
self.parallel_transformer_block(self.layers[str(layer_id)], tp_mesh)
400+
print(layer_id)
371401

372-
self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
373-
374-
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
402+
self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps).to(
403+
model_args.device
404+
)
405+
self.norm = self.parallel_norm(self.norm, tp_mesh)
406+
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False).to(
407+
model_args.device
408+
)
409+
self.output = self.parallel_output(self.output, tp_mesh)
375410
self.init_weights()
376411

412+
def parallel_transformer_block(self, transformer_block, tp_mesh):
413+
if tp_mesh.size() <= 1:
414+
return
415+
plan = {
416+
"attention": PrepareModuleInput(
417+
input_layouts=(Shard(1), None),
418+
desired_input_layouts=(Replicate(), None),
419+
),
420+
"attention.wq": ColwiseParallel(),
421+
"attention.wk": ColwiseParallel(),
422+
"attention.wv": ColwiseParallel(),
423+
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
424+
"attention_norm": SequenceParallel(),
425+
"feed_forward": PrepareModuleInput(
426+
input_layouts=(Shard(1),),
427+
desired_input_layouts=(Replicate(),),
428+
),
429+
"feed_forward.w1": ColwiseParallel(),
430+
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
431+
"feed_forward.w3": ColwiseParallel(),
432+
"ffn_norm": SequenceParallel(),
433+
}
434+
435+
# Adjust attention module to use the local number of heads
436+
attn_layer = transformer_block.attention
437+
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
438+
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
439+
440+
# Apply the plan for the current transformer block
441+
parallelize_module(transformer_block, tp_mesh, plan)
442+
443+
def parallel_embeddings(self, embedding, tp_mesh):
444+
plan = {
445+
"tok_embeddings": RowwiseParallel(
446+
input_layouts=Replicate(),
447+
output_layouts=Shard(1),
448+
)
449+
}
450+
return parallelize_module(embedding, tp_mesh, plan)
451+
452+
def parallel_output(self, output, tp_mesh):
453+
plan = {
454+
"output": ColwiseParallel(
455+
input_layouts=Shard(1),
456+
),
457+
}
458+
return parallelize_module(output, tp_mesh, plan)
459+
460+
def parallel_norm(self, norm, tp_mesh):
461+
plan = {
462+
"norm": SequenceParallel(),
463+
}
464+
return parallelize_module(norm, tp_mesh, plan)
465+
377466
def reset_parameters(self):
378467
with torch.device(self.freqs_cis.device):
379468
self.freqs_cis = self._precompute_freqs_cis()
@@ -447,4 +536,4 @@ def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
447536
Transformer: Transformer model.
448537
449538
"""
450-
return cls(model_args)
539+
return cls(model_args)

examples/distributed_inference/tensor_parallel_llama3.py

Lines changed: 34 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,29 @@
1+
import os
2+
import time
3+
14
import torch
25
import torch_tensorrt
3-
from llama3_model import Transformer, ModelArgs
6+
from llama3_model import ModelArgs, ParallelTransformer
47
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
58
from torch.distributed._composable.fsdp.fully_shard import fully_shard
69
from torch.distributed._tensor import Replicate, Shard
7-
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
8-
from torch.distributed.device_mesh import DeviceMesh
10+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
11+
checkpoint_wrapper,
12+
)
13+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
914
from torch.distributed.tensor.parallel import (
1015
ColwiseParallel,
1116
PrepareModuleInput,
1217
RowwiseParallel,
1318
SequenceParallel,
1419
parallelize_module,
1520
)
16-
import time
17-
from torch.distributed.device_mesh import init_device_mesh
18-
import os
1921

2022
# Taken and modified pytorch lightening
2123
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
22-
def parallelize(model: Transformer, tp_mesh: DeviceMesh) -> Transformer:
23-
"""Apply parallelisms and activation checkpointing to the model.
24-
25-
NOTE: The passed-in model preferably should be on meta device. Otherwise,
26-
the model must fit on GPU or CPU memory.
27-
28-
"""
2924

30-
if tp_mesh.size() > 1:
31-
# 1. Parallelize the first embedding and the last linear proj layer
32-
# 2. Parallelize the root norm layer over the sequence dim
33-
# 3. Shard the first transformer block's inputs
3425

35-
# Parallelize the first embedding and the last linear out projection
36-
plan = {
37-
"tok_embeddings": RowwiseParallel(input_layouts=Replicate(),
38-
output_layouts=Shard(1),),
39-
"output": ColwiseParallel(
40-
input_layouts=Shard(1),
41-
),
42-
"norm": SequenceParallel(),
43-
}
44-
model = parallelize_module(model, tp_mesh, plan)
45-
46-
# Parallelize each transformer block
47-
for transformer_block in model.layers.values():
48-
plan = {
49-
"attention": PrepareModuleInput(
50-
input_layouts=(Shard(1), None),
51-
desired_input_layouts=(Replicate(), None),
52-
),
53-
"attention.wq": ColwiseParallel(),
54-
"attention.wk": ColwiseParallel(),
55-
"attention.wv": ColwiseParallel(),
56-
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
57-
"attention_norm": SequenceParallel(),
58-
"feed_forward": PrepareModuleInput(
59-
input_layouts=(Shard(1),),
60-
desired_input_layouts=(Replicate(),),
61-
),
62-
"feed_forward.w1": ColwiseParallel(),
63-
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
64-
"feed_forward.w3": ColwiseParallel(),
65-
"ffn_norm": SequenceParallel(),
66-
}
67-
68-
# Adjust attention module to use the local number of heads
69-
attn_layer = transformer_block.attention
70-
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
71-
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
72-
73-
# Apply the plan for the current transformer block
74-
parallelize_module(transformer_block, tp_mesh, plan)
75-
76-
return model
77-
78-
79-
tp_size = 4
26+
tp_size = 8
8027

8128
# understand world topology
8229
_rank = int(os.environ["RANK"])
@@ -85,29 +32,34 @@ def parallelize(model: Transformer, tp_mesh: DeviceMesh) -> Transformer:
8532

8633
tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
8734

88-
model_args = ModelArgs(vocab_size=128256, dim=8192, n_layers=80, n_heads=64, rope_theta=500000.0, n_kv_heads=8)
89-
90-
# model_args = ModelArgs(vocab_size=32000, dim=2048, n_layers=8, n_heads=32)
91-
model = Transformer(model_args).to("cuda")
92-
model = parallelize(model, tp_mesh)
93-
model.eval()
94-
torch.manual_seed(0)
95-
inp = torch.randint(32000, (8, 256), device="cuda")
96-
python_result = model(inp)
97-
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
98-
model = torch.compile(
99-
model,
100-
fullgraph=True,
101-
backend="torch_tensorrt",
102-
options={
103-
"truncate_long_and_double": True,
104-
"enabled_precisions": {torch.float32, torch.float16},
105-
"use_python_runtime": True,
106-
},
107-
dynamic=False,
35+
model_args = ModelArgs(
36+
vocab_size=32000,
37+
dim=8192,
38+
n_layers=80,
39+
n_heads=64,
40+
rope_theta=500000.0,
41+
n_kv_heads=8,
42+
device="cuda",
10843
)
10944

11045
with torch.no_grad():
46+
model = ParallelTransformer(model_args, tp_mesh)
47+
torch.manual_seed(0)
48+
inp = torch.randint(32000, (8, 256), device="cuda")
49+
python_result = model(inp)
50+
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
51+
model = torch.compile(
52+
model,
53+
fullgraph=True,
54+
backend="torch_tensorrt",
55+
options={
56+
"truncate_long_and_double": True,
57+
"enabled_precisions": {torch.float32, torch.float16},
58+
"use_python_runtime": True,
59+
"workspace_size": 1 << 33,
60+
},
61+
dynamic=False,
62+
)
11163
for i in range(15):
11264
# seeding with dp_rank to ensure identical inputs for TP groups
11365
torch.manual_seed(i)

0 commit comments

Comments
 (0)