@@ -141,7 +141,7 @@ def __init__(
141
141
embed_dim = dim ,
142
142
num_heads = num_heads ,
143
143
projection_size = dim ,
144
- use_qkv_parallel = False ,
144
+ use_qkv_parallel = True ,
145
145
use_context_forward = use_context_forward ,
146
146
softmax_in_single_precision = softmax_in_single_precision ,
147
147
flatten_batch = flatten_batch ,
@@ -325,7 +325,7 @@ def get_window_index(self, grid_thw):
325
325
326
326
@property
327
327
def dtype (self ) -> torch .dtype :
328
- return self .blocks [ 0 ]. mlp . gate_proj .weight .dtype
328
+ return self .patch_embed . proj .weight .dtype
329
329
330
330
@property
331
331
def device (self ) -> torch .device :
@@ -572,10 +572,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
572
572
("gate_up_proj" , "up_proj" , 1 ),
573
573
("gate_up_proj" , "gate_proj" , 0 ),
574
574
]
575
- # Just for bnb 4bit
576
- is_bnb_weights = hasattr (
577
- weights , "gi_code"
578
- ) and weights .gi_code .co_name .startswith ("_quantized_4bit_generator" )
579
575
params_dict = dict (self .named_parameters (remove_duplicate = False ))
580
576
for name , loaded_weight in weights :
581
577
if "rotary_emb.inv_freq" in name :
@@ -596,23 +592,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
596
592
weight_loader (param , loaded_weight , shard_id )
597
593
break
598
594
else :
599
- if "visual" in name and "qkv.weight" in name and not is_bnb_weights :
600
- visual_num_heads = self .config .vision_config .num_heads
601
- visual_embed_dim = self .config .vision_config .hidden_size
602
- head_size = visual_embed_dim // visual_num_heads
603
- loaded_weight = loaded_weight .view (
604
- 3 , visual_num_heads , head_size , visual_embed_dim
605
- )
606
- loaded_weight = loaded_weight .transpose (0 , 1 )
607
- loaded_weight = loaded_weight .reshape (- 1 , visual_embed_dim )
608
- elif "visual" in name and "qkv.bias" in name and not is_bnb_weights :
609
- visual_num_heads = self .config .vision_config .num_heads
610
- visual_embed_dim = self .config .vision_config .hidden_size
611
- head_size = visual_embed_dim // visual_num_heads
612
- loaded_weight = loaded_weight .view (3 , visual_num_heads , head_size )
613
- loaded_weight = loaded_weight .transpose (0 , 1 )
614
- loaded_weight = loaded_weight .reshape (- 1 )
615
-
616
595
if "visual" in name :
617
596
# adapt to VisionAttention
618
597
name = name .replace (r"attn.qkv." , r"attn.qkv_proj." )
0 commit comments