Skip to content

Commit 53a6f75

Browse files
committed
Support Half/BFloat16 in native_group_norm
ghstack-source-id: 938abf9 ghstack-comment-id: 2608331848 Pull Request resolved: #7846
1 parent 6fe6870 commit 53a6f75

File tree

2 files changed

+142
-111
lines changed

2 files changed

+142
-111
lines changed

kernels/portable/cpu/op_native_group_norm.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ void group_norm(
7878
// compute E[X] and Var[x] = E[x^2] - E[x]^2
7979
CTYPE sum = reduce_add(x, inner_size);
8080
CTYPE sq_sum = vec_powerf(x, inner_size);
81-
CTYPE mean_value = sum / inner_size;
82-
CTYPE variance = sq_sum / inner_size - mean_value * mean_value;
81+
CTYPE mean_value = sum / static_cast<CTYPE>(inner_size);
82+
CTYPE variance =
83+
sq_sum / static_cast<CTYPE>(inner_size) - mean_value * mean_value;
8384
CTYPE std = std::sqrt(variance + eps);
8485
CTYPE rstd_value = 1.0 / std;
8586

@@ -93,10 +94,10 @@ void group_norm(
9394
const size_t g = i % G;
9495
for (size_t j = 0; j < D; j++) {
9596
const size_t ch = g * D + j;
96-
const CTYPE scale =
97-
rstd_value * (weight_data == nullptr ? 1.0 : weight_data[ch]);
98-
const CTYPE beta =
99-
-scale * mean_value + (bias_data == nullptr ? 0.0 : bias_data[ch]);
97+
const CTYPE scale = rstd_value *
98+
(weight_data == nullptr ? CTYPE(1.0) : weight_data[ch]);
99+
const CTYPE beta = -scale * mean_value +
100+
(bias_data == nullptr ? CTYPE(0.0) : bias_data[ch]);
100101
x = input_data + (i * D + j) * HxW;
101102
CTYPE* y = out_data + (i * D + j) * HxW;
102103
for (size_t k = 0; k < HxW; k++) {
@@ -185,7 +186,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_group_norm_out(
185186

186187
constexpr auto name = "native_group_norm.out";
187188

188-
ET_SWITCH_FLOAT_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
189+
ET_SWITCH_FLOATHBF16_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
189190
group_norm<CTYPE>(
190191
input, weight, bias, N, C, HxW, group, eps, out, mean_out, rstd_out);
191192
});

kernels/test/op_native_group_norm_test.cpp

Lines changed: 134 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -20,110 +20,140 @@ using exec_aten::ScalarType;
2020
using exec_aten::Tensor;
2121
using torch::executor::testing::TensorFactory;
2222

23-
::std::tuple<Tensor&, Tensor&, Tensor&> op_native_group_norm_out(
24-
const Tensor& input,
25-
const optional<Tensor>& weight,
26-
const optional<Tensor>& bias,
27-
int64_t N,
28-
int64_t C,
29-
int64_t HxW,
30-
int64_t group,
31-
double eps,
32-
Tensor& out0,
33-
Tensor& out1,
34-
Tensor& out2) {
35-
executorch::runtime::KernelRuntimeContext context{};
36-
return torch::executor::aten::native_group_norm_outf(
37-
context, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2);
38-
}
23+
class OpNativeGroupNormOutTest : public OperatorTest {
24+
protected:
25+
::std::tuple<Tensor&, Tensor&, Tensor&> op_native_group_norm_out(
26+
const Tensor& input,
27+
const optional<Tensor>& weight,
28+
const optional<Tensor>& bias,
29+
int64_t N,
30+
int64_t C,
31+
int64_t HxW,
32+
int64_t group,
33+
double eps,
34+
Tensor& out0,
35+
Tensor& out1,
36+
Tensor& out2) {
37+
executorch::runtime::KernelRuntimeContext context{};
38+
return torch::executor::aten::native_group_norm_outf(
39+
context, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2);
40+
}
41+
template <ScalarType DTYPE>
42+
void test_dtype() {
43+
TensorFactory<DTYPE> tf;
3944

40-
TEST(OpNativeGroupNormOutTest, SmokeTest) {
41-
TensorFactory<ScalarType::Float> tfFloat;
45+
Tensor input = tf.make(
46+
{5, 6, 2, 2},
47+
{-0.8125, 0.0625, -2.7500, -3.0625, -1.1250, -2.1250, -1.3125,
48+
-4.0625, 2.8125, -2.0625, 4.2500, 3.5000, -0.3750, 1.6250,
49+
4.3125, -1.0625, -2.8750, 3.3750, 4.9375, 4.0625, -3.0625,
50+
-1.8750, -2.7500, -2.5625, -0.1875, -3.0000, -2.7500, 0.6875,
51+
-3.2500, -3.1875, 1.0000, -4.6250, -0.1875, -1.7500, 4.5000,
52+
-1.8750, -2.6875, 4.8125, -3.8125, -2.9375, -1.1875, 2.8750,
53+
0.7500, 2.8750, 1.1250, -0.6250, -2.2500, -3.7500, 3.2500,
54+
-0.3750, -2.0625, -4.7500, 2.0625, 3.0000, -3.1875, -4.1250,
55+
-3.7500, 1.2500, -2.3125, 1.5625, 3.1250, 0.3125, 3.2500,
56+
-2.7500, -3.8125, -4.2500, -4.3125, -0.5625, -0.4375, 2.9375,
57+
-1.3750, -0.6250, -2.5625, -4.5625, 0.1250, -3.5000, -5.0000,
58+
-1.0000, -4.6875, -0.6875, 1.1250, 1.8750, -4.5000, 4.3125,
59+
4.5625, 0.2500, -3.6250, 4.5625, -3.5000, -2.1250, -3.6250,
60+
-2.9375, 3.6875, 3.9375, 4.3750, 3.0625, 2.4375, 2.0625,
61+
-2.4375, -3.9375, 3.6875, 2.7500, -0.8750, -0.9375, 2.7500,
62+
-2.4375, -2.3750, -0.9375, -4.8750, 0.1875, 3.5000, -2.0000,
63+
-0.2500, -2.7500, 0.3125, 1.2500, -0.5625, 0.0000, 1.8125,
64+
1.0625});
65+
optional<Tensor> weight =
66+
tf.make({6}, {4.5625, -2.8750, -0.6875, 0.5625, -2.0625, -2.7500});
67+
optional<Tensor> bias =
68+
tf.make({6}, {-0.5000, -2.7500, 1.1875, 3.6875, 3.8125, 4.6875});
69+
double eps = 1e-5;
70+
Tensor out0 = tf.zeros({5, 6, 2, 2});
71+
Tensor out1 = tf.zeros({5, 3});
72+
Tensor out2 = tf.zeros({5, 3});
73+
Tensor out0_expected = tf.make(
74+
{5, 6, 2, 2},
75+
{3.419882, 6.578348, -3.573864, -4.701888, -4.509254, -2.234663,
76+
-4.082768, 2.172355, 0.838826, 2.270225, 0.416747, 0.636962,
77+
3.207030, 3.687500, 4.333131, 3.041869, 5.547079, 1.649148,
78+
0.674665, 1.220376, 7.156189, 6.168714, 6.896327, 6.740410,
79+
3.509863, -3.022041, -2.441427, 5.542011, -0.794903, -0.886369,
80+
-7.014627, 1.217361, 1.120617, 1.463606, 0.091652, 1.491045,
81+
3.293219, 4.640229, 3.091168, 3.248319, 4.895990, 1.114683,
82+
3.092597, 1.114683, 3.262238, 5.434066, 7.450763, 9.312329,
83+
5.570122, 0.101119, -2.444796, -6.499403, -5.446074, -6.337338,
84+
-0.454995, 0.436269, 2.228491, 0.871598, 1.838385, 0.786793,
85+
4.362284, 3.737805, 4.390039, 3.057817, 5.814659, 6.202621,
86+
6.258044, 2.932658, 3.366583, -0.623879, 4.475045, 3.588276,
87+
-0.082914, -4.936279, 6.438795, -2.357929, 0.714463, -5.402106,
88+
0.236606, -5.879963, 1.176247, 1.021916, 2.333727, 0.520341,
89+
4.275447, 3.549392, 2.896994, 4.275447, 6.120910, 5.298480,
90+
6.195676, 5.784461, 2.033296, 1.833920, 1.485010, 2.531738,
91+
3.193988, 2.532378, -5.406940, -8.053379, -6.467402, -5.425139,
92+
-1.395059, -1.325575, 0.266062, 1.622680, 1.606336, 1.230405,
93+
2.809896, 3.893110, 4.601880, 3.425055, 4.374411, 8.283354,
94+
3.494898, 2.029045, 6.088204, 4.915522, 1.136877, 2.700454});
95+
Tensor out1_expected = tf.make(
96+
{5, 3},
97+
{-1.89843750,
98+
1.62500000,
99+
-0.09375000,
100+
-1.91406250,
101+
-0.49218744,
102+
-0.02343750,
103+
-0.77343756,
104+
0.08593753,
105+
-1.55468738,
106+
-2.73437500,
107+
1.07031238,
108+
0.35937503,
109+
0.34374997,
110+
-0.77343750,
111+
0.10937499});
112+
Tensor out2_expected = tf.make(
113+
{5, 3},
114+
{0.79116172,
115+
0.42708409,
116+
0.30238494,
117+
0.50903118,
118+
0.31929117,
119+
0.45128885,
120+
0.33067191,
121+
0.39473253,
122+
0.42994878,
123+
0.53187561,
124+
0.29930803,
125+
0.29000264,
126+
0.38669431,
127+
0.38038814,
128+
0.75809801});
129+
op_native_group_norm_out(
130+
input, weight, bias, 5, 6, 4, 3, eps, out0, out1, out2);
131+
if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) {
132+
EXPECT_TENSOR_CLOSE_WITH_TOL(
133+
out0,
134+
out0_expected,
135+
2e-1,
136+
executorch::runtime::testing::internal::kDefaultAtol);
137+
EXPECT_TENSOR_CLOSE_WITH_TOL(
138+
out1,
139+
out1_expected,
140+
1e-2,
141+
executorch::runtime::testing::internal::kDefaultAtol);
142+
EXPECT_TENSOR_CLOSE_WITH_TOL(
143+
out2,
144+
out2_expected,
145+
1e-2,
146+
executorch::runtime::testing::internal::kDefaultAtol);
147+
} else {
148+
EXPECT_TENSOR_CLOSE(out0, out0_expected);
149+
EXPECT_TENSOR_CLOSE(out1, out1_expected);
150+
EXPECT_TENSOR_CLOSE(out2, out2_expected);
151+
}
152+
}
153+
};
42154

43-
Tensor input = tfFloat.make(
44-
{5, 6, 2, 2},
45-
{-0.8125, 0.0625, -2.7500, -3.0625, -1.1250, -2.1250, -1.3125, -4.0625,
46-
2.8125, -2.0625, 4.2500, 3.5000, -0.3750, 1.6250, 4.3125, -1.0625,
47-
-2.8750, 3.3750, 4.9375, 4.0625, -3.0625, -1.8750, -2.7500, -2.5625,
48-
-0.1875, -3.0000, -2.7500, 0.6875, -3.2500, -3.1875, 1.0000, -4.6250,
49-
-0.1875, -1.7500, 4.5000, -1.8750, -2.6875, 4.8125, -3.8125, -2.9375,
50-
-1.1875, 2.8750, 0.7500, 2.8750, 1.1250, -0.6250, -2.2500, -3.7500,
51-
3.2500, -0.3750, -2.0625, -4.7500, 2.0625, 3.0000, -3.1875, -4.1250,
52-
-3.7500, 1.2500, -2.3125, 1.5625, 3.1250, 0.3125, 3.2500, -2.7500,
53-
-3.8125, -4.2500, -4.3125, -0.5625, -0.4375, 2.9375, -1.3750, -0.6250,
54-
-2.5625, -4.5625, 0.1250, -3.5000, -5.0000, -1.0000, -4.6875, -0.6875,
55-
1.1250, 1.8750, -4.5000, 4.3125, 4.5625, 0.2500, -3.6250, 4.5625,
56-
-3.5000, -2.1250, -3.6250, -2.9375, 3.6875, 3.9375, 4.3750, 3.0625,
57-
2.4375, 2.0625, -2.4375, -3.9375, 3.6875, 2.7500, -0.8750, -0.9375,
58-
2.7500, -2.4375, -2.3750, -0.9375, -4.8750, 0.1875, 3.5000, -2.0000,
59-
-0.2500, -2.7500, 0.3125, 1.2500, -0.5625, 0.0000, 1.8125, 1.0625});
60-
optional<Tensor> weight =
61-
tfFloat.make({6}, {4.5625, -2.8750, -0.6875, 0.5625, -2.0625, -2.7500});
62-
optional<Tensor> bias =
63-
tfFloat.make({6}, {-0.5000, -2.7500, 1.1875, 3.6875, 3.8125, 4.6875});
64-
double eps = 1e-5;
65-
Tensor out0 = tfFloat.zeros({5, 6, 2, 2});
66-
Tensor out1 = tfFloat.zeros({5, 3});
67-
Tensor out2 = tfFloat.zeros({5, 3});
68-
Tensor out0_expected = tfFloat.make(
69-
{5, 6, 2, 2},
70-
{3.419882, 6.578348, -3.573864, -4.701888, -4.509254, -2.234663,
71-
-4.082768, 2.172355, 0.838826, 2.270225, 0.416747, 0.636962,
72-
3.207030, 3.687500, 4.333131, 3.041869, 5.547079, 1.649148,
73-
0.674665, 1.220376, 7.156189, 6.168714, 6.896327, 6.740410,
74-
3.509863, -3.022041, -2.441427, 5.542011, -0.794903, -0.886369,
75-
-7.014627, 1.217361, 1.120617, 1.463606, 0.091652, 1.491045,
76-
3.293219, 4.640229, 3.091168, 3.248319, 4.895990, 1.114683,
77-
3.092597, 1.114683, 3.262238, 5.434066, 7.450763, 9.312329,
78-
5.570122, 0.101119, -2.444796, -6.499403, -5.446074, -6.337338,
79-
-0.454995, 0.436269, 2.228491, 0.871598, 1.838385, 0.786793,
80-
4.362284, 3.737805, 4.390039, 3.057817, 5.814659, 6.202621,
81-
6.258044, 2.932658, 3.366583, -0.623879, 4.475045, 3.588276,
82-
-0.082914, -4.936279, 6.438795, -2.357929, 0.714463, -5.402106,
83-
0.236606, -5.879963, 1.176247, 1.021916, 2.333727, 0.520341,
84-
4.275447, 3.549392, 2.896994, 4.275447, 6.120910, 5.298480,
85-
6.195676, 5.784461, 2.033296, 1.833920, 1.485010, 2.531738,
86-
3.193988, 2.532378, -5.406940, -8.053379, -6.467402, -5.425139,
87-
-1.395059, -1.325575, 0.266062, 1.622680, 1.606336, 1.230405,
88-
2.809896, 3.893110, 4.601880, 3.425055, 4.374411, 8.283354,
89-
3.494898, 2.029045, 6.088204, 4.915522, 1.136877, 2.700454});
90-
Tensor out1_expected = tfFloat.make(
91-
{5, 3},
92-
{-1.89843750,
93-
1.62500000,
94-
-0.09375000,
95-
-1.91406250,
96-
-0.49218744,
97-
-0.02343750,
98-
-0.77343756,
99-
0.08593753,
100-
-1.55468738,
101-
-2.73437500,
102-
1.07031238,
103-
0.35937503,
104-
0.34374997,
105-
-0.77343750,
106-
0.10937499});
107-
Tensor out2_expected = tfFloat.make(
108-
{5, 3},
109-
{0.79116172,
110-
0.42708409,
111-
0.30238494,
112-
0.50903118,
113-
0.31929117,
114-
0.45128885,
115-
0.33067191,
116-
0.39473253,
117-
0.42994878,
118-
0.53187561,
119-
0.29930803,
120-
0.29000264,
121-
0.38669431,
122-
0.38038814,
123-
0.75809801});
124-
op_native_group_norm_out(
125-
input, weight, bias, 5, 6, 4, 3, eps, out0, out1, out2);
126-
EXPECT_TENSOR_CLOSE(out0, out0_expected);
127-
EXPECT_TENSOR_CLOSE(out1, out1_expected);
128-
EXPECT_TENSOR_CLOSE(out2, out2_expected);
155+
TEST_F(OpNativeGroupNormOutTest, SmokeTest) {
156+
#define TEST_ENTRY(ctype, dtype) test_dtype<ScalarType::dtype>();
157+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY)
158+
#undef TEST_ENTRY
129159
}

0 commit comments

Comments
 (0)