Skip to content

Commit 747b6b7

Browse files
authored
Add zero-padding support to the reference kernel (#571)
* Add zero-padding support to the reference kernel * Cleanup asserts
1 parent 87fee7a commit 747b6b7

File tree

7 files changed

+83
-61
lines changed

7 files changed

+83
-61
lines changed

larq_compute_engine/core/bconv2d/reference.h

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ inline void BConv2DReference(
6767
const int output_height = output_shape.Dims(1);
6868
const int output_width = output_shape.Dims(2);
6969

70+
const bool zero_padding =
71+
bconv2d_params->padding_type == kTfLitePaddingSame &&
72+
bconv2d_params->pad_value == 0;
73+
74+
// For n channels, a popcount of n/2 of the {0,1} bits would correspond to 0
75+
// in the {-1,1} representation. So n/2 can be considered the 'zero point'.
76+
const int binary_zero_point =
77+
(bconv2d_params->channels_in / bconv2d_params->groups) / 2;
78+
7079
TFLITE_DCHECK_EQ(input_depth_per_group * bconv2d_params->groups,
7180
packed_input_shape.Dims(3));
7281
TFLITE_DCHECK_EQ(output_depth_per_group * bconv2d_params->groups,
@@ -84,16 +93,18 @@ inline void BConv2DReference(
8493
AccumScalar accum = AccumScalar(0);
8594
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
8695
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
96+
const int in_x = in_x_origin + dilation_width_factor * filter_x;
97+
const int in_y = in_y_origin + dilation_height_factor * filter_y;
98+
const bool inside = ((in_x >= 0) && (in_x < input_width) &&
99+
(in_y >= 0) && (in_y < input_height));
100+
if (zero_padding && !inside) {
101+
accum += binary_zero_point;
102+
continue;
103+
}
87104
for (int in_channel = 0; in_channel < input_depth_per_group;
88105
++in_channel) {
89-
const int in_x = in_x_origin + dilation_width_factor * filter_x;
90-
const int in_y =
91-
in_y_origin + dilation_height_factor * filter_y;
92-
// `pad_value=1`, which means the bitpacked value is 0, so we
93-
// set `input_value=0`
94-
TBitpacked input_value = 0;
95-
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
96-
(in_y < input_height)) {
106+
TBitpacked input_value = 0; // represents a +1
107+
if (inside) {
97108
input_value = packed_input_data[Offset(
98109
packed_input_shape, batch, in_y, in_x,
99110
group * input_depth_per_group + in_channel)];

larq_compute_engine/tests/end2end_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ def preprocess(data):
134134

135135

136136
def assert_model_output(model_lce, inputs, outputs, rtol, atol):
137-
interpreter = Interpreter(model_lce, num_threads=min(os.cpu_count(), 4))
137+
interpreter = Interpreter(
138+
model_lce, num_threads=min(os.cpu_count(), 4), use_reference_bconv=False
139+
)
138140
actual_outputs = interpreter.predict(inputs)
139141
np.testing.assert_allclose(actual_outputs, outputs, rtol=rtol, atol=atol)
140142

larq_compute_engine/tflite/kernels/bconv2d.cc

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,6 @@ void* Init(TfLiteContext* context, const char* buffer, std::size_t length) {
124124

125125
op_data->fused_activation_function = ConvertActivation(
126126
(ActivationFunctionType)m["fused_activation_function"].AsInt32());
127-
if (bconv2d_params->padding_type == kTfLitePaddingSame &&
128-
bconv2d_params->pad_value != 1 &&
129-
op_data->fused_activation_function != kTfLiteActNone) {
130-
TF_LITE_KERNEL_LOG(
131-
context,
132-
"Fused activations are only supported with valid or one-padding.");
133-
return op_data;
134-
}
135127

136128
// It's not possible to return an error code in this method. If we get to here
137129
// without returning early, initialisation has succeeded without error, and so
@@ -195,6 +187,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
195187
bconv2d_params->groups = groups;
196188
}
197189

190+
if (bconv2d_params->padding_type == kTfLitePaddingSame &&
191+
bconv2d_params->pad_value == 0) {
192+
TF_LITE_ENSURE_MSG(
193+
context,
194+
(kernel_type == KernelType::kReference &&
195+
bconv2d_params->channels_in % 2 == 0) ||
196+
(kernel_type != KernelType::kReference &&
197+
output->type == kTfLiteFloat32 &&
198+
op_data->fused_activation_function == kTfLiteActNone),
199+
"Zero-padding is only supported by the reference kernel with an even "
200+
"number of input channels, or when using "
201+
"float output with no fused activation function.");
202+
}
203+
198204
// Compute the padding and output values (height, width)
199205
int out_width, out_height;
200206
bconv2d_params->padding_values = ComputePaddingHeightWidth(
@@ -210,11 +216,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
210216
TF_LITE_ENSURE_EQ(context, thresholds->type, kTfLiteInt32);
211217
TF_LITE_ENSURE_EQ(context, SizeOfDimension(thresholds, 0),
212218
bconv2d_params->channels_out);
213-
TF_LITE_ENSURE_MSG(context,
214-
bconv2d_params->padding_type != kTfLitePaddingSame ||
215-
bconv2d_params->pad_value == 1,
216-
"Writing bitpacked output is only supported with "
217-
"valid or one-padding.");
218219
} else {
219220
TF_LITE_ENSURE_EQ(context, post_activation_multiplier->type,
220221
kTfLiteFloat32);
@@ -230,20 +231,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
230231
if (output->type == kTfLiteInt8) {
231232
TF_LITE_ENSURE_EQ(context, output->quantization.type,
232233
kTfLiteAffineQuantization);
233-
TF_LITE_ENSURE_MSG(
234-
context,
235-
bconv2d_params->padding_type != kTfLitePaddingSame ||
236-
bconv2d_params->pad_value == 1,
237-
"8-bit quantization is only supported with valid or one-padding");
238234
}
239235

240-
if (kernel_type == KernelType::kReference) {
241-
TF_LITE_ENSURE_MSG(
242-
context,
243-
bconv2d_params->padding_type != kTfLitePaddingSame ||
244-
bconv2d_params->pad_value == 1,
245-
"The reference kernel only supports valid or one-padding.");
246-
} else if (kernel_type == KernelType::kOptimizedIndirectBGEMM) {
236+
if (kernel_type == KernelType::kOptimizedIndirectBGEMM) {
247237
TF_LITE_ENSURE_MSG(
248238
context, input->allocation_type != kTfLiteDynamic,
249239
"The input tensor must not have dynamic allocation type");
@@ -374,9 +364,9 @@ void OneTimeSetup(TfLiteContext* context, TfLiteNode* node, OpData* op_data) {
374364
const std::int32_t backtransform_add =
375365
filter_shape.Dims(1) * filter_shape.Dims(2) * channels_in_per_group;
376366
const double output_scale =
377-
output->type == kTfLiteInt8 ? output->params.scale : 1.0f;
367+
output->type == kTfLiteInt8 ? output->params.scale : 1.0;
378368
const double output_zero_point =
379-
output->type == kTfLiteInt8 ? output->params.zero_point : 0.0f;
369+
output->type == kTfLiteInt8 ? output->params.zero_point : 0.0;
380370

381371
for (int i = 0; i < bconv2d_params->channels_out; ++i) {
382372
const double post_mul =

larq_compute_engine/tflite/kernels/lce_ops_register.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,24 @@ namespace tflite {
1313
TfLiteRegistration* Register_QUANTIZE();
1414
TfLiteRegistration* Register_DEQUANTIZE();
1515
TfLiteRegistration* Register_BCONV_2D();
16+
TfLiteRegistration* Register_BCONV_2D_REF();
1617
TfLiteRegistration* Register_BMAXPOOL_2D();
1718

1819
// By calling this function on TF lite mutable op resolver, all LCE custom ops
1920
// will be registerd to the op resolver.
20-
inline void RegisterLCECustomOps(::tflite::MutableOpResolver* resolver) {
21+
inline void RegisterLCECustomOps(::tflite::MutableOpResolver* resolver,
22+
const bool use_reference_bconv = false) {
2123
resolver->AddCustom("LceQuantize",
2224
compute_engine::tflite::Register_QUANTIZE());
2325
resolver->AddCustom("LceDequantize",
2426
compute_engine::tflite::Register_DEQUANTIZE());
25-
resolver->AddCustom("LceBconv2d",
26-
compute_engine::tflite::Register_BCONV_2D());
27+
if (use_reference_bconv) {
28+
resolver->AddCustom("LceBconv2d",
29+
compute_engine::tflite::Register_BCONV_2D_REF());
30+
} else {
31+
resolver->AddCustom("LceBconv2d",
32+
compute_engine::tflite::Register_BCONV_2D());
33+
}
2734
resolver->AddCustom("LceBMaxPool2d",
2835
compute_engine::tflite::Register_BMAXPOOL_2D());
2936
};

larq_compute_engine/tflite/python/interpreter.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class Interpreter:
4343
# Arguments
4444
flatbuffer_model: A serialized Larq Compute Engine model in the flatbuffer format.
4545
num_threads: The number of threads used by the interpreter.
46+
use_reference_bconv: When True, uses the reference implementation of LceBconv2d.
4647
4748
# Attributes
4849
input_types: Returns a list of input types.
@@ -51,9 +52,14 @@ class Interpreter:
5152
output_shapes: Returns a list of output shapes.
5253
"""
5354

54-
def __init__(self, flatbuffer_model: bytes, num_threads: int = 1):
55+
def __init__(
56+
self,
57+
flatbuffer_model: bytes,
58+
num_threads: int = 1,
59+
use_reference_bconv: bool = False,
60+
):
5561
self.interpreter = interpreter_wrapper_lite.LiteInterpreter(
56-
flatbuffer_model, num_threads
62+
flatbuffer_model, num_threads, use_reference_bconv
5763
)
5864

5965
@property

larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ class LiteInterpreterWrapper
99
: public InterpreterWrapperBase<tflite::Interpreter> {
1010
public:
1111
LiteInterpreterWrapper(const pybind11::bytes& flatbuffer,
12-
const int num_threads);
12+
const int num_threads = 1,
13+
const bool use_reference_bconv = false);
1314
~LiteInterpreterWrapper(){};
1415

1516
private:
@@ -21,7 +22,8 @@ class LiteInterpreterWrapper
2122
};
2223

2324
LiteInterpreterWrapper::LiteInterpreterWrapper(
24-
const pybind11::bytes& flatbuffer, const int num_threads = 1) {
25+
const pybind11::bytes& flatbuffer, const int num_threads,
26+
const bool use_reference_bconv) {
2527
// Make a copy of the flatbuffer because it can get deallocated after the
2628
// constructor is done
2729
flatbuffer_ = static_cast<std::string>(flatbuffer);
@@ -34,7 +36,8 @@ LiteInterpreterWrapper::LiteInterpreterWrapper(
3436

3537
// Build the interpreter
3638
resolver_ = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
37-
compute_engine::tflite::RegisterLCECustomOps(resolver_.get());
39+
compute_engine::tflite::RegisterLCECustomOps(resolver_.get(),
40+
use_reference_bconv);
3841

3942
tflite::InterpreterBuilder builder(*model_, *resolver_);
4043
builder(&interpreter_, num_threads);
@@ -46,7 +49,7 @@ LiteInterpreterWrapper::LiteInterpreterWrapper(
4649

4750
PYBIND11_MODULE(interpreter_wrapper_lite, m) {
4851
pybind11::class_<LiteInterpreterWrapper>(m, "LiteInterpreter")
49-
.def(pybind11::init<const pybind11::bytes&, const int>())
52+
.def(pybind11::init<const pybind11::bytes&, const int, const bool>())
5053
.def_property("input_types", &LiteInterpreterWrapper::get_input_types,
5154
nullptr)
5255
.def_property("output_types", &LiteInterpreterWrapper::get_output_types,

larq_compute_engine/tflite/tests/bconv2d_test.cc

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ void runTest(const TestParam& param) {
455455
constexpr bool write_bitpacked_output =
456456
std::is_same<TOutput, TBitpacked>::value;
457457
constexpr bool int8_output = std::is_same<TOutput, std::int8_t>::value;
458+
constexpr bool float_output = std::is_same<TOutput, float>::value;
458459

459460
const Padding builtin_padding =
460461
(padding == Padding_ONE ? Padding_VALID : padding);
@@ -492,22 +493,24 @@ void runTest(const TestParam& param) {
492493
const int bitpacked_filters_num_elem =
493494
filter_count * filter_height * filter_width * packed_channels_per_group;
494495

495-
// the reference implementation only support one-padding
496496
const auto is_reference_registration =
497497
(registration == compute_engine::tflite::Register_BCONV_2D_REF);
498498

499-
if (padding == Padding_SAME &&
500-
(is_reference_registration || activation == ActivationFunctionType_RELU ||
501-
write_bitpacked_output || int8_output)) {
502-
// Zero-padding is not supported in combination with:
503-
// - The reference implementation
504-
// - Fused ReLu
505-
// - Writing bitpacked output
506-
// - Int8 output
507-
// We could use `EXPECT_DEATH` here but it is extremely slow. Therefore we
508-
// have a separate test below, and here we just skip.
509-
GTEST_SKIP();
510-
return;
499+
if (padding == Padding_SAME) {
500+
if (is_reference_registration) {
501+
if (input_depth % 2 != 0) {
502+
GTEST_SKIP();
503+
return;
504+
}
505+
} else {
506+
if (!float_output || activation == ActivationFunctionType_RELU) {
507+
// We could use `EXPECT_DEATH` here but it is
508+
// extremely slow. Therefore we have a separate test below, and here we
509+
// just skip.
510+
GTEST_SKIP();
511+
return;
512+
}
513+
}
511514
}
512515

513516
std::random_device rd;
@@ -874,7 +877,7 @@ TEST(BConv2DTests, ReluErrorDeathTest) {
874877
threshold_tensor, 64, 1, 1, Padding_SAME, 0,
875878
ActivationFunctionType_RELU, 1, 1, 1);
876879
},
877-
"Fused activations are only supported with valid or one-padding.");
880+
"Zero-padding is only supported by");
878881

879882
// Test if writing bitpacked output throws an error in combination with
880883
// zero-padding.
@@ -886,7 +889,7 @@ TEST(BConv2DTests, ReluErrorDeathTest) {
886889
post_tensor, threshold_tensor, 64, 1, 1, Padding_SAME, 0,
887890
ActivationFunctionType_NONE, 1, 1, 1);
888891
},
889-
"Writing bitpacked output is only supported with valid or one-padding.");
892+
"Zero-padding is only supported by");
890893
}
891894

892895
TEST(BConv2DTests, Int8ErrorDeathTest) {
@@ -910,7 +913,7 @@ TEST(BConv2DTests, Int8ErrorDeathTest) {
910913
threshold_tensor, 64, 1, 1, Padding_SAME, 0,
911914
ActivationFunctionType_NONE, 1, 1, 1);
912915
},
913-
"8-bit quantization is only supported with valid or one-padding.");
916+
"Zero-padding is only supported by");
914917
}
915918

916919
} // namespace testing

0 commit comments

Comments
 (0)