@@ -135,11 +135,8 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
135
135
return ops .RoIPool ((pool_h , pool_w ), spatial_scale )(x , rois )
136
136
137
137
def get_script_fn (self , rois , pool_size ):
138
- @torch .jit .script
139
- def script_fn (input , rois , pool_size ):
140
- # type: (Tensor, Tensor, int) -> Tensor
141
- return ops .roi_pool (input , rois , pool_size , 1.0 )[0 ]
142
- return lambda x : script_fn (x , rois , pool_size )
138
+ scriped = torch .jit .script (ops .roi_pool )
139
+ return lambda x : scriped (x , rois , pool_size )
143
140
144
141
def expected_fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 ,
145
142
device = None , dtype = torch .float64 ):
@@ -177,11 +174,8 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
177
174
return ops .PSRoIPool ((pool_h , pool_w ), 1 )(x , rois )
178
175
179
176
def get_script_fn (self , rois , pool_size ):
180
- @torch .jit .script
181
- def script_fn (input , rois , pool_size ):
182
- # type: (Tensor, Tensor, int) -> Tensor
183
- return ops .ps_roi_pool (input , rois , pool_size , 1.0 )[0 ]
184
- return lambda x : script_fn (x , rois , pool_size )
177
+ scriped = torch .jit .script (ops .ps_roi_pool )
178
+ return lambda x : scriped (x , rois , pool_size )
185
179
186
180
def expected_fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 ,
187
181
device = None , dtype = torch .float64 ):
@@ -257,11 +251,8 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligne
257
251
sampling_ratio = sampling_ratio , aligned = aligned )(x , rois )
258
252
259
253
def get_script_fn (self , rois , pool_size ):
260
- @torch .jit .script
261
- def script_fn (input , rois , pool_size ):
262
- # type: (Tensor, Tensor, int) -> Tensor
263
- return ops .roi_align (input , rois , pool_size , 1.0 )[0 ]
264
- return lambda x : script_fn (x , rois , pool_size )
254
+ scriped = torch .jit .script (ops .roi_align )
255
+ return lambda x : scriped (x , rois , pool_size )
265
256
266
257
def expected_fn (self , in_data , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , aligned = False ,
267
258
device = None , dtype = torch .float64 ):
@@ -311,11 +302,8 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
311
302
sampling_ratio = sampling_ratio )(x , rois )
312
303
313
304
def get_script_fn (self , rois , pool_size ):
314
- @torch .jit .script
315
- def script_fn (input , rois , pool_size ):
316
- # type: (Tensor, Tensor, int) -> Tensor
317
- return ops .ps_roi_align (input , rois , pool_size , 1.0 )[0 ]
318
- return lambda x : script_fn (x , rois , pool_size )
305
+ scriped = torch .jit .script (ops .ps_roi_align )
306
+ return lambda x : scriped (x , rois , pool_size )
319
307
320
308
def expected_fn (self , in_data , rois , pool_h , pool_w , device , spatial_scale = 1 ,
321
309
sampling_ratio = - 1 , dtype = torch .float64 ):
0 commit comments