Skip to content

Commit df38515

Browse files
committed
LLaMA-Adapter support
1 parent e3574a1 commit df38515

File tree

3 files changed

+79
-46
lines changed

3 files changed

+79
-46
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
# LLaMa MPS fork
1+
# LLaMa MPS fork (llama-adapter branch)
22

3-
This is a fork of https://github.com/markasoftware/llama-cpu which is a fork of https://github.com/facebookresearch/llama. The goal of this fork is to use GPU acceleration on Apple M1/M2 devices.
3+
This is a fork of https://github.com/markasoftware/llama-cpu which is a fork of https://github.com/facebookresearch/llama. The goal of this fork is to use GPU acceleration on Apple M1/M2 devices.
4+
This branch provides support for [LLaMA-Adapter](https://github.com/ZrrSkywalker/LLaMA-Adapter)
45

56
Please check the original repos for installation instructions. After you're done, run this
6-
`torchrun example.py --ckpt_dir ../7B --tokenizer_path ../tokenizer.model --max_batch_size=1` with correct paths to the models. You might need to set up env. variable PYTORCH_ENABLE_MPS_FALLBACK=1
7+
`torchrun example.py --ckpt_dir ../7B --tokenizer_path ../tokenizer.model --max_batch_size=1 --adapter_path ../llama_adapter_len10_layer30_release.pth` with correct paths to the models. You might need to set up env. variable PYTORCH_ENABLE_MPS_FALLBACK=1
78

89
This fork is experimental, currently at the stage which allows to run a full non-quantized model with MPS.
910

example.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def setup_model_parallel() -> Tuple[int, int]:
3333
def load(
3434
ckpt_dir: str,
3535
tokenizer_path: str,
36+
adapter_path: str,
3637
local_rank: int,
3738
world_size: int,
3839
max_seq_len: int,
@@ -45,19 +46,26 @@ def load(
4546
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
4647
ckpt_path = checkpoints[local_rank]
4748
print("Loading")
48-
checkpoint = torch.load(ckpt_path, map_location="cpu")
49+
checkpoint = torch.load(ckpt_path, torch.device("cpu"))
50+
if adapter_path:
51+
adapter_checkpoint = torch.load(adapter_path, torch.device("cpu"))
4952
with open(Path(ckpt_dir) / "params.json", "r") as f:
5053
params = json.loads(f.read())
5154

5255
model_args: ModelArgs = ModelArgs(
5356
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
5457
)
58+
if adapter_path:
59+
model_args.adapter_layer = int(adapter_checkpoint['adapter_query.weight'].shape[0] / model_args.adapter_len)
5560
tokenizer = Tokenizer(model_path=tokenizer_path)
5661
model_args.vocab_size = tokenizer.n_words
5762
torch.set_default_tensor_type(torch.HalfTensor)
5863
model = Transformer(model_args)
59-
torch.set_default_tensor_type(torch.FloatTensor)
6064
model.load_state_dict(checkpoint, strict=False)
65+
if adapter_path:
66+
model.load_state_dict(adapter_checkpoint, strict=False)
67+
del adapter_checkpoint
68+
del checkpoint
6169
model = model.to('mps')
6270
generator = LLaMA(model, tokenizer)
6371
print(f"Loaded in {time.time() - start_time:.2f} seconds")
@@ -67,6 +75,7 @@ def load(
6775
def main(
6876
ckpt_dir: str,
6977
tokenizer_path: str,
78+
adapter_path: str = None,
7079
temperature: float = 0.8,
7180
top_p: float = 0.95,
7281
max_seq_len: int = 512,
@@ -77,47 +86,49 @@ def main(
7786
sys.stdout = open(os.devnull, "w")
7887

7988
generator = load(
80-
ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
89+
ckpt_dir, tokenizer_path, adapter_path, local_rank, world_size, max_seq_len, max_batch_size
8190
)
8291

83-
prompts = [
84-
# For these prompts, the expected answer is the natural continuation of the prompt
85-
"I believe the meaning of life is",
86-
"Simply put, the theory of relativity states that ",
87-
"Building a website can be done in 10 simple steps:\n",
88-
# Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api
89-
"""Tweet: "I hate it when my phone battery dies."
90-
Sentiment: Negative
91-
###
92-
Tweet: "My day has been 👍"
93-
Sentiment: Positive
94-
###
95-
Tweet: "This is the link to the article"
96-
Sentiment: Neutral
97-
###
98-
Tweet: "This new music video was incredibile"
99-
Sentiment:""",
100-
"""Translate English to French:
101-
102-
sea otter => loutre de mer
103-
104-
peppermint => menthe poivrée
105-
106-
plush girafe => girafe peluche
107-
108-
cheese =>""",
92+
PROMPT_DICT = {
93+
"prompt_input": (
94+
"Below is an instruction that describes a task, paired with an input that provides further context. "
95+
"Write a response that appropriately completes the request.\n\n"
96+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
97+
),
98+
"prompt_no_input": (
99+
"Below is an instruction that describes a task. "
100+
"Write a response that appropriately completes the request.\n\n"
101+
"### Instruction:\n{instruction}\n\n### Response:"
102+
),
103+
}
104+
105+
106+
instructs = [
107+
"Tell me about alpacas.",
108+
"Tell me about the president of Mexico in 2019.",
109+
"Tell me about the king of France in 2019.",
110+
"List all Canadian provinces in alphabetical order.",
111+
"Write a Python program that prints the first 10 Fibonacci numbers.",
112+
"Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.",
113+
"Tell me five words that rhyme with 'shock'.",
114+
"Translate the sentence 'I have no mouth but I must scream' into Spanish.",
115+
"Count up from 1 to 500."
109116
]
117+
prompts = [PROMPT_DICT['prompt_no_input'].format_map({'instruction':x, 'input': ''}) for x in instructs]
110118
# results = generator.generate(
111119
# prompts, max_gen_len=256, temperature=temperature, top_p=top_p
112120
# )
113-
results = [generator.generate(
114-
[prompt], max_gen_len=32, temperature=temperature, top_p=top_p
115-
) for prompt in prompts]
121+
gen_start_time = time.time()
122+
123+
with torch.inference_mode(mode=True):
124+
results = [generator.generate(
125+
[prompt], max_gen_len=32, temperature=temperature, top_p=top_p
126+
) for prompt in prompts]
116127

117128
for result in results:
118-
print("\n==================================\n")
119129
print(result)
120130
print("\n==================================\n")
131+
print(f"Generated in {time.time() - gen_start_time:.2f} seconds")
121132

122133

123134
if __name__ == "__main__":

llama/model.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ class ModelArgs:
2828

2929
max_batch_size: int = 32
3030
max_seq_len: int = 2048
31-
31+
32+
adapter_len: int=10
33+
adapter_layer: int=8
3234

3335
class RMSNorm(torch.nn.Module):
3436
def __init__(self, dim: int, eps: float = 1e-6):
@@ -65,10 +67,8 @@ def apply_rotary_emb(
6567
xk: torch.Tensor,
6668
freqs_cis: torch.Tensor,
6769
) -> Tuple[torch.Tensor, torch.Tensor]:
68-
xq = xq.to('cpu')
69-
xk = xk.to('cpu')
70-
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
71-
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
70+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2).to('cpu'))
71+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2).to('cpu'))
7272
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
7373
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
7474
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
@@ -117,8 +117,9 @@ def __init__(self, args: ModelArgs):
117117
self.cache_v = torch.zeros(
118118
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
119119
).to('mps')#.cuda()
120+
self.gate = torch.nn.Parameter(torch.zeros(1))
120121

121-
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
122+
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
122123
bsz, seqlen, _ = x.shape
123124
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
124125

@@ -137,6 +138,13 @@ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask
137138
keys = self.cache_k[:bsz, : start_pos + seqlen]
138139
values = self.cache_v[:bsz, : start_pos + seqlen]
139140

141+
if adapter is not None:
142+
adapter_len = adapter.shape[1]
143+
adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
144+
adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)
145+
adapter_k = adapter_k.transpose(1, 2)
146+
adapter_v = adapter_v.transpose(1, 2)
147+
140148
xq = xq.transpose(1, 2)
141149
keys = keys.transpose(1, 2)
142150
values = values.transpose(1, 2)
@@ -145,6 +153,10 @@ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask
145153
scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
146154
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
147155
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
156+
if adapter is not None:
157+
adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)
158+
adapter_scores = self.gate * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)
159+
output = output + torch.matmul(adapter_scores, adapter_v)
148160
output = output.transpose(
149161
1, 2
150162
).contiguous().view(bsz, seqlen, -1)
@@ -191,8 +203,8 @@ def __init__(self, layer_id: int, args: ModelArgs):
191203
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
192204
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
193205

