88import torch
99import torch .nn .functional as F
1010from torch import nn
11+ from torch .distributed ._tensor import Replicate , Shard
12+ from torch .distributed .device_mesh import DeviceMesh
13+ from torch .distributed .tensor .parallel import (
14+ ColwiseParallel ,
15+ PrepareModuleInput ,
16+ RowwiseParallel ,
17+ SequenceParallel ,
18+ parallelize_module ,
19+ )
1120
1221
1322@dataclass
@@ -27,6 +36,7 @@ class ModelArgs:
2736 # If `True`, then each transformer block init uses its layer ID, and if
2837 # `False`, each uses the total number of transformer blocks
2938 depth_init : bool = True
39+ device : str = "cuda"
3040
3141
3242def precompute_freqs_cis (dim : int , end : int , theta : float = 10000.0 ) -> torch .Tensor :
@@ -168,14 +178,22 @@ class Attention(nn.Module):
168178 def __init__ (self , model_args : ModelArgs ):
169179 super ().__init__ ()
170180 self .n_heads = model_args .n_heads
171- self .n_kv_heads = model_args .n_heads if model_args .n_kv_heads is None else model_args .n_kv_heads
181+ self .n_kv_heads = (
182+ model_args .n_heads
183+ if model_args .n_kv_heads is None
184+ else model_args .n_kv_heads
185+ )
172186 self .n_rep = self .n_heads // self .n_kv_heads
173187 self .head_dim = model_args .dim // model_args .n_heads
174188
175- self .wq = nn .Linear (model_args .dim , model_args .n_heads * self .head_dim , bias = False )
189+ self .wq = nn .Linear (
190+ model_args .dim , model_args .n_heads * self .head_dim , bias = False
191+ )
176192 self .wk = nn .Linear (model_args .dim , self .n_kv_heads * self .head_dim , bias = False )
177193 self .wv = nn .Linear (model_args .dim , self .n_kv_heads * self .head_dim , bias = False )
178- self .wo = nn .Linear (model_args .n_heads * self .head_dim , model_args .dim , bias = False )
194+ self .wo = nn .Linear (
195+ model_args .n_heads * self .head_dim , model_args .dim , bias = False
196+ )
179197
180198 def init_weights (self , init_std : float ):
181199 for linear in (self .wq , self .wk , self .wv ):
@@ -216,7 +234,9 @@ def forward(
216234
217235 # we use casual mask for training
218236 output = F .scaled_dot_product_attention (xq , xk , xv , is_causal = True )
219- output = output .transpose (1 , 2 ).contiguous () # (bs, seqlen, n_local_heads, head_dim)
237+ output = output .transpose (
238+ 1 , 2
239+ ).contiguous () # (bs, seqlen, n_local_heads, head_dim)
220240 output = output .view (bs , seqlen , - 1 )
221241 return self .wo (output )
222242
@@ -330,7 +350,7 @@ def init_weights(self):
330350 self .feed_forward .init_weights (self .weight_init_std )
331351
332352
333- class Transformer (nn .Module ):
353+ class ParallelTransformer (nn .Module ):
334354 """Transformer Module.
335355
336356 Args:
@@ -348,13 +368,15 @@ class Transformer(nn.Module):
348368
349369 """
350370
351- def __init__ (self , model_args : ModelArgs ):
371+ def __init__ (self , model_args : ModelArgs , tp_mesh : DeviceMesh ):
352372 super ().__init__ ()
353373 self .model_args = model_args
354374 self .vocab_size = model_args .vocab_size
355375 self .n_layers = model_args .n_layers
356376
357377 self .tok_embeddings = nn .Embedding (model_args .vocab_size , model_args .dim )
378+ self .tok_embeddings .to (model_args .device )
379+ self .tok_embeddings = self .parallel_embeddings (self .tok_embeddings , tp_mesh )
358380
359381 # TODO persistent should be set to false, since this buffer can be recomputed.
360382 # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
@@ -363,17 +385,83 @@ def __init__(self, model_args: ModelArgs):
363385 # a seed checkpoint rather than calling init_weights, we need freqs_cis to be
364386 # initialized by the checkpoint, or we need to add a separate initializer for
365387 # just the non-persistent buffers that is called after loading checkpoints.
366- self .register_buffer ("freqs_cis" , self ._precompute_freqs_cis (), persistent = True )
388+ self .register_buffer (
389+ "freqs_cis" ,
390+ self ._precompute_freqs_cis ().to (model_args .device ),
391+ persistent = True ,
392+ )
367393
368- self .layers = torch .nn .ModuleDict ()
394+ self .layers = torch .nn .ModuleDict (). to ( model_args . device )
369395 for layer_id in range (model_args .n_layers ):
370- self .layers [str (layer_id )] = TransformerBlock (layer_id , model_args )
396+ block = TransformerBlock (layer_id , model_args ).to (model_args .device )
397+ self .layers [str (layer_id )] = block
398+ self .parallel_transformer_block (self .layers [str (layer_id )], tp_mesh )
399+ print (layer_id )
371400
372- self .norm = RMSNorm (dim = model_args .dim , eps = model_args .norm_eps )
373-
374- self .output = nn .Linear (model_args .dim , model_args .vocab_size , bias = False )
401+ self .norm = RMSNorm (dim = model_args .dim , eps = model_args .norm_eps ).to (
402+ model_args .device
403+ )
404+ self .norm = self .parallel_norm (self .norm , tp_mesh )
405+ self .output = nn .Linear (model_args .dim , model_args .vocab_size , bias = False ).to (
406+ model_args .device
407+ )
408+ self .output = self .parallel_output (self .output , tp_mesh )
375409 self .init_weights ()
376410
411+ def parallel_transformer_block (self , transformer_block , tp_mesh ):
412+ if tp_mesh .size () <= 1 :
413+ return
414+ plan = {
415+ "attention" : PrepareModuleInput (
416+ input_layouts = (Shard (1 ), None ),
417+ desired_input_layouts = (Replicate (), None ),
418+ ),
419+ "attention.wq" : ColwiseParallel (),
420+ "attention.wk" : ColwiseParallel (),
421+ "attention.wv" : ColwiseParallel (),
422+ "attention.wo" : RowwiseParallel (output_layouts = Shard (1 )),
423+ "attention_norm" : SequenceParallel (),
424+ "feed_forward" : PrepareModuleInput (
425+ input_layouts = (Shard (1 ),),
426+ desired_input_layouts = (Replicate (),),
427+ ),
428+ "feed_forward.w1" : ColwiseParallel (),
429+ "feed_forward.w2" : RowwiseParallel (output_layouts = Shard (1 )),
430+ "feed_forward.w3" : ColwiseParallel (),
431+ "ffn_norm" : SequenceParallel (),
432+ }
433+
434+ # Adjust attention module to use the local number of heads
435+ attn_layer = transformer_block .attention
436+ attn_layer .n_heads = attn_layer .n_heads // tp_mesh .size ()
437+ attn_layer .n_kv_heads = attn_layer .n_kv_heads // tp_mesh .size ()
438+
439+ # Apply the plan for the current transformer block
440+ parallelize_module (transformer_block , tp_mesh , plan )
441+
442+ def parallel_embeddings (self , embedding , tp_mesh ):
443+ plan = {
444+ "tok_embeddings" : RowwiseParallel (
445+ input_layouts = Replicate (),
446+ output_layouts = Shard (1 ),
447+ )
448+ }
449+ return parallelize_module (embedding , tp_mesh , plan )
450+
451+ def parallel_output (self , output , tp_mesh ):
452+ plan = {
453+ "output" : ColwiseParallel (
454+ input_layouts = Shard (1 ),
455+ ),
456+ }
457+ return parallelize_module (output , tp_mesh , plan )
458+
459+ def parallel_norm (self , norm , tp_mesh ):
460+ plan = {
461+ "norm" : SequenceParallel (),
462+ }
463+ return parallelize_module (norm , tp_mesh , plan )
464+
377465 def reset_parameters (self ):
378466 with torch .device (self .freqs_cis .device ):
379467 self .freqs_cis = self ._precompute_freqs_cis ()
@@ -447,4 +535,4 @@ def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
447535 Transformer: Transformer model.
448536
449537 """
450- return cls (model_args )
538+ return cls (model_args )
0 commit comments