@@ -379,3 +379,78 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
379379 progress = progress ,
380380 ** kwargs ,
381381 )
382+
383+
384+ def interpolate_embeddings (
385+ image_size : int ,
386+ patch_size : int ,
387+ model_state : "OrderedDict[str, torch.Tensor]" ,
388+ interpolation_mode : str = "bicubic" ,
389+ reset_heads : bool = False ,
390+ ) -> "OrderedDict[str, torch.Tensor]" :
391+ """This function helps interpolating positional embeddings during checkpoint loading,
392+ especially when you want to apply a pre-trained model on images with different resolution.
393+
394+ Args:
395+ image_size (int): Image size of the new model.
396+ patch_size (int): Patch size of the new model.
397+ model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
398+ interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
399+ reset_heads (bool): If true, not copying the state of heads. Default: False.
400+
401+ Returns:
402+ OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
403+ """
404+ # Shape of pos_embedding is (1, seq_length, hidden_dim)
405+ pos_embedding = model_state ["encoder.pos_embedding" ]
406+ n , seq_length , hidden_dim = pos_embedding .shape
407+ if n != 1 :
408+ raise ValueError (f"Unexpected position embedding shape: { pos_embedding .shape } " )
409+
410+ new_seq_length = (image_size // patch_size ) ** 2 + 1
411+
412+ # Need to interpolate the weights for the position embedding.
413+ # We do this by reshaping the positions embeddings to a 2d grid, performing
414+ # an interpolation in the (h, w) space and then reshaping back to a 1d grid.
415+ if new_seq_length != seq_length :
416+ # The class token embedding shouldn't be interpolated so we split it up.
417+ seq_length -= 1
418+ new_seq_length -= 1
419+ pos_embedding_token = pos_embedding [:, :1 , :]
420+ pos_embedding_img = pos_embedding [:, 1 :, :]
421+
422+ # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
423+ pos_embedding_img = pos_embedding_img .permute (0 , 2 , 1 )
424+ seq_length_1d = int (math .sqrt (seq_length ))
425+ torch ._assert (seq_length_1d * seq_length_1d == seq_length , "seq_length is not a perfect square!" )
426+
427+ # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
428+ pos_embedding_img = pos_embedding_img .reshape (1 , hidden_dim , seq_length_1d , seq_length_1d )
429+ new_seq_length_1d = image_size // patch_size
430+
431+ # Perform interpolation.
432+ # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
433+ new_pos_embedding_img = nn .functional .interpolate (
434+ pos_embedding_img ,
435+ size = new_seq_length_1d ,
436+ mode = interpolation_mode ,
437+ align_corners = True ,
438+ )
439+
440+ # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
441+ new_pos_embedding_img = new_pos_embedding_img .reshape (1 , hidden_dim , new_seq_length )
442+
443+ # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
444+ new_pos_embedding_img = new_pos_embedding_img .permute (0 , 2 , 1 )
445+ new_pos_embedding = torch .cat ([pos_embedding_token , new_pos_embedding_img ], dim = 1 )
446+
447+ model_state ["encoder.pos_embedding" ] = new_pos_embedding
448+
449+ if reset_heads :
450+ model_state_copy : "OrderedDict[str, torch.Tensor]" = OrderedDict ()
451+ for k , v in model_state .items ():
452+ if not k .startswith ("heads" ):
453+ model_state_copy [k ] = v
454+ model_state = model_state_copy
455+
456+ return model_state
0 commit comments