Skip to content

Commit 08fce7c

Browse files
fmassafacebook-github-bot
authored andcommitted
[fbsync] Pass indexing param to meshgrid to avoid warning (#4645)
Reviewed By: datumbox Differential Revision: D31898207 fbshipit-source-id: e23c0b73c6cad555f5247674c214e567eb87ba7f
1 parent 0b11da5 commit 08fce7c

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

test/test_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
def _create_video_frames(num_frames, height, width):
26-
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
26+
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width), indexing="ij")
2727
data = []
2828
for i in range(num_frames):
2929
xc = float(i) / num_frames

torchvision/models/detection/anchor_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]])
104104
# For output anchor, compute [x_center, y_center, x_center, y_center]
105105
shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
106106
shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
107-
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
107+
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
108108
shift_x = shift_x.reshape(-1)
109109
shift_y = shift_y.reshape(-1)
110110
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
@@ -222,7 +222,7 @@ def _grid_default_boxes(
222222

223223
shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
224224
shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
225-
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
225+
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
226226
shift_x = shift_x.reshape(-1)
227227
shift_y = shift_y.reshape(-1)
228228

0 commit comments

Comments
 (0)