Skip to content

Commit a520ec8

Browse files
Gasoonjiadbort
authored andcommitted
make op_split_with_sizes_copy support dynamic shape (#3152)
Summary: Pull Request resolved: #3152 as title Reviewed By: SS-JIA Differential Revision: D56333587 fbshipit-source-id: deecbb2a394257dc146dd1af50cc0e7158ac79ed (cherry picked from commit bd07c75)
1 parent 6a1703e commit a520ec8

File tree

2 files changed

+81
-60
lines changed

2 files changed

+81
-60
lines changed

kernels/portable/cpu/op_split_with_sizes_copy.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ void split_with_sizes_copy_out(
5555
target_out_sizes[dim] = static_cast<Tensor::SizesType>(split_sizes[i]);
5656
ET_KERNEL_CHECK(
5757
ctx,
58-
tensor_is_broadcastable_to(
59-
{target_out_sizes, target_out_ndim}, out[i].sizes()),
58+
resize_tensor(out[i], {target_out_sizes, target_out_ndim}) == Error::Ok,
6059
InvalidArgument, );
6160
}
6261

kernels/test/op_split_with_sizes_copy_test.cpp

Lines changed: 80 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -27,66 +27,88 @@ class OpSplitWithSizesCopyOutTest : public OperatorTest {
2727
return torch::executor::aten::split_with_sizes_copy_outf(
2828
context_, self, split_sizes, dim, out);
2929
}
30+
31+
void test_tensor_shape_dynamism(exec_aten::TensorShapeDynamism dynamism) {
32+
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float>
33+
tfFloat;
34+
35+
exec_aten::Tensor self = tfFloat.make(
36+
{2, 6, 3},
37+
{-31.25, -92.75, -39.75, -3.25, 53.875, 88.25, -0.625, -1.125,
38+
14.75, 42.0, 89.875, -21.125, -8.0, -64.125, 23.0, 37.0,
39+
46.125, -83.25, -58.125, 19.625, -71.125, 64.75, -1.375, -83.5,
40+
-61.375, 13.125, 28.625, -94.0, -67.0, -8.625, -88.875, -79.125,
41+
0.375, -61.375, 65.0, -99.375});
42+
::std::vector<int64_t> split_sizes_vec = {3, 1, 2};
43+
exec_aten::ArrayRef<int64_t> split_sizes = exec_aten::ArrayRef<int64_t>(
44+
split_sizes_vec.data(), split_sizes_vec.size());
45+
int64_t dim = 1;
46+
47+
::std::vector<exec_aten::Tensor> out_vec;
48+
if (dynamism == exec_aten::TensorShapeDynamism::STATIC) {
49+
out_vec = {
50+
tfFloat.zeros({2, 3, 3}),
51+
tfFloat.zeros({2, 1, 3}),
52+
tfFloat.zeros({2, 2, 3})};
53+
} else { // dynamism == exec_aten::TensorShapeDynamism::DYNAMIC_BOUND
54+
out_vec = {
55+
tfFloat.zeros(
56+
{2, 3, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND),
57+
tfFloat.zeros(
58+
{2, 1, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND),
59+
tfFloat.zeros(
60+
{2, 2, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND)};
61+
}
62+
63+
exec_aten::TensorList out =
64+
exec_aten::TensorList(out_vec.data(), out_vec.size());
65+
::std::vector<exec_aten::Tensor> out_expected_vec = {
66+
tfFloat.make(
67+
{2, 3, 3},
68+
{-31.25,
69+
-92.75,
70+
-39.75,
71+
-3.25,
72+
53.875,
73+
88.25,
74+
-0.625,
75+
-1.125,
76+
14.75,
77+
-58.125,
78+
19.625,
79+
-71.125,
80+
64.75,
81+
-1.375,
82+
-83.5,
83+
-61.375,
84+
13.125,
85+
28.625}),
86+
tfFloat.make({2, 1, 3}, {42.0, 89.875, -21.125, -94.0, -67.0, -8.625}),
87+
tfFloat.make(
88+
{2, 2, 3},
89+
{-8.0,
90+
-64.125,
91+
23.0,
92+
37.0,
93+
46.125,
94+
-83.25,
95+
-88.875,
96+
-79.125,
97+
0.375,
98+
-61.375,
99+
65.0,
100+
-99.375})};
101+
exec_aten::TensorList out_expected =
102+
exec_aten::TensorList(out_expected_vec.data(), out_expected_vec.size());
103+
op_split_with_sizes_copy_out(self, split_sizes, dim, out);
104+
EXPECT_TENSOR_LISTS_CLOSE(out, out_expected);
105+
}
30106
};
31107

32108
TEST_F(OpSplitWithSizesCopyOutTest, SanityCheckDim1) {
33-
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
109+
test_tensor_shape_dynamism(exec_aten::TensorShapeDynamism::STATIC);
110+
}
34111

35-
exec_aten::Tensor self = tfFloat.make(
36-
{2, 6, 3},
37-
{-31.25, -92.75, -39.75, -3.25, 53.875, 88.25, -0.625, -1.125,
38-
14.75, 42.0, 89.875, -21.125, -8.0, -64.125, 23.0, 37.0,
39-
46.125, -83.25, -58.125, 19.625, -71.125, 64.75, -1.375, -83.5,
40-
-61.375, 13.125, 28.625, -94.0, -67.0, -8.625, -88.875, -79.125,
41-
0.375, -61.375, 65.0, -99.375});
42-
::std::vector<int64_t> split_sizes_vec = {3, 1, 2};
43-
exec_aten::ArrayRef<int64_t> split_sizes = exec_aten::ArrayRef<int64_t>(
44-
split_sizes_vec.data(), split_sizes_vec.size());
45-
int64_t dim = 1;
46-
::std::vector<exec_aten::Tensor> out_vec = {
47-
tfFloat.zeros({2, 3, 3}),
48-
tfFloat.zeros({2, 1, 3}),
49-
tfFloat.zeros({2, 2, 3})};
50-
exec_aten::TensorList out =
51-
exec_aten::TensorList(out_vec.data(), out_vec.size());
52-
::std::vector<exec_aten::Tensor> out_expected_vec = {
53-
tfFloat.make(
54-
{2, 3, 3},
55-
{-31.25,
56-
-92.75,
57-
-39.75,
58-
-3.25,
59-
53.875,
60-
88.25,
61-
-0.625,
62-
-1.125,
63-
14.75,
64-
-58.125,
65-
19.625,
66-
-71.125,
67-
64.75,
68-
-1.375,
69-
-83.5,
70-
-61.375,
71-
13.125,
72-
28.625}),
73-
tfFloat.make({2, 1, 3}, {42.0, 89.875, -21.125, -94.0, -67.0, -8.625}),
74-
tfFloat.make(
75-
{2, 2, 3},
76-
{-8.0,
77-
-64.125,
78-
23.0,
79-
37.0,
80-
46.125,
81-
-83.25,
82-
-88.875,
83-
-79.125,
84-
0.375,
85-
-61.375,
86-
65.0,
87-
-99.375})};
88-
exec_aten::TensorList out_expected =
89-
exec_aten::TensorList(out_expected_vec.data(), out_expected_vec.size());
90-
op_split_with_sizes_copy_out(self, split_sizes, dim, out);
91-
EXPECT_TENSOR_LISTS_CLOSE(out, out_expected);
112+
TEST_F(OpSplitWithSizesCopyOutTest, DynamicShape) {
113+
test_tensor_shape_dynamism(exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
92114
}

0 commit comments

Comments
 (0)