Skip to content

Added support for bias in optimized linear operation #9527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions kernels/optimized/cpu/op_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <executorch/kernels/optimized/blas/CPUBlas.h>
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <c10/util/irange.h>

#include <array>

Expand All @@ -24,12 +25,6 @@ Tensor& opt_linear_out(
const Tensor& mat2,
const optional<Tensor>& bias,
Tensor& out) {
ET_KERNEL_CHECK_MSG(
ctx,
!bias.has_value(),
InvalidArgument,
out,
"bias not supported yet in linear");
ET_KERNEL_CHECK(ctx, check_linear_args(in, mat2, out), InvalidArgument, out);

size_t output_ndim = 0;
Expand All @@ -56,6 +51,33 @@ Tensor& opt_linear_out(
size_t k = in.sizes()[in.dim() - 1];
size_t m = mat2.size(0);

// If bias is provided, verify its shape and pre-fill the output tensor.
if (bias.has_value()) {
auto bias_value = bias.value();
// Check that bias is 1D and its size matches m.
ET_KERNEL_CHECK_MSG(
ctx,
bias_value.dim() == 1 && bias_value.size(0) == m,
InvalidArgument,
out,
"Bias must be 1D and of size m. Got: ",
bias_value.size(0),
", expected: ",
m
);
auto bias_ptr = bias_value.const_data_ptr<CTYPE>();
CTYPE* out_ptr = out.mutable_data_ptr<CTYPE>();
// Broadcast the bias to every column of the output.
auto row_size = m * sizeof(CTYPE);
for (const auto col : c10::irange(n)) {
std::memcpy(out_ptr + col * m, bias_ptr, row_size);
}
}

// Set beta to 1 if bias was applied so that GEMM adds to the pre-filled bias,
// otherwise beta remains 0 (i.e. the output is fully overwritten by GEMM).
CTYPE beta_val = bias.has_value() ? static_cast<CTYPE>(1) : static_cast<CTYPE>(0);

executorch::cpublas::gemm(
executorch::cpublas::TransposeType::Transpose,
executorch::cpublas::TransposeType::NoTranspose,
Expand All @@ -67,7 +89,7 @@ Tensor& opt_linear_out(
k,
in.const_data_ptr<CTYPE>(),
k,
static_cast<CTYPE>(0),
beta_val,
out.mutable_data_ptr<CTYPE>(),
m);
});
Expand Down
137 changes: 76 additions & 61 deletions kernels/test/op_linear_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ using executorch::aten::ArrayRef;
using executorch::aten::Scalar;
using executorch::aten::ScalarType;
using executorch::aten::Tensor;
using executorch::aten::optional;
using torch::executor::testing::TensorFactory;

class OpLinearOutTest : public OperatorTest {
protected:
Tensor& op_linear_out(const Tensor& self, const Tensor& mat2, Tensor& out) {
return torch::executor::aten::linear_outf(context_, self, mat2, {}, out);
Tensor& op_linear_out(const Tensor& self, const Tensor& mat2, const optional<Tensor>& bias, Tensor& out) {
return torch::executor::aten::linear_outf(context_, self, mat2, bias, out);
}

template <class CTYPE, executorch::aten::ScalarType DTYPE>
Expand All @@ -47,14 +48,18 @@ class OpLinearOutTest : public OperatorTest {
Tensor x = tf.full({3, 32}, 2);
Tensor y = tf.full({5, 32}, 3);

// Output shape should be (3, 5)
Tensor out = tf.zeros({3, 5});

op_linear_out(x, y, out);

Tensor expected = tf.full({3, 5}, 192);

EXPECT_TENSOR_EQ(out, expected);
// without bias
Tensor out_no_bias = tf.zeros({3, 5});
op_linear_out(x, y, {}, out_no_bias);
Tensor expected_no_bias = tf.full({3, 5}, 192);
EXPECT_TENSOR_EQ(out_no_bias, expected_no_bias);

// with bias
Tensor bias = tf.full({5}, 1);
Tensor out_with_bias = tf.zeros({3, 5});
op_linear_out(x, y, bias, out_with_bias);
Tensor expected_with_bias = tf.full({3, 5}, 193);
EXPECT_TENSOR_EQ(out_with_bias, expected_with_bias);
}
};

Expand All @@ -66,13 +71,15 @@ TEST_F(OpLinearOutTest, OutputDim) {
Tensor y = tf.ones({5, 4});
Tensor out = tf.zeros({3, 5});

Tensor ret = op_linear_out(x, y, out);
Tensor bias = tf.ones({5});

Tensor ret = op_linear_out(x, y, bias, out);

// Should always return the provided out Tensor.
EXPECT_TENSOR_EQ(ret, out);

// Expected tensor, filled with 4.
Tensor expected = tf.full({3, 5}, 4);
// Expected tensor, filled with 5.
Tensor expected = tf.full({3, 5}, 5);

EXPECT_TENSOR_EQ(out, expected);
}
Expand All @@ -94,44 +101,47 @@ TEST_F(OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) {
// Empty input matrices
Tensor x = tf.make({0, 3}, {});
Tensor y = tf.make({0, 3}, {});
Tensor bias = tf.make({0}, {});

// Output matrix is also empty
Tensor out = tf.make({0, 0}, {});

Tensor expected = tf.make({0, 0}, {});

EXPECT_TENSOR_EQ(op_linear_out(x, y, out), expected);
EXPECT_TENSOR_EQ(op_linear_out(x, y, bias, out), expected);
}

TEST_F(OpLinearOutTest, InfinityTensorPasses) {
TensorFactory<ScalarType::Float> tff;

Tensor x = tff.full({3, 4}, std::numeric_limits<float>::infinity());
Tensor y = tff.full({5, 4}, 3);
Tensor bias = tff.full({5}, 1);

// Output shape should be (3, 5)
Tensor out = tff.zeros({3, 5});

Tensor expected = tff.full({3, 5}, std::numeric_limits<float>::infinity());

EXPECT_TENSOR_EQ(op_linear_out(x, y, out), expected);
EXPECT_TENSOR_EQ(op_linear_out(x, y, bias, out), expected);
}

TEST_F(OpLinearOutTest, MismatchedDimensionsDies) {
TensorFactory<ScalarType::Int> tf;

Tensor x = tf.full({2, 2}, 3);
Tensor bias = tf.full({2}, 1);

Tensor wrong_y = tf.full({1, 3}, 1);
Tensor right_y = tf.full({2, 2}, 1);

// Make an empty out tensor and demonstrate that it's empty.
Tensor out = tf.full({2, 2}, 0);

Tensor expected = tf.full({2, 2}, 6);
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, wrong_y, out));
Tensor expected = tf.full({2, 2}, 7);
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, wrong_y, bias, out));

