@@ -238,24 +238,25 @@ def __init__(self, params: ModelArgs):
238238 self .adapter_layer = params .adapter_layer
239239
240240 @torch .inference_mode ()
241- def forward (self , tokens : torch .Tensor , start_pos : int ):
241+ def forward (self , tokens : torch .Tensor , start_pos : int , use_adapter ):
242242 _bsz , seqlen = tokens .shape
243243 h = self .tok_embeddings (tokens )
244244 #self.freqs_cis = self.freqs_cis.float().to(h.device)
245245 freqs_cis = self .freqs_cis [start_pos : start_pos + seqlen ]
246- prompt = self .adapter_query .weight .reshape (self .params .adapter_layer , self .params .adapter_len , self .params .dim ).unsqueeze (1 )
247246
248247 mask = None
249248 if seqlen > 1 :
250249 mask = torch .full ((1 , 1 , seqlen , seqlen ), float ("-inf" ), device = torch .device ('cpu' ))
251250 mask = torch .triu (mask , diagonal = start_pos + 1 ).type_as (h )
252251
253- for layer in self .layers [: - 1 * self .params .adapter_layer ]:
252+ for layer in ( self .layers [: - 1 * self .params .adapter_layer ]) if use_adapter else self . layers :
254253 h = layer (h , start_pos , freqs_cis , (mask .to ('mps' ) if mask is not None else None ))
255- layer_index = 0
256- for layer in self .layers [- 1 * self .params .adapter_layer :]:
257- h = layer (h , start_pos , freqs_cis , (mask .to ('mps' ) if mask is not None else None ), prompt [layer_index ])
258- layer_index = layer_index + 1
254+ if use_adapter :
255+ prompt = self .adapter_query .weight .reshape (self .params .adapter_layer , self .params .adapter_len , self .params .dim ).unsqueeze (1 )
256+ layer_index = 0
257+ for layer in self .layers [- 1 * self .params .adapter_layer :]:
258+ h = layer (h , start_pos , freqs_cis , (mask .to ('mps' ) if mask is not None else None ), prompt [layer_index ])
259+ layer_index = layer_index + 1
259260 h = self .norm (h )
260261 output = self .output (h [:, - 1 , :]) # only compute last logits
261262 return output .float ()
0 commit comments