Skip to content

Commit a9b319a

Browse files
committed
make the use of adapter optional in generation
1 parent df38515 commit a9b319a

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def main(
122122

123123
with torch.inference_mode(mode=True):
124124
results = [generator.generate(
125-
[prompt], max_gen_len=32, temperature=temperature, top_p=top_p
125+
[prompt], max_gen_len=32, temperature=temperature, top_p=top_p, use_adapter=bool(adapter_path)
126126
) for prompt in prompts]
127127

128128
for result in results:

llama/generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def generate(
2020
max_gen_len: int,
2121
temperature: float = 0.8,
2222
top_p: float = 0.95,
23+
use_adapter: bool = True
2324
) -> List[str]:
2425
bsz = len(prompts)
2526
params = self.model.params
@@ -39,7 +40,7 @@ def generate(
3940
start_pos = min_prompt_size
4041
prev_pos = 0
4142
for cur_pos in range(start_pos, total_len):
42-
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
43+
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos, use_adapter)
4344
if temperature > 0:
4445
probs = torch.softmax(logits / temperature, dim=-1)
4546
next_token = sample_top_p(probs, top_p)

llama/model.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,24 +238,25 @@ def __init__(self, params: ModelArgs):
238238
self.adapter_layer = params.adapter_layer
239239

240240
@torch.inference_mode()
241-
def forward(self, tokens: torch.Tensor, start_pos: int):
241+
def forward(self, tokens: torch.Tensor, start_pos: int, use_adapter):
242242
_bsz, seqlen = tokens.shape
243243
h = self.tok_embeddings(tokens)
244244
#self.freqs_cis = self.freqs_cis.float().to(h.device)
245245
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
246-
prompt = self.adapter_query.weight.reshape(self.params.adapter_layer, self.params.adapter_len, self.params.dim).unsqueeze(1)
247246

248247
mask = None
249248
if seqlen > 1:
250249
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=torch.device('cpu'))
251250
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
252251

253-
for layer in self.layers[: -1 * self.params.adapter_layer]:
252+
for layer in (self.layers[: -1 * self.params.adapter_layer]) if use_adapter else self.layers :
254253
h = layer(h, start_pos, freqs_cis, (mask.to('mps') if mask is not None else None))
255-
layer_index = 0
256-
for layer in self.layers[-1 * self.params.adapter_layer:]:
257-
h = layer(h, start_pos, freqs_cis, (mask.to('mps') if mask is not None else None), prompt[layer_index])
258-
layer_index = layer_index + 1
254+
if use_adapter:
255+
prompt = self.adapter_query.weight.reshape(self.params.adapter_layer, self.params.adapter_len, self.params.dim).unsqueeze(1)
256+
layer_index = 0
257+
for layer in self.layers[-1 * self.params.adapter_layer:]:
258+
h = layer(h, start_pos, freqs_cis, (mask.to('mps') if mask is not None else None), prompt[layer_index])
259+
layer_index = layer_index + 1
259260
h = self.norm(h)
260261
output = self.output(h[:, -1, :]) # only compute last logits
261262
return output.float()

0 commit comments

Comments
 (0)