EXPECT_TENSOR_EQ(op_linear_out(x, right_y, out), expected);
EXPECT_TENSOR_EQ(op_linear_out(x, right_y, bias, out), expected);
}

TEST_F(OpLinearOutTest, MismatchedDimensionSizeDies) {
Expand All @@ -140,6 +150,7 @@ TEST_F(OpLinearOutTest, MismatchedDimensionSizeDies) {
}
TensorFactory<ScalarType::Int> tf;
Tensor x = tf.full({2, 2}, 3);
Tensor bias = tf.full({2}, 1);

// wrong_y has incompatible dim
Tensor wrong_y = tf.full({2, 2, 2}, 1);
Expand All @@ -149,8 +160,8 @@ TEST_F(OpLinearOutTest, MismatchedDimensionSizeDies) {
Tensor right_out = tf.ones({2, 2});
Tensor wrong_out = tf.ones({2, 2, 3});

ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, right_y, wrong_out));
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, wrong_y, right_out));
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, right_y, bias, wrong_out));
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, wrong_y, bias, right_out));
}

TEST_F(OpLinearOutTest, WrongOutShapeDies) {
Expand All @@ -161,14 +172,15 @@ TEST_F(OpLinearOutTest, WrongOutShapeDies) {
Tensor x = tf.ones({10, 3});

Tensor y = tf.ones({4, 3});
Tensor bias = tf.ones({4});

// wrong_out has incompatible shape
Tensor right_out = tf.ones({10, 4});
Tensor wrong_out = tf.ones({7, 5});

ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, y, wrong_out));
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, y, bias, wrong_out));

