88import torch
99import torch .nn .functional as F
1010from torch import nn
11+ from torch .distributed ._tensor import Replicate , Shard
12+ from torch .distributed .device_mesh import DeviceMesh
13+ from torch .distributed .tensor .parallel import (
14+ ColwiseParallel ,
15+ PrepareModuleInput ,
16+ RowwiseParallel ,
17+ SequenceParallel ,
18+ parallelize_module ,
19+ )
1120
1221
1322@dataclass
@@ -27,6 +36,7 @@ class ModelArgs:
2736 # If `True`, then each transformer block init uses its layer ID, and if
2837 # `False`, each uses the total number of transformer blocks
2938 depth_init : bool = True
39+ device : str = "cuda"
3040
3141
3242def precompute_freqs_cis (dim : int , end : int , theta : float = 10000.0 ) -> torch .Tensor :
@@ -168,14 +178,22 @@ class Attention(nn.Module):
168178 def __init__ (self , model_args : ModelArgs ):
169179 super ().__init__ ()
170180 self .n_heads = model_args .n_heads
171- self .n_kv_heads = model_args .n_heads if model_args .n_kv_heads is None else model_args .n_kv_heads
181+ self .n_kv_heads = (
182+ model_args .n_heads
183+ if model_args .n_kv_heads is None
184+ else model_args .n_kv_heads
185+ )
172186 self .n_rep = self .n_heads // self .n_kv_heads
173187 self .head_dim = model_args .dim // model_args .n_heads
174188
175- self .wq = nn .Linear (model_args .dim , model_args .n_heads * self .head_dim , bias = False )
189+ self .wq = nn .Linear (
190+ model_args .dim , model_args .n_heads * self .head_dim , bias = False
191+ )
176192 self .wk = nn .Linear (model_args .dim , self .n_kv_heads * self .head_dim , bias = False )
177193 self .wv = nn .Linear (model_args .dim , self .n_kv_heads * self .head_dim , bias = False )
178- self .wo = nn .Linear (model_args .n_heads * self .head_dim , model_args .dim , bias = False )
194+ self .wo = nn .Linear (
195+ model_args .n_heads * self .head_dim , model_args .dim , bias = False
196+ )
179197
180198 def init_weights (self , init_std : float ):
181199 for linear in (self .wq , self .wk , self .wv ):
@@ -216,7 +234,9 @@ def forward(
216234
217235 # we use casual mask for training
218236 output = F .scaled_dot_product_attention (xq , xk , xv , is_causal = True )
219- output = output .transpose (1 , 2 ).contiguous () # (bs, seqlen, n_local_heads, head_dim)
237+ output = output .transpose (
238+ 1 , 2
239+ ).contiguous () # (bs, seqlen, n_local_heads, head_dim)
220240 output = output .view (bs , seqlen , - 1 )
221241 return self .wo (output )
222242
@@ -330,7 +350,7 @@ def init_weights(self):
330350 self .feed_forward .init_weights (self .weight_init_std )
331351
332352
333- class Transformer (nn .Module ):
353+ class ParallelTransformer (nn .Module ):
334354 """Transformer Module.
335355
336356 Args:
@@ -348,13 +368,16 @@ class Transformer(nn.Module):
348368
349369 """
350370
351- def __init__ (self , model_args : ModelArgs ):
371+ def __init__ (self , model_args : ModelArgs , tp_mesh : DeviceMesh ):
372+ # Here we use distributed model initialization to avoid memory overflow
352373 super ().__init__ ()
353374 self .model_args = model_args
354375 self .vocab_size = model_args .vocab_size
355376 self .n_layers = model_args .n_layers
356377
357378 self .tok_embeddings = nn .Embedding (model_args .vocab_size , model_args .dim )
379+ self .tok_embeddings .to (model_args .device )
380+ self .tok_embeddings = self .parallel_embeddings (self .tok_embeddings , tp_mesh )
358381
359382 # TODO persistent should be set to false, since this buffer can be recomputed.
360383 # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
@@ -363,17 +386,83 @@ def __init__(self, model_args: ModelArgs):
363386 # a seed checkpoint rather than calling init_weights, we need freqs_cis to be
364387 # initialized by the checkpoint, or we need to add a separate initializer for
365388 # just the non-persistent buffers that is called after loading checkpoints.
366- self .register_buffer ("freqs_cis" , self ._precompute_freqs_cis (), persistent = True )
389+ self .register_buffer (
390+ "freqs_cis" ,
391+ self ._precompute_freqs_cis ().to (model_args .device ),
392+ persistent = True ,
393+ )
367394
368- self .layers = torch .nn .ModuleDict ()
395+ self .layers = torch .nn .ModuleDict (). to ( model_args . device )
369396 for layer_id in range (model_args .n_layers ):
370- self .layers [str (layer_id )] = TransformerBlock (layer_id , model_args )
397+ block = TransformerBlock (layer_id , model_args ).to (model_args .device )
398+ self .layers [str (layer_id )] = block
399+ self .parallel_transformer_block (self .layers [str (layer_id )], tp_mesh )
400+ print (layer_id )
371401
372- self .norm = RMSNorm (dim = model_args .dim , eps = model_args .norm_eps )
373-
374- self .output = nn .Linear (model_args .dim , model_args .vocab_size , bias = False )
402+ self .norm = RMSNorm (dim = model_args .dim , eps = model_args .norm_eps ).to (
403+ model_args .device
404+ )
405+ self .norm = self .parallel_norm (self .norm , tp_mesh )
406+ self .output = nn .Linear (model_args .dim , model_args .vocab_size , bias = False ).to (
407+ model_args .device
408+ )
409+ self .output = self .parallel_output (self .output , tp_mesh )
375410 self .init_weights ()
376411
412+ def parallel_transformer_block (self , transformer_block , tp_mesh ):
413+ if tp_mesh .size () <= 1 :
414+ return
415+ plan = {
416+ "attention" : PrepareModuleInput (
417+ input_layouts = (Shard (1 ), None ),
418+ desired_input_layouts = (Replicate (), None ),
419+ ),
420+ "attention.wq" : ColwiseParallel (),
421+ "attention.wk" : ColwiseParallel (),
422+ "attention.wv" : ColwiseParallel (),
423+ "attention.wo" : RowwiseParallel (output_layouts = Shard (1 )),
424+ "attention_norm" : SequenceParallel (),
425+ "feed_forward" : PrepareModuleInput (
426+ input_layouts = (Shard (1 ),),
427+ desired_input_layouts = (Replicate (),),
428+ ),
429+ "feed_forward.w1" : ColwiseParallel (),
430+ "feed_forward.w2" : RowwiseParallel (output_layouts = Shard (1 )),
431+ "feed_forward.w3" : ColwiseParallel (),
432+ "ffn_norm" : SequenceParallel (),
433+ }
434+
435+ # Adjust attention module to use the local number of heads
436+ attn_layer = transformer_block .attention
437+ attn_layer .n_heads = attn_layer .n_heads // tp_mesh .size ()
438+ attn_layer .n_kv_heads = attn_layer .n_kv_heads // tp_mesh .size ()
439+
440+ # Apply the plan for the current transformer block
441+ parallelize_module (transformer_block , tp_mesh , plan )
442+
443+ def parallel_embeddings (self , embedding , tp_mesh ):
444+ plan = {
445+ "tok_embeddings" : RowwiseParallel (
446+ input_layouts = Replicate (),
447+ output_layouts = Shard (1 ),
448+ )
449+ }
450+ return parallelize_module (embedding , tp_mesh , plan )
451+
452+ def parallel_output (self , output , tp_mesh ):
453+ plan = {
454+ "output" : ColwiseParallel (
455+ input_layouts = Shard (1 ),
456+ ),
457+ }
458+ return parallelize_module (output , tp_mesh , plan )
459+
460+ def parallel_norm (self , norm , tp_mesh ):
461+ plan = {
462+ "norm" : SequenceParallel (),
463+ }
464+ return parallelize_module (norm , tp_mesh , plan )
465+
377466 def reset_parameters (self ):
378467 with torch .device (self .freqs_cis .device ):
379468 self .freqs_cis = self ._precompute_freqs_cis ()
@@ -447,4 +536,4 @@ def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
447536 Transformer: Transformer model.
448537
449538 """
450- return cls (model_args )
539+ return cls (model_args )
0 commit comments