Skip to content

Commit bc2de23

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in native_batch_norm (#7842)
1 parent b6c04f6 commit bc2de23

File tree

2 files changed

+180
-126
lines changed

2 files changed

+180
-126
lines changed

kernels/portable/cpu/op_native_batch_norm.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_training_out(
104104

105105
constexpr auto name = "native_batch_norm_legit_no_training.out";
106106

107-
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
107+
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
108108
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
109109
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
110110

@@ -261,7 +261,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_stats_out(
261261

262262
constexpr auto name = "_native_batch_norm_legit.no_stats_out";
263263

264-
ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
264+
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
265265
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
266266
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
267267
CTYPE* mean_data = mean_out.mutable_data_ptr<CTYPE>();
@@ -282,10 +282,12 @@ std::tuple<Tensor&, Tensor&, Tensor&> _native_batch_norm_legit_no_stats_out(
282282
}
283283

284284
// Compute mean and invstd for each channel
285+
const CTYPE elements_per_channel_ct =
286+
static_cast<CTYPE>(elements_per_channel);
285287
for (size_t c = 0; c < C; ++c) {
286-
CTYPE mean = mean_data[c] / elements_per_channel;
288+
CTYPE mean = mean_data[c] / elements_per_channel_ct;
287289
// Var[x] = E[x^2] - E[x]^2
288-
CTYPE var = invstd_data[c] / elements_per_channel - mean * mean;
290+
CTYPE var = invstd_data[c] / elements_per_channel_ct - mean * mean;
289291
CTYPE invstd = 1.0 / std::sqrt(var + eps);
290292
mean_data[c] = mean;
291293
invstd_data[c] = invstd;

kernels/test/op_native_batch_norm_test.cpp

Lines changed: 174 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,112 @@ class OpNativeBatchNormLegitNoTrainingOutTest : public OperatorTest {
4444
out1,
4545
out2);
4646
}
47+
48+
template <exec_aten::ScalarType DTYPE>
49+
void test_2d_dtype() {
50+
torch::executor::testing::TensorFactory<DTYPE> tf;
51+
52+
exec_aten::Tensor input = tf.make(
53+
{4, 7}, {2.876736640930176, 7.67944860458374, 5.701690196990967,
54+
9.299789428710938, 3.023690700531006, 5.315116882324219,
55+
7.185585021972656, 6.911304473876953, 7.61051082611084,
56+
1.4963287115097046, 0.7381612062454224, 8.588483810424805,
57+
6.583977699279785, 8.831110000610352, 0.8165055513381958,
58+
7.087201118469238, 5.572513580322266, 4.446897983551025,
59+
4.444573402404785, 6.254056930541992, 5.906398296356201,
60+
9.971039772033691, 3.5423521995544434, 7.452159881591797,
61+
9.93700122833252, 1.8560808897018433, 1.524025797843933,
62+
7.3222975730896});
63+
exec_aten::optional<exec_aten::Tensor> weight =
64+
exec_aten::optional<exec_aten::Tensor>(tf.make(
65+
{7},
66+
{8.287437438964844,
67+
8.227645874023438,
68+
6.65926456451416,
69+
9.436124801635742,
70+
4.119281768798828,
71+
8.593960762023926,
72+
2.3760855197906494}));
73+
exec_aten::optional<exec_aten::Tensor> bias =
74+
exec_aten::optional<exec_aten::Tensor>(tf.make(
75+
{7},
76+
{7.824275970458984,
77+
6.84327507019043,
78+
8.354326248168945,
79+
8.773970603942871,
80+
3.89609694480896,
81+
3.0753469467163086,
82+
3.1105971336364746}));
83+
exec_aten::Tensor running_mean = tf.make(
84+
{7},
85+
{9.700226783752441,
86+
0.1234668493270874,
87+
7.527220249176025,
88+
8.993252754211426,
89+
0.4736626148223877,
90+
7.7135701179504395,
91+
5.12320613861084});
92+
exec_aten::Tensor running_var = tf.make(
93+
{7},
94+
{3.585531234741211,
95+
6.615292549133301,
96+
0.24084866046905518,
97+
5.175800323486328,
98+
0.5886000394821167,
99+
6.23909854888916,
100+
1.5029621124267578});
101+
double momentum = 0.1;
102+
double eps = 0;
103+
exec_aten::Tensor out0 = tf.zeros({4, 7});
104+
exec_aten::Tensor out1 = tf.zeros({0});
105+
exec_aten::Tensor out2 = tf.zeros({0});
106+
exec_aten::Tensor out0_expected = tf.make(
107+
{4, 7}, {-22.039867401123047, 31.014127731323242, -16.416650772094727,
108+
10.04538631439209, 17.5877628326416, -5.17673921585083,
109+
7.1078033447265625, -4.381907939910889, 30.793603897094727,
110+
-73.48003387451172, -25.46548080444336, 47.46636962890625,
111+
-0.8111140131950378, 10.29708194732666, -31.056814193725586,
112+
29.119586944580078, -18.16947364807129, -10.082839965820312,
113+
25.216796875, -1.9462348222732544, 4.628543376922607,
114+
9.00953483581543, 17.779958724975586, 7.335818767547607,
115+
12.688335418701172, 11.318607330322266, -18.22031593322754,
116+
7.372773170471191});
117+
exec_aten::Tensor out1_expected = tf.make({0}, {});
118+
exec_aten::Tensor out2_expected = tf.make({0}, {});
119+
op_native_batch_norm_legit_no_training_out(
120+
input,
121+
weight,
122+
bias,
123+
running_mean,
124+
running_var,
125+
momentum,
126+
eps,
127+
out0,
128+
out1,
129+
out2);
130+
if (DTYPE == exec_aten::ScalarType::Half ||
131+
DTYPE == exec_aten::ScalarType::BFloat16) {
132+
EXPECT_TENSOR_CLOSE_WITH_TOL(
133+
out0,
134+
out0_expected,
135+
4e-2,
136+
executorch::runtime::testing::internal::kDefaultAtol);
137+
EXPECT_TENSOR_CLOSE_WITH_TOL(
138+
out1,
139+
out1_expected,
140+
2e-2,
141+
executorch::runtime::testing::internal::kDefaultAtol);
142+
EXPECT_TENSOR_CLOSE_WITH_TOL(
143+
out2,
144+
out2_expected,
145+
2e-2,
146+
executorch::runtime::testing::internal::kDefaultAtol);
147+
} else {
148+
EXPECT_TENSOR_CLOSE(out0, out0_expected);
149+
EXPECT_TENSOR_CLOSE(out1, out1_expected);
150+
EXPECT_TENSOR_CLOSE(out2, out2_expected);
151+
}
152+
}
47153
};
48154

49155
class OpNativeBatchNormLegitOutTest : public OperatorTest {
@@ -103,92 +209,72 @@ class OpNativeBatchNormLegitNoStatsOutTest : public OperatorTest {
103209
out1,
104210
out2);
105211
}
106-
};
107212

108-
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest2D) {
109-
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
213+
template <exec_aten::ScalarType DTYPE>
214+
void test_2d_dtype() {
215+
torch::executor::testing::TensorFactory<DTYPE> tf;
110216

111-
exec_aten::Tensor input = tfFloat.make(
112-
{4, 7}, {2.876736640930176, 7.67944860458374, 5.701690196990967,
113-
9.299789428710938, 3.023690700531006, 5.315116882324219,
114-
7.185585021972656, 6.911304473876953, 7.61051082611084,
115-
1.4963287115097046, 0.7381612062454224, 8.588483810424805,
116-
6.583977699279785, 8.831110000610352, 0.8165055513381958,
117-
7.087201118469238, 5.572513580322266, 4.446897983551025,
118-
4.444573402404785, 6.254056930541992, 5.906398296356201,
119-
9.971039772033691, 3.5423521995544434, 7.452159881591797,
120-
9.93700122833252, 1.8560808897018433, 1.524025797843933,
121-
7.3222975730896});
122-
exec_aten::optional<exec_aten::Tensor> weight =
123-
exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
124-
{7},
125-
{8.287437438964844,
126-
8.227645874023438,
127-
6.65926456451416,
128-
9.436124801635742,
129-
4.119281768798828,
130-
8.593960762023926,
131-
2.3760855197906494}));
132-
exec_aten::optional<exec_aten::Tensor> bias =
133-
exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
134-
{7},
135-
{7.824275970458984,
136-
6.84327507019043,
137-
8.354326248168945,
138-
8.773970603942871,
139-
3.89609694480896,
140-
3.0753469467163086,
141-
3.1105971336364746}));
142-
exec_aten::Tensor running_mean = tfFloat.make(
143-
{7},
144-
{9.700226783752441,
145-
0.1234668493270874,
146-
7.527220249176025,
147-
8.993252754211426,
148-
0.4736626148223877,
149-
7.7135701179504395,
150-
5.12320613861084});
151-
exec_aten::Tensor running_var = tfFloat.make(
152-
{7},
153-
{3.585531234741211,
154-
6.615292549133301,
155-
0.24084866046905518,
156-
5.175800323486328,
157-
0.5886000394821167,
158-
6.23909854888916,
159-
1.5029621124267578});
160-
double momentum = 0.1;
161-
double eps = 0;
162-
exec_aten::Tensor out0 = tfFloat.zeros({4, 7});
163-
exec_aten::Tensor out1 = tfFloat.zeros({0});
164-
exec_aten::Tensor out2 = tfFloat.zeros({0});
165-
exec_aten::Tensor out0_expected = tfFloat.make(
166-
{4, 7}, {-22.039867401123047, 31.014127731323242, -16.416650772094727,
167-
10.04538631439209, 17.5877628326416, -5.17673921585083,
168-
7.1078033447265625, -4.381907939910889, 30.793603897094727,
169-
-73.48003387451172, -25.46548080444336, 47.46636962890625,
170-
-0.8111140131950378, 10.29708194732666, -31.056814193725586,
171-
29.119586944580078, -18.16947364807129, -10.082839965820312,
172-
25.216796875, -1.9462348222732544, 4.628543376922607,
173-
9.00953483581543, 17.779958724975586, 7.335818767547607,
174-
12.688335418701172, 11.318607330322266, -18.22031593322754,
175-
7.372773170471191});
176-
exec_aten::Tensor out1_expected = tfFloat.make({0}, {});
177-
exec_aten::Tensor out2_expected = tfFloat.make({0}, {});
178-
op_native_batch_norm_legit_no_training_out(
179-
input,
180-
weight,
181-
bias,
182-
running_mean,
183-
running_var,
184-
momentum,
185-
eps,
186-
out0,
187-
out1,
188-
out2);
189-
EXPECT_TENSOR_CLOSE(out0, out0_expected);
190-
EXPECT_TENSOR_CLOSE(out1, out1_expected);
191-
EXPECT_TENSOR_CLOSE(out2, out2_expected);
217+
exec_aten::Tensor input =
218+
tf.make({3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
219+
exec_aten::optional<exec_aten::Tensor> weight =
220+
exec_aten::optional<exec_aten::Tensor>();
221+
exec_aten::optional<exec_aten::Tensor> bias =
222+
exec_aten::optional<exec_aten::Tensor>();
223+
bool training = true;
224+
double momentum = 1e-3;
225+
double eps = 1e-5;
226+
exec_aten::Tensor out0 = tf.zeros({3, 4});
227+
exec_aten::Tensor out1 = tf.zeros({4});
228+
exec_aten::Tensor out2 = tf.zeros({4});
229+
exec_aten::Tensor out0_expected = tf.make(
230+
{3, 4},
231+
{-0.98058063,
232+
-1.03422451,
233+
-1.06904495,
234+
-1.09332705,
235+
-0.39223224,
236+
-0.31822300,
237+
-0.26726127,
238+
-0.23017406,
239+
1.37281299,
240+
1.35244739,
241+
1.33630610,
242+
1.32350123});
243+
exec_aten::Tensor out1_expected =
244+
tf.make({4}, {26.66666603, 35.66666794, 46.66666794, 59.66666794});
245+
exec_aten::Tensor out2_expected =
246+
tf.make({4}, {0.03677177, 0.02983340, 0.02505574, 0.02157882});
247+
op_native_batch_norm_legit_no_stats_out(
248+
input, weight, bias, training, momentum, eps, out0, out1, out2);
249+
if (DTYPE == exec_aten::ScalarType::Half ||
250+
DTYPE == exec_aten::ScalarType::BFloat16) {
251+
EXPECT_TENSOR_CLOSE_WITH_TOL(
252+
out0,
253+
out0_expected,
254+
2e-2,
255+
executorch::runtime::testing::internal::kDefaultAtol);
256+
EXPECT_TENSOR_CLOSE_WITH_TOL(
257+
out1,
258+
out1_expected,
259+
1e-2,
260+
executorch::runtime::testing::internal::kDefaultAtol);
261+
EXPECT_TENSOR_CLOSE_WITH_TOL(
262+
out2,
263+
out2_expected,
264+
2e-2,
265+
executorch::runtime::testing::internal::kDefaultAtol);
266+
} else {
267+
EXPECT_TENSOR_CLOSE(out0, out0_expected);
268+
EXPECT_TENSOR_CLOSE(out1, out1_expected);
269+
EXPECT_TENSOR_CLOSE(out2, out2_expected);
270+
}
271+
}
272+
};
273+
274+
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest2D){
275+
#define TEST_ENTRY(ctype, dtype) test_2d_dtype<exec_aten::ScalarType::dtype>();
276+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY)
277+
#undef TEST_ENTRY
192278
}
193279

