|
4 | 4 | from torch.testing._internal.common_utils import run_tests |
5 | 5 | from torch_tensorrt import Input |
6 | 6 | from parameterized import parameterized |
7 | | -from .harness import DispatchTestCase |
| 7 | +from harness import DispatchTestCase |
8 | 8 |
|
9 | 9 | class TestGridConverter(DispatchTestCase): |
10 | 10 | @parameterized.expand( |
11 | 11 | [ |
12 | | - ("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0), |
13 | | - ("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1), |
14 | | - ("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2), |
15 | | - ("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0), |
16 | | - ("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1), |
17 | | - ("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2), |
18 | | - ("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0), |
19 | | - ("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1), |
20 | | - ("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2), |
| 12 | + ("input_grid_interpolation_nearest_sample_fill", [1,1,5,5], [1,5,2,2], 0, 0), |
| 13 | + ("input_grid_interpolation_nearest_sample_clamp", [1,1,5,5], [1,5,2,2], 0, 1), |
| 14 | + ("input_grid_interpolation_nearest_sample_reflect", [1,1,5,5], [1,5,2,2], 0, 2), |
| 15 | + ("input_grid_interpolation_linear_sample_fill", [1,1,5,5], [1,5,2,2], 1, 0), |
| 16 | + ("input_grid_interpolation_linear_sample_clamp", [1,1,5,5], [1,5,2,2], 1, 1), |
| 17 | + ("input_grid_interpolation_linear_sample_reflect", [1,1,5,5], [1,5,2,2], 1, 2), |
| 18 | + ("input_grid_interpolation_cubic_sample_fill", [1,1,5,5], [1,5,2,2], 2, 0), |
| 19 | + ("input_grid_interpolation_cubic_sample_clamp", [1,1,5,5], [1,5,2,2], 2, 1), |
| 20 | + ("input_grid_interpolation_cubic_sample_reflect", [1,1,5,5], [1,5,2,2], 2, 2), |
21 | 21 | ] |
22 | 22 | ) |
23 | | - def test_grid(self,_, input_shape, dim_shape, interpolation, sample): |
| 23 | + def test_grid(self, _, input_shape, dim_shape, interpolation, sample): |
24 | 24 | class TestModule(nn.Module): |
25 | | - def forward(self, x): |
26 | | - input = torch.randn(10).reshape(input_shape) |
27 | | - grid = torch.randint(-1, 1, dim_shape) |
28 | | - return nn.functional.grid(input, grid, interpolation, sample) |
29 | | - |
30 | | - inputs = [torch.randn(1, 10)] |
31 | | - self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out}) |
| 25 | + def forward(self, x): |
| 26 | + grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) |
| 27 | + return torch.ops.aten.grid_sampler(x, grid, interpolation, sample, True) |
| 28 | + inputs = [torch.randn(input_shape, dtype = torch.float32)] |
| 29 | + self.run_test(TestModule(), inputs) |
32 | 30 |
|
| 31 | +if __name__ == "__main__": |
| 32 | + run_tests() |
33 | 33 |
|
34 | 34 |
|
35 | 35 |
|
|
0 commit comments