@@ -140,17 +140,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
140140 new_item = new_item .replace ("norm.weight" , "group_norm.weight" )
141141 new_item = new_item .replace ("norm.bias" , "group_norm.bias" )
142142
143- new_item = new_item .replace ("q.weight" , "query .weight" )
144- new_item = new_item .replace ("q.bias" , "query .bias" )
143+ new_item = new_item .replace ("q.weight" , "to_q .weight" )
144+ new_item = new_item .replace ("q.bias" , "to_q .bias" )
145145
146- new_item = new_item .replace ("k.weight" , "key .weight" )
147- new_item = new_item .replace ("k.bias" , "key .bias" )
146+ new_item = new_item .replace ("k.weight" , "to_k .weight" )
147+ new_item = new_item .replace ("k.bias" , "to_k .bias" )
148148
149- new_item = new_item .replace ("v.weight" , "value .weight" )
150- new_item = new_item .replace ("v.bias" , "value .bias" )
149+ new_item = new_item .replace ("v.weight" , "to_v .weight" )
150+ new_item = new_item .replace ("v.bias" , "to_v .bias" )
151151
152- new_item = new_item .replace ("proj_out.weight" , "proj_attn .weight" )
153- new_item = new_item .replace ("proj_out.bias" , "proj_attn .bias" )
152+ new_item = new_item .replace ("proj_out.weight" , "to_out.0 .weight" )
153+ new_item = new_item .replace ("proj_out.bias" , "to_out.0 .bias" )
154154
155155 new_item = shave_segments (new_item , n_shave_prefix_segments = n_shave_prefix_segments )
156156
@@ -204,8 +204,12 @@ def assign_to_checkpoint(
204204 new_path = new_path .replace (replacement ["old" ], replacement ["new" ])
205205
206206 # proj_attn.weight has to be converted from conv 1D to linear
207- if "proj_attn.weight" in new_path :
207+ is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path )
208+ shape = old_checkpoint [path ["old" ]].shape
209+ if is_attn_weight and len (shape ) == 3 :
208210 checkpoint [new_path ] = old_checkpoint [path ["old" ]][:, :, 0 ]
211+ elif is_attn_weight and len (shape ) == 4 :
212+ checkpoint [new_path ] = old_checkpoint [path ["old" ]][:, :, 0 , 0 ]
209213 else :
210214 checkpoint [new_path ] = old_checkpoint [path ["old" ]]
211215
0 commit comments