Skip to content

Commit b637f4b

Browse files
Added support for bias in optimized linear operation
Signed-off-by: David Grigoryan <[email protected]>
1 parent 77c35f5 commit b637f4b

File tree

2 files changed

+105
-68
lines changed

2 files changed

+105
-68
lines changed

kernels/optimized/cpu/op_linear.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/kernels/optimized/blas/CPUBlas.h>
1010
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <c10/util/irange.h>
1213

1314
#include <array>
1415

@@ -24,12 +25,6 @@ Tensor& opt_linear_out(
2425
const Tensor& mat2,
2526
const optional<Tensor>& bias,
2627
Tensor& out) {
27-
ET_KERNEL_CHECK_MSG(
28-
ctx,
29-
!bias.has_value(),
30-
InvalidArgument,
31-
out,
32-
"bias not supported yet in linear");
3328
ET_KERNEL_CHECK(ctx, check_linear_args(in, mat2, out), InvalidArgument, out);
3429

3530
size_t output_ndim = 0;
@@ -56,6 +51,33 @@ Tensor& opt_linear_out(
5651
size_t k = in.sizes()[in.dim() - 1];
5752
size_t m = mat2.size(0);
5853

54+
// If bias is provided, verify its shape and pre-fill the output tensor.
55+
if (bias.has_value()) {
56+
auto bias_value = bias.value();
57+
// Check that bias is 1D and its size matches m.
58+
ET_KERNEL_CHECK_MSG(
59+
ctx,
60+
bias_value.dim() == 1 && bias_value.size(0) == m,
61+
InvalidArgument,
62+
out,
63+
"Bias must be 1D and of size m. Got: ",
64+
bias_value.size(0),
65+
", expected: ",
66+
m
67+
);
68+
auto bias_ptr = bias_value.const_data_ptr<CTYPE>();
69+
CTYPE* out_ptr = out.mutable_data_ptr<CTYPE>();
70+
// Broadcast the bias to every column of the output.
71+
auto row_size = m * sizeof(CTYPE);
72+
for (const auto col : c10::irange(n)) {
73+
std::memcpy(out_ptr + col * m, bias_ptr, row_size);
74+
}
75+
}
76+
77+
// Set beta to 1 if bias was applied so that GEMM adds to the pre-filled bias,
78+
// otherwise beta remains 0 (i.e. the output is fully overwritten by GEMM).
79+
CTYPE beta_val = bias.has_value() ? static_cast<CTYPE>(1) : static_cast<CTYPE>(0);
80+
5981
executorch::cpublas::gemm(
6082
executorch::cpublas::TransposeType::Transpose,
6183
executorch::cpublas::TransposeType::NoTranspose,
@@ -67,7 +89,7 @@ Tensor& opt_linear_out(
6789
k,
6890
in.const_data_ptr<CTYPE>(),
6991
k,
70-
static_cast<CTYPE>(0),
92+
beta_val,
7193
out.mutable_data_ptr<CTYPE>(),
7294
m);
7395
});

kernels/test/op_linear_test.cpp

Lines changed: 76 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ using executorch::aten::ArrayRef;
2323
using executorch::aten::Scalar;
2424
using executorch::aten::ScalarType;
2525
using executorch::aten::Tensor;
26+
using executorch::aten::optional;
2627
using torch::executor::testing::TensorFactory;
2728

2829
class OpLinearOutTest : public OperatorTest {
2930
protected:
30-
Tensor& op_linear_out(const Tensor& self, const Tensor& mat2, Tensor& out) {
31-
return torch::executor::aten::linear_outf(context_, self, mat2, {}, out);
31+
Tensor& op_linear_out(const Tensor& self, const Tensor& mat2, const optional<Tensor>& bias, Tensor& out) {
32+
return torch::executor::aten::linear_outf(context_, self, mat2, bias, out);
3233
}
3334

3435
template <class CTYPE, executorch::aten::ScalarType DTYPE>
@@ -47,14 +48,18 @@ class OpLinearOutTest : public OperatorTest {
4748
Tensor x = tf.full({3, 32}, 2);
4849
Tensor y = tf.full({5, 32}, 3);
4950

50-
// Output shape should be (3, 5)
51-
Tensor out = tf.zeros({3, 5});
52-
53-
op_linear_out(x, y, out);
54-
55-
Tensor expected = tf.full({3, 5}, 192);
56-
57-
EXPECT_TENSOR_EQ(out, expected);
51+
// without bias
52+
Tensor out_no_bias = tf.zeros({3, 5});
53+
op_linear_out(x, y, {}, out_no_bias);
54+
Tensor expected_no_bias = tf.full({3, 5}, 192);
55+
EXPECT_TENSOR_EQ(out_no_bias, expected_no_bias);
56+
57+
// with bias
58+
Tensor bias = tf.full({5}, 1);
59+
Tensor out_with_bias = tf.zeros({3, 5});
60+
op_linear_out(x, y, bias, out_with_bias);
61+
Tensor expected_with_bias = tf.full({3, 5}, 193);
62+
EXPECT_TENSOR_EQ(out_with_bias, expected_with_bias);
5863
}
5964
};
6065

@@ -66,13 +71,15 @@ TEST_F(OpLinearOutTest, OutputDim) {
6671
Tensor y = tf.ones({5, 4});
6772
Tensor out = tf.zeros({3, 5});
6873

69-
Tensor ret = op_linear_out(x, y, out);
74+
Tensor bias = tf.ones({5});
75+
76+
Tensor ret = op_linear_out(x, y, bias, out);
7077

7178
// Should always return the provided out Tensor.
7279
EXPECT_TENSOR_EQ(ret, out);
7380

74-
// Expected tensor, filled with 4.
75-
Tensor expected = tf.full({3, 5}, 4);
81+
// Expected tensor, filled with 5.
82+
Tensor expected = tf.full({3, 5}, 5);
7683

7784
EXPECT_TENSOR_EQ(out, expected);
7885
}
@@ -94,44 +101,47 @@ TEST_F(OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) {
94101
// Empty input matrices
95102
Tensor x = tf.make({0, 3}, {});
96103
Tensor y = tf.make({0, 3}, {});
104+
Tensor bias = tf.make({0}, {});
97105

98106
// Output matrix is also empty
99107
Tensor out = tf.make({0, 0}, {});
100108

101109
Tensor expected = tf.make({0, 0}, {});
102110

103-
EXPECT_TENSOR_EQ(op_linear_out(x, y, out), expected);
111+
EXPECT_TENSOR_EQ(op_linear_out(x, y, bias, out), expected);
104112
}
105113

106114
TEST_F(OpLinearOutTest, InfinityTensorPasses) {
107115
TensorFactory<ScalarType::Float> tff;
108116

109117
Tensor x = tff.full({3, 4}, std::numeric_limits<float>::infinity());
110118
Tensor y = tff.full({5, 4}, 3);
119+
Tensor bias = tff.full({5}, 1);
111120

112121
// Output shape should be (3, 5)
113122
Tensor out = tff.zeros({3, 5});
114123

115124
Tensor expected = tff.full({3, 5}, std::numeric_limits<float>::infinity());
116125

117-
EXPECT_TENSOR_EQ(op_linear_out(x, y, out), expected);
126+
EXPECT_TENSOR_EQ(op_linear_out(x, y, bias, out), expected);
118127
}
119128

120129
TEST_F(OpLinearOutTest, MismatchedDimensionsDies) {
121130
TensorFactory<ScalarType::Int> tf;
122131

123132
Tensor x = tf.full({2, 2}, 3);
133+
Tensor bias = tf.full({2}, 1);
124134

125135
Tensor wrong_y = tf.full({1, 3}, 1);
126136
Tensor right_y = tf.full({2, 2}, 1);
127137

128138
// Make an empty out tensor and demonstrate that it's empty.
129139
Tensor out = tf.full({2, 2}, 0);
130140

131-
Tensor expected = tf.full({2, 2}, 6);
132-
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, wrong_y, out));
141+
Tensor expected = tf.full({2, 2}, 7);
142+
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, wrong_y, bias, out));
133143

