@@ -27,66 +27,88 @@ class OpSplitWithSizesCopyOutTest : public OperatorTest {
27
27
return torch::executor::aten::split_with_sizes_copy_outf (
28
28
context_, self, split_sizes, dim, out);
29
29
}
30
+
31
+ void test_tensor_shape_dynamism (exec_aten::TensorShapeDynamism dynamism) {
32
+ torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float>
33
+ tfFloat;
34
+
35
+ exec_aten::Tensor self = tfFloat.make (
36
+ {2 , 6 , 3 },
37
+ {-31.25 , -92.75 , -39.75 , -3.25 , 53.875 , 88.25 , -0.625 , -1.125 ,
38
+ 14.75 , 42.0 , 89.875 , -21.125 , -8.0 , -64.125 , 23.0 , 37.0 ,
39
+ 46.125 , -83.25 , -58.125 , 19.625 , -71.125 , 64.75 , -1.375 , -83.5 ,
40
+ -61.375 , 13.125 , 28.625 , -94.0 , -67.0 , -8.625 , -88.875 , -79.125 ,
41
+ 0.375 , -61.375 , 65.0 , -99.375 });
42
+ ::std::vector<int64_t > split_sizes_vec = {3 , 1 , 2 };
43
+ exec_aten::ArrayRef<int64_t > split_sizes = exec_aten::ArrayRef<int64_t >(
44
+ split_sizes_vec.data (), split_sizes_vec.size ());
45
+ int64_t dim = 1 ;
46
+
47
+ ::std::vector<exec_aten::Tensor> out_vec;
48
+ if (dynamism == exec_aten::TensorShapeDynamism::STATIC) {
49
+ out_vec = {
50
+ tfFloat.zeros ({2 , 3 , 3 }),
51
+ tfFloat.zeros ({2 , 1 , 3 }),
52
+ tfFloat.zeros ({2 , 2 , 3 })};
53
+ } else { // dynamism == exec_aten::TensorShapeDynamism::DYNAMIC_BOUND
54
+ out_vec = {
55
+ tfFloat.zeros (
56
+ {2 , 3 , 10 }, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND),
57
+ tfFloat.zeros (
58
+ {2 , 1 , 10 }, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND),
59
+ tfFloat.zeros (
60
+ {2 , 2 , 10 }, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND)};
61
+ }
62
+
63
+ exec_aten::TensorList out =
64
+ exec_aten::TensorList (out_vec.data (), out_vec.size ());
65
+ ::std::vector<exec_aten::Tensor> out_expected_vec = {
66
+ tfFloat.make (
67
+ {2 , 3 , 3 },
68
+ {-31.25 ,
69
+ -92.75 ,
70
+ -39.75 ,
71
+ -3.25 ,
72
+ 53.875 ,
73
+ 88.25 ,
74
+ -0.625 ,
75
+ -1.125 ,
76
+ 14.75 ,
77
+ -58.125 ,
78
+ 19.625 ,
79
+ -71.125 ,
80
+ 64.75 ,
81
+ -1.375 ,
82
+ -83.5 ,
83
+ -61.375 ,
84
+ 13.125 ,
85
+ 28.625 }),
86
+ tfFloat.make ({2 , 1 , 3 }, {42.0 , 89.875 , -21.125 , -94.0 , -67.0 , -8.625 }),
87
+ tfFloat.make (
88
+ {2 , 2 , 3 },
89
+ {-8.0 ,
90
+ -64.125 ,
91
+ 23.0 ,
92
+ 37.0 ,
93
+ 46.125 ,
94
+ -83.25 ,
95
+ -88.875 ,
96
+ -79.125 ,
97
+ 0.375 ,
98
+ -61.375 ,
99
+ 65.0 ,
100
+ -99.375 })};
101
+ exec_aten::TensorList out_expected =
102
+ exec_aten::TensorList (out_expected_vec.data (), out_expected_vec.size ());
103
+ op_split_with_sizes_copy_out (self, split_sizes, dim, out);
104
+ EXPECT_TENSOR_LISTS_CLOSE (out, out_expected);
105
+ }
30
106
};
31
107
32
108
TEST_F (OpSplitWithSizesCopyOutTest, SanityCheckDim1) {
33
- torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
109
+ test_tensor_shape_dynamism (exec_aten::TensorShapeDynamism::STATIC);
110
+ }
34
111
35
- exec_aten::Tensor self = tfFloat.make (
36
- {2 , 6 , 3 },
37
- {-31.25 , -92.75 , -39.75 , -3.25 , 53.875 , 88.25 , -0.625 , -1.125 ,
38
- 14.75 , 42.0 , 89.875 , -21.125 , -8.0 , -64.125 , 23.0 , 37.0 ,
39
- 46.125 , -83.25 , -58.125 , 19.625 , -71.125 , 64.75 , -1.375 , -83.5 ,
40
- -61.375 , 13.125 , 28.625 , -94.0 , -67.0 , -8.625 , -88.875 , -79.125 ,
41
- 0.375 , -61.375 , 65.0 , -99.375 });
42
- ::std::vector<int64_t > split_sizes_vec = {3 , 1 , 2 };
43
- exec_aten::ArrayRef<int64_t > split_sizes = exec_aten::ArrayRef<int64_t >(
44
- split_sizes_vec.data (), split_sizes_vec.size ());
45
- int64_t dim = 1 ;
46
- ::std::vector<exec_aten::Tensor> out_vec = {
47
- tfFloat.zeros ({2 , 3 , 3 }),
48
- tfFloat.zeros ({2 , 1 , 3 }),
49
- tfFloat.zeros ({2 , 2 , 3 })};
50
- exec_aten::TensorList out =
51
- exec_aten::TensorList (out_vec.data (), out_vec.size ());
52
- ::std::vector<exec_aten::Tensor> out_expected_vec = {
53
- tfFloat.make (
54
- {2 , 3 , 3 },
55
- {-31.25 ,
56
- -92.75 ,
57
- -39.75 ,
58
- -3.25 ,
59
- 53.875 ,
60
- 88.25 ,
61
- -0.625 ,
62
- -1.125 ,
63
- 14.75 ,
64
- -58.125 ,
65
- 19.625 ,
66
- -71.125 ,
67
- 64.75 ,
68
- -1.375 ,
69
- -83.5 ,
70
- -61.375 ,
71
- 13.125 ,
72
- 28.625 }),
73
- tfFloat.make ({2 , 1 , 3 }, {42.0 , 89.875 , -21.125 , -94.0 , -67.0 , -8.625 }),
74
- tfFloat.make (
75
- {2 , 2 , 3 },
76
- {-8.0 ,
77
- -64.125 ,
78
- 23.0 ,
79
- 37.0 ,
80
- 46.125 ,
81
- -83.25 ,
82
- -88.875 ,
83
- -79.125 ,
84
- 0.375 ,
85
- -61.375 ,
86
- 65.0 ,
87
- -99.375 })};
88
- exec_aten::TensorList out_expected =
89
- exec_aten::TensorList (out_expected_vec.data (), out_expected_vec.size ());
90
- op_split_with_sizes_copy_out (self, split_sizes, dim, out);
91
- EXPECT_TENSOR_LISTS_CLOSE (out, out_expected);
112
+ TEST_F (OpSplitWithSizesCopyOutTest, DynamicShape) {
113
+ test_tensor_shape_dynamism (exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
92
114
}
0 commit comments