@@ -94,6 +94,8 @@ def clip_encode_image(self, x):
9494 return x
9595
9696 def forward_visual (self , imgs ):
97+ if type (imgs ) is not torch .Tensor :
98+ return None
9799 clip_feats = self .clip_encode_image (imgs )
98100 clip_feats = self .clip_proj_norm (self .clip_proj (clip_feats .half ()))
99101
@@ -110,7 +112,7 @@ def forward_visual(self, imgs):
110112 return visual_query
111113
112114 @torch .inference_mode ()
113- def forward (self , visual_query , tokens , start_pos : int ):
115+ def forward (self , visual_query , tokens , start_pos : int , use_adapter ):
114116 _bsz , seqlen = tokens .shape
115117 h = self .llama .tok_embeddings (tokens )
116118 freqs_cis = self .llama .freqs_cis #.to(h.device)
@@ -120,17 +122,19 @@ def forward(self, visual_query, tokens, start_pos: int):
120122 float ("-inf" ), device = torch .device ('cpu' ))
121123 mask = torch .triu (mask , diagonal = start_pos + 1 ).type_as (h )
122124
123- for layer in self .llama .layers [:- 1 * self .query_layer ]:
125+ for layer in ( self .llama .layers [:- 1 * self .query_layer ] if use_adapter else self . llama . layers ) :
124126 h = layer (h , start_pos , freqs_cis , mask .to ('mps' ) if mask is not None else None )
125127
126- adapter = self .adapter_query .weight .reshape (
127- self .query_layer , self .query_len , - 1 ).unsqueeze (1 )
128- adapter_index = 0
129- for layer in self .llama .layers [- 1 * self .query_layer :]:
130- dynamic_adapter = adapter [adapter_index ].repeat (_bsz , 1 , 1 )
131- dynamic_adapter = dynamic_adapter + visual_query
132- h = layer (h , start_pos , freqs_cis , mask , dynamic_adapter )
133- adapter_index = adapter_index + 1
128+ if use_adapter :
129+ adapter = self .adapter_query .weight .reshape (
130+ self .query_layer , self .query_len , - 1 ).unsqueeze (1 )
131+ adapter_index = 0
132+ for layer in self .llama .layers [- 1 * self .query_layer :]:
133+ dynamic_adapter = adapter [adapter_index ].repeat (_bsz , 1 , 1 )
134+ if visual_query is not None :
135+ dynamic_adapter = dynamic_adapter + visual_query
136+ h = layer (h , start_pos , freqs_cis , mask , dynamic_adapter )
137+ adapter_index = adapter_index + 1
134138
135139 h = self .llama .norm (h )
136140 output = self .llama .output (h [:, - 1 , :])
@@ -139,15 +143,18 @@ def forward(self, visual_query, tokens, start_pos: int):
139143
140144 @torch .inference_mode ()
141145 def generate (
142- self , imgs , prompts ,
146+ self , imgs = None , prompts = None ,
143147 max_gen_len : int = 256 ,
144148 temperature : float = 0.1 ,
145149 top_p : float = 0.75 ,
150+ use_adapter : bool = True
146151 ):
147- bsz = len (imgs )
152+ use_visual_input = type (imgs ) is torch .Tensor
153+ bsz = len (imgs ) if use_visual_input else len (prompts )
148154 params = self .llama .params
149155 assert bsz <= params .max_batch_size , (bsz , params .max_batch_size )
150- assert len (imgs ) == len (prompts )
156+ if use_visual_input :
157+ assert len (imgs ) == len (prompts )
151158
152159 visual_query = self .forward_visual (imgs )
153160
@@ -169,7 +176,7 @@ def generate(
169176 start_pos = min_prompt_size
170177 prev_pos = 0
171178 for cur_pos in range (start_pos , total_len ):
172- logits = self .forward (visual_query , tokens [:, prev_pos :cur_pos ], prev_pos )
179+ logits = self .forward (visual_query , tokens [:, prev_pos :cur_pos ], prev_pos , use_adapter )
173180 if temperature > 0 :
174181 probs = torch .softmax (logits / temperature , dim = - 1 )
175182 next_token = sample_top_p (probs , top_p )
0 commit comments