Skip to content

Commit 74e9734

Browse files
committed
Make visual input/adapter use optional
1 parent fa7a04f commit 74e9734

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

demo.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,18 @@
1616
result = model.generate(img, [prompt])[0]
1717

1818
print(result)
19-
19+
# use llama with visual input and adapter
2020
prompt = llama.format_prompt('Describe this image.')
2121
result = model.generate(img, [prompt])[0]
2222

23+
print(result)
24+
# use llama with adapter without visual input
25+
prompt = llama.format_prompt('Give me a random number')
26+
result = model.generate([], [prompt])[0]
27+
28+
print(result)
29+
# use llama without visual model and adapter (useful for not alpaca-formatted tasks)
30+
prompt = llama.format_prompt('Give me a random number')
31+
result = model.generate([], [prompt], use_adapter=False)[0]
32+
2333
print(result)

llama/llama_adapter.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def clip_encode_image(self, x):
9494
return x
9595

9696
def forward_visual(self, imgs):
97+
if type(imgs) is not torch.Tensor:
98+
return None
9799
clip_feats = self.clip_encode_image(imgs)
98100
clip_feats = self.clip_proj_norm(self.clip_proj(clip_feats.half()))
99101

@@ -110,7 +112,7 @@ def forward_visual(self, imgs):
110112
return visual_query
111113

112114
@torch.inference_mode()
113-
def forward(self, visual_query, tokens, start_pos: int):
115+
def forward(self, visual_query, tokens, start_pos: int, use_adapter):
114116
_bsz, seqlen = tokens.shape
115117
h = self.llama.tok_embeddings(tokens)
116118
freqs_cis = self.llama.freqs_cis#.to(h.device)
@@ -120,17 +122,19 @@ def forward(self, visual_query, tokens, start_pos: int):
120122
float("-inf"), device=torch.device('cpu'))
121123
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
122124

123-
for layer in self.llama.layers[:-1 * self.query_layer]:
125+
for layer in (self.llama.layers[:-1 * self.query_layer] if use_adapter else self.llama.layers):
124126
h = layer(h, start_pos, freqs_cis, mask.to('mps') if mask is not None else None)
125127

126-
adapter = self.adapter_query.weight.reshape(
127-
self.query_layer, self.query_len, -1).unsqueeze(1)
128-
adapter_index = 0
129-
for layer in self.llama.layers[-1 * self.query_layer:]:
130-
dynamic_adapter = adapter[adapter_index].repeat(_bsz, 1, 1)
131-
dynamic_adapter = dynamic_adapter + visual_query
132-
h = layer(h, start_pos, freqs_cis, mask, dynamic_adapter)
133-
adapter_index = adapter_index + 1
128+
if use_adapter:
129+
adapter = self.adapter_query.weight.reshape(
130+
self.query_layer, self.query_len, -1).unsqueeze(1)
131+
adapter_index = 0
132+
for layer in self.llama.layers[-1 * self.query_layer:]:
133+
dynamic_adapter = adapter[adapter_index].repeat(_bsz, 1, 1)
134+
if visual_query is not None:
135+
dynamic_adapter = dynamic_adapter + visual_query
136+
h = layer(h, start_pos, freqs_cis, mask, dynamic_adapter)
137+
adapter_index = adapter_index + 1
134138

135139
h = self.llama.norm(h)
136140
output = self.llama.output(h[:, -1, :])
@@ -139,15 +143,18 @@ def forward(self, visual_query, tokens, start_pos: int):
139143

140144
@torch.inference_mode()
141145
def generate(
142-
self, imgs, prompts,
146+
self, imgs=None, prompts=None,
143147
max_gen_len: int = 256,
144148
temperature: float = 0.1,
145149
top_p: float = 0.75,
150+
use_adapter: bool = True
146151
):
147-
bsz = len(imgs)
152+
use_visual_input = type(imgs) is torch.Tensor
153+
bsz = len(imgs) if use_visual_input else len(prompts)
148154
params = self.llama.params
149155
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
150-
assert len(imgs) == len(prompts)
156+
if use_visual_input:
157+
assert len(imgs) == len(prompts)
151158

152159
visual_query = self.forward_visual(imgs)
153160

@@ -169,7 +176,7 @@ def generate(
169176
start_pos = min_prompt_size
170177
prev_pos = 0
171178
for cur_pos in range(start_pos, total_len):
172-
logits = self.forward(visual_query, tokens[:, prev_pos:cur_pos], prev_pos)
179+
logits = self.forward(visual_query, tokens[:, prev_pos:cur_pos], prev_pos, use_adapter)
173180
if temperature > 0:
174181
probs = torch.softmax(logits / temperature, dim=-1)
175182
next_token = sample_top_p(probs, top_p)

0 commit comments

Comments
 (0)