EXPECT_TENSOR_EQ(op_linear_out(x, y, right_out), tf.full({10, 4}, 3));
EXPECT_TENSOR_EQ(op_linear_out(x, y, bias, right_out), tf.full({10, 4}, 4));
}

TEST_F(OpLinearOutTest, DynamicShapeUpperBoundSameAsExpected) {
Expand All @@ -192,24 +204,25 @@ TEST_F(OpLinearOutTest, DynamicShapeUpperBoundSameAsExpected) {
0.09420186281204224,
0.9070476293563843,
0.9310881495475769});
Tensor bias = tf.ones({4});
Tensor expected_result = tf.make(
{3, 4},
{0.2506277561187744,
0.15225356817245483,
0.18952149152755737,
0.48189279437065125,
0.976661741733551,
0.480360746383667,
0.8310978412628174,
1.6718982458114624,
0.703657865524292,
0.2534688115119934,
0.6746801733970642,
1.0356627702713013});
{1.2506277561187744,
1.15225356817245483,
1.18952149152755737,
1.48189279437065125,
1.976661741733551,
1.480360746383667,
1.8310978412628174,
2.6718982458114624,
1.703657865524292,
1.2534688115119934,
1.6746801733970642,
2.0356627702713013});

Tensor out =
tf.zeros({3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
Tensor ret = op_linear_out(x, y, out);
Tensor ret = op_linear_out(x, y, bias, out);
EXPECT_TENSOR_CLOSE(out, expected_result);
}

Expand All @@ -234,24 +247,25 @@ TEST_F(OpLinearOutTest, DynamicShapeUpperBoundLargerThanExpected) {
0.09420186281204224,
0.9070476293563843,
0.9310881495475769});
Tensor bias = tf.ones({4});
Tensor expected_result = tf.make(
{3, 4},
{0.2506277561187744,
0.15225356817245483,
0.18952149152755737,
0.48189279437065125,
0.976661741733551,
0.480360746383667,
0.8310978412628174,
1.6718982458114624,
0.703657865524292,
0.2534688115119934,
0.6746801733970642,
1.0356627702713013});
{1.2506277561187744,
1.15225356817245483,
1.18952149152755737,
1.48189279437065125,
1.976661741733551,
1.480360746383667,
1.8310978412628174,
2.6718982458114624,
1.703657865524292,
1.2534688115119934,
1.6746801733970642,
2.0356627702713013});

Tensor out =
tf.zeros({10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
Tensor ret = op_linear_out(x, y, out);
Tensor ret = op_linear_out(x, y, bias, out);
EXPECT_TENSOR_CLOSE(out, expected_result);
}

Expand All @@ -277,24 +291,25 @@ TEST_F(OpLinearOutTest, DynamicShapeUnbound) {
0.09420186281204224,
0.9070476293563843,
0.9310881495475769});
Tensor bias = tf.ones({4});
Tensor expected_result = tf.make(
{3, 4},
{0.2506277561187744,
0.15225356817245483,
0.18952149152755737,
0.48189279437065125,
0.976661741733551,
0.480360746383667,
0.8310978412628174,
1.6718982458114624,
0.703657865524292,
0.2534688115119934,
0.6746801733970642,
1.0356627702713013});
{1.2506277561187744,
1.15225356817245483,
1.18952149152755737,
1.48189279437065125,
1.976661741733551,
1.480360746383667,
1.8310978412628174,
2.6718982458114624,
1.703657865524292,
1.2534688115119934,
1.6746801733970642,
2.0356627702713013});

Tensor out =
tf.zeros({1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
Tensor ret = op_linear_out(x, y, out);
Tensor ret = op_linear_out(x, y, bias, out);
EXPECT_TENSOR_CLOSE(out, expected_result);
}

Expand Down
Loading