Skip to content

Commit 8cefd56

Browse files
committed
simplify _get_script_fn
1 parent afc502b commit 8cefd56

File tree

1 file changed

+8
-20
lines changed

1 file changed

+8
-20
lines changed

test/test_ops.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,8 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
135135
return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
136136

137137
def get_script_fn(self, rois, pool_size):
138-
@torch.jit.script
139-
def script_fn(input, rois, pool_size):
140-
# type: (Tensor, Tensor, int) -> Tensor
141-
return ops.roi_pool(input, rois, pool_size, 1.0)[0]
142-
return lambda x: script_fn(x, rois, pool_size)
138+
scriped = torch.jit.script(ops.roi_pool)
139+
return lambda x: scriped(x, rois, pool_size)
143140

144141
def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1,
145142
device=None, dtype=torch.float64):
@@ -177,11 +174,8 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
177174
return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
178175

179176
def get_script_fn(self, rois, pool_size):
180-
@torch.jit.script
181-
def script_fn(input, rois, pool_size):
182-
# type: (Tensor, Tensor, int) -> Tensor
183-
return ops.ps_roi_pool(input, rois, pool_size, 1.0)[0]
184-
return lambda x: script_fn(x, rois, pool_size)
177+
scriped = torch.jit.script(ops.ps_roi_pool)
178+
return lambda x: scriped(x, rois, pool_size)
185179

186180
def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1,
187181
device=None, dtype=torch.float64):
@@ -257,11 +251,8 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligne
257251
sampling_ratio=sampling_ratio, aligned=aligned)(x, rois)
258252

259253
def get_script_fn(self, rois, pool_size):
260-
@torch.jit.script
261-
def script_fn(input, rois, pool_size):
262-
# type: (Tensor, Tensor, int) -> Tensor
263-
return ops.roi_align(input, rois, pool_size, 1.0)[0]
264-
return lambda x: script_fn(x, rois, pool_size)
254+
scriped = torch.jit.script(ops.roi_align)
255+
return lambda x: scriped(x, rois, pool_size)
265256

266257
def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False,
267258
device=None, dtype=torch.float64):
@@ -311,11 +302,8 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
311302
sampling_ratio=sampling_ratio)(x, rois)
312303

313304
def get_script_fn(self, rois, pool_size):
314-
@torch.jit.script
315-
def script_fn(input, rois, pool_size):
316-
# type: (Tensor, Tensor, int) -> Tensor
317-
return ops.ps_roi_align(input, rois, pool_size, 1.0)[0]
318-
return lambda x: script_fn(x, rois, pool_size)
305+
scriped = torch.jit.script(ops.ps_roi_align)
306+
return lambda x: scriped(x, rois, pool_size)
319307

320308
def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
321309
sampling_ratio=-1, dtype=torch.float64):

0 commit comments

Comments
 (0)