194280
TEST_F(OpNativeBatchNormLegitNoTrainingOutTest, SampleAtomicTest3D) {
@@ -977,44 +1063,10 @@ TEST_F(OpNativeBatchNormLegitOutTest, SampleAtomicTest2D) {
9771063
EXPECT_TENSOR_CLOSE(out2, out2_expected);
9781064
}
9791065

980-
TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest2D) {
981-
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
982-
983-
exec_aten::Tensor input =
984-
tfFloat.make({3, 4}, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
985-
exec_aten::optional<exec_aten::Tensor> weight =
986-
exec_aten::optional<exec_aten::Tensor>();
987-
exec_aten::optional<exec_aten::Tensor> bias =
988-
exec_aten::optional<exec_aten::Tensor>();
989-
bool training = true;
990-
double momentum = 1e-3;
991-
double eps = 1e-5;
992-
exec_aten::Tensor out0 = tfFloat.zeros({3, 4});
993-
exec_aten::Tensor out1 = tfFloat.zeros({4});
994-
exec_aten::Tensor out2 = tfFloat.zeros({4});
995-
exec_aten::Tensor out0_expected = tfFloat.make(
996-
{3, 4},
997-
{-0.98058063,
998-
-1.03422451,
999-
-1.06904495,
1000-
-1.09332705,
1001-
-0.39223224,
1002-
-0.31822300,
1003-
-0.26726127,
1004-
-0.23017406,
1005-
1.37281299,
1006-
1.35244739,
1007-
1.33630610,
1008-
1.32350123});
1009-
exec_aten::Tensor out1_expected =
1010-
tfFloat.make({4}, {26.66666603, 35.66666794, 46.66666794, 59.66666794});
1011-
exec_aten::Tensor out2_expected =
1012-
tfFloat.make({4}, {0.03677177, 0.02983340, 0.02505574, 0.02157882});
1013-
op_native_batch_norm_legit_no_stats_out(
1014-
input, weight, bias, training, momentum, eps, out0, out1, out2);
1015-
EXPECT_TENSOR_CLOSE(out0, out0_expected);
1016-
EXPECT_TENSOR_CLOSE(out1, out1_expected);
1017-
EXPECT_TENSOR_CLOSE(out2, out2_expected);
1066+
TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest2D){
1067+
#define TEST_ENTRY(ctype, dtype) test_2d_dtype<exec_aten::ScalarType::dtype>();
1068+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY)
1069+
#undef TEST_ENTRY
10181070
}
10191071

10201072
TEST_F(OpNativeBatchNormLegitNoStatsOutTest, SampleAtomicTest3D) {

0 commit comments

Comments
 (0)