@@ -104,7 +104,7 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]])
104
104
# For output anchor, compute [x_center, y_center, x_center, y_center]
105
105
shifts_x = torch .arange (0 , grid_width , dtype = torch .int32 , device = device ) * stride_width
106
106
shifts_y = torch .arange (0 , grid_height , dtype = torch .int32 , device = device ) * stride_height
107
- shift_y , shift_x = torch .meshgrid (shifts_y , shifts_x )
107
+ shift_y , shift_x = torch .meshgrid (shifts_y , shifts_x , indexing = "ij" )
108
108
shift_x = shift_x .reshape (- 1 )
109
109
shift_y = shift_y .reshape (- 1 )
110
110
shifts = torch .stack ((shift_x , shift_y , shift_x , shift_y ), dim = 1 )
@@ -222,7 +222,7 @@ def _grid_default_boxes(
222
222
223
223
shifts_x = ((torch .arange (0 , f_k [1 ]) + 0.5 ) / x_f_k ).to (dtype = dtype )
224
224
shifts_y = ((torch .arange (0 , f_k [0 ]) + 0.5 ) / y_f_k ).to (dtype = dtype )
225
- shift_y , shift_x = torch .meshgrid (shifts_y , shifts_x )
225
+ shift_y , shift_x = torch .meshgrid (shifts_y , shifts_x , indexing = "ij" )
226
226
shift_x = shift_x .reshape (- 1 )
227
227
shift_y = shift_y .reshape (- 1 )
228
228
0 commit comments