134-
EXPECT_TENSOR_EQ(op_linear_out(x, right_y, out), expected);
144+
EXPECT_TENSOR_EQ(op_linear_out(x, right_y, bias, out), expected);
135145
}
136146

137147
TEST_F(OpLinearOutTest, MismatchedDimensionSizeDies) {
@@ -140,6 +150,7 @@ TEST_F(OpLinearOutTest, MismatchedDimensionSizeDies) {
140150
}
141151
TensorFactory<ScalarType::Int> tf;
142152
Tensor x = tf.full({2, 2}, 3);
153+
Tensor bias = tf.full({2}, 1);
143154

144155
// wrong_y has incompatible dim
145156
Tensor wrong_y = tf.full({2, 2, 2}, 1);
@@ -149,8 +160,8 @@ TEST_F(OpLinearOutTest, MismatchedDimensionSizeDies) {
149160
Tensor right_out = tf.ones({2, 2});
150161
Tensor wrong_out = tf.ones({2, 2, 3});
151162

152-
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, right_y, wrong_out));
153-
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, wrong_y, right_out));
163+
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, right_y, bias, wrong_out));
164+
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, wrong_y, bias, right_out));
154165
}
155166

156167
TEST_F(OpLinearOutTest, WrongOutShapeDies) {
@@ -161,14 +172,15 @@ TEST_F(OpLinearOutTest, WrongOutShapeDies) {
161172
Tensor x = tf.ones({10, 3});
162173

163174
Tensor y = tf.ones({4, 3});
175+
Tensor bias = tf.ones({4});
164176

165177
// wrong_out has incompatible shape
166178
Tensor right_out = tf.ones({10, 4});
167179
Tensor wrong_out = tf.ones({7, 5});
168180

169-
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, y, wrong_out));
181+
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, y, bias, wrong_out));
170182

