diff --git a/tests/torchtune/models/clip/test_pos_embedding_interpolation.py b/tests/torchtune/models/clip/test_pos_embedding_interpolation.py index 559403d9f5..c46a4a28a9 100644 --- a/tests/torchtune/models/clip/test_pos_embedding_interpolation.py +++ b/tests/torchtune/models/clip/test_pos_embedding_interpolation.py @@ -201,6 +201,7 @@ def test_tile_resize_position_embedding(self, params): embedding, tgt_max_num_tiles ) + assert resized_pos_embed.is_contiguous() assert_expected(resized_pos_embed, expected_output, atol=1e-3, rtol=1e-4) @pytest.mark.parametrize("params", local_pos_emb_test_cases) @@ -215,6 +216,7 @@ def test_resize_local_position_embedding(self, params): ) ) + assert resized_pos_embed.is_contiguous() assert_expected(resized_pos_embed, expected_output, atol=1e-3, rtol=1e-4) @pytest.mark.parametrize("params", global_pos_emb_test_cases) @@ -230,6 +232,7 @@ def test_resize_global_position_embedding(self, params): ) ) + assert resized_pos_embed.is_contiguous() assert_expected(resized_pos_embed, expected_output, atol=1e-3, rtol=1e-4) @pytest.mark.parametrize( diff --git a/torchtune/models/clip/_position_embeddings.py b/torchtune/models/clip/_position_embeddings.py index cd1ea5947c..09a98862e1 100644 --- a/torchtune/models/clip/_position_embeddings.py +++ b/torchtune/models/clip/_position_embeddings.py @@ -319,7 +319,7 @@ def _resize_local_position_embedding( # add cls token back in local_pos_embed = torch.cat([cls_token, local_pos_embed], dim=0) - return local_pos_embed + return local_pos_embed.contiguous() # TODO: Switch to public method after 2.5 is stable @staticmethod @@ -436,7 +436,7 @@ def _resize_global_position_embedding( # add cls token back in global_pos_embed = torch.cat([cls_embed, pos_embed], dim=2) - return global_pos_embed + return global_pos_embed.contiguous() def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: """ @@ -633,7 +633,7 @@ def _resize_position_embedding( ) # permute to the original shape embedding = embedding.permute(2, 3, 0, 1) - return embedding + return embedding.contiguous() def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: """