194-
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
195-
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
206+
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
207+
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter)
196208
out = h + self.feed_forward.forward(self.ffn_norm(h))
197209
return out
198210

@@ -221,20 +233,29 @@ def __init__(self, params: ModelArgs):
221233
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
222234
)
223235

236+
self.adapter_query = nn.Embedding(params.adapter_len * params.adapter_layer, params.dim)
237+
self.adapter_len = params.adapter_len
238+
self.adapter_layer = params.adapter_layer
239+
224240
@torch.inference_mode()
225241
def forward(self, tokens: torch.Tensor, start_pos: int):
226242
_bsz, seqlen = tokens.shape
227243
h = self.tok_embeddings(tokens)
228244
#self.freqs_cis = self.freqs_cis.float().to(h.device)
229245
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
246+
prompt = self.adapter_query.weight.reshape(self.params.adapter_layer, self.params.adapter_len, self.params.dim).unsqueeze(1)
230247

231248
mask = None
232249
if seqlen > 1:
233250
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=torch.device('cpu'))
234251
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
235252

236-
for layer in self.layers:
237-
h = layer(h, start_pos, freqs_cis, (mask.to('mps') if mask is not None else mask))
253+
for layer in self.layers[: -1 * self.params.adapter_layer]:
254+
h = layer(h, start_pos, freqs_cis, (mask.to('mps') if mask is not None else None))
255+
layer_index = 0
256+
for layer in self.layers[-1 * self.params.adapter_layer:]:
257+
h = layer(h, start_pos, freqs_cis, (mask.to('mps') if mask is not None else None), prompt[layer_index])
258+
layer_index = layer_index + 1
238259
h = self.norm(h)
239260
output = self.output(h[:, -1, :]) # only compute last logits
240261
return output.float()

0 commit comments

Comments
 (0)