171-
EXPECT_TENSOR_EQ(op_linear_out(x, y, right_out), tf.full({10, 4}, 3));
183+
EXPECT_TENSOR_EQ(op_linear_out(x, y, bias, right_out), tf.full({10, 4}, 4));
172184
}
173185

174186
TEST_F(OpLinearOutTest, DynamicShapeUpperBoundSameAsExpected) {
@@ -192,24 +204,25 @@ TEST_F(OpLinearOutTest, DynamicShapeUpperBoundSameAsExpected) {
192204
0.09420186281204224,
193205
0.9070476293563843,
194206
0.9310881495475769});
207+
Tensor bias = tf.ones({4});
195208
Tensor expected_result = tf.make(
196209
{3, 4},
197-
{0.2506277561187744,
198-
0.15225356817245483,
199-
0.18952149152755737,
200-
0.48189279437065125,
201-
0.976661741733551,
202-
0.480360746383667,
203-
0.8310978412628174,
204-
1.6718982458114624,
205-
0.703657865524292,
206-
0.2534688115119934,
207-
0.6746801733970642,
208-
1.0356627702713013});
210+
{1.2506277561187744,
211+
1.15225356817245483,
212+
1.18952149152755737,
213+
1.48189279437065125,
214+
1.976661741733551,
215+
1.480360746383667,
216+
1.8310978412628174,
217+
2.6718982458114624,
218+
1.703657865524292,
219+
1.2534688115119934,
220+
1.6746801733970642,
221+
2.0356627702713013});
209222

210223
Tensor out =
211224
tf.zeros({3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
212-
Tensor ret = op_linear_out(x, y, out);
225+
Tensor ret = op_linear_out(x, y, bias, out);
213226
EXPECT_TENSOR_CLOSE(out, expected_result);
214227
}
215228

@@ -234,24 +247,25 @@ TEST_F(OpLinearOutTest, DynamicShapeUpperBoundLargerThanExpected) {
234247
0.09420186281204224,
235248
0.9070476293563843,
236249
0.9310881495475769});
250+
Tensor bias = tf.ones({4});
237251
Tensor expected_result = tf.make(
238252
{3, 4},
239-
{0.2506277561187744,
240-
0.15225356817245483,
241-
0.18952149152755737,
242-
0.48189279437065125,
243-
0.976661741733551,
244-
0.480360746383667,
245-
0.8310978412628174,
246-
1.6718982458114624,
247-
0.703657865524292,
248-
0.2534688115119934,
249-
0.6746801733970642,
250-
1.0356627702713013});
253+
{1.2506277561187744,
254+
1.15225356817245483,
255+
1.18952149152755737,
256+
1.48189279437065125,
257+
1.976661741733551,
258+
1.480360746383667,
259+
1.8310978412628174,
260+
2.6718982458114624,
261+
1.703657865524292,
262+
1.2534688115119934,
263+
1.6746801733970642,
264+
2.0356627702713013});
251265

252266
Tensor out =
253267
tf.zeros({10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
254-
Tensor ret = op_linear_out(x, y, out);
268+
Tensor ret = op_linear_out(x, y, bias, out);
255269
EXPECT_TENSOR_CLOSE(out, expected_result);
256270
}
257271

@@ -277,24 +291,25 @@ TEST_F(OpLinearOutTest, DynamicShapeUnbound) {
277291
0.09420186281204224,
278292
0.9070476293563843,
279293
0.9310881495475769});
294+
Tensor bias = tf.ones({4});
280295
Tensor expected_result = tf.make(
281296
{3, 4},
282-
{0.2506277561187744,
283-
0.15225356817245483,
284-
0.18952149152755737,
285-
0.48189279437065125,
286-
0.976661741733551,
287-
0.480360746383667,
288-
0.8310978412628174,
289-
1.6718982458114624,
290-
0.703657865524292,
291-
0.2534688115119934,
292-
0.6746801733970642,
293-
1.0356627702713013});
297+
{1.2506277561187744,
298+
1.15225356817245483,
299+
1.18952149152755737,
300+
1.48189279437065125,
301+
1.976661741733551,
302+
1.480360746383667,
303+
1.8310978412628174,
304+
2.6718982458114624,
305+
1.703657865524292,
306+
1.2534688115119934,
307+
1.6746801733970642,
308+
2.0356627702713013});
294309

295310
Tensor out =
296311
tf.zeros({1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
297-
Tensor ret = op_linear_out(x, y, out);
312+
Tensor ret = op_linear_out(x, y, bias, out);
298313
EXPECT_TENSOR_CLOSE(out, expected_result);
299314
}
300315

0 commit comments

Comments
 (0)