Skip to content

Commit f0229f0

Browse files
fadara01dzarukin
authored andcommitted
src: cpu: aarch64: re-enable fp16 post-ops for aarch64
Perform the eltwise post-ops in fp32 instead of fp16 by casting up (to fp32) before executing the eltwise op and then casting back down (to fp16) after the operation completes. With this change, all fp16 benchdnn tests pass on aarch64.
1 parent 9d2a5e5 commit f0229f0

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

src/cpu/aarch64/acl_post_ops.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2022 Arm Ltd. and affiliates
2+
* Copyright 2022-2023 Arm Ltd. and affiliates
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
1414
* limitations under the License.
1515
*******************************************************************************/
1616

17+
#include "common/float16.hpp"
1718
#include "cpu/aarch64/acl_gemm_convolution.hpp"
1819

1920
namespace dnnl {
@@ -59,7 +60,35 @@ status_t acl_post_ops_t::execute(const exec_ctx_t &ctx, void *src_orig) const {
5960
= dynamic_cast<acl_eltwise_fwd_t *>(post_op.get());
6061
if (eltwise_post_op == nullptr) return status::runtime_error;
6162

62-
CHECK(eltwise_post_op->execute_forward(ctx, src, src));
63+
if (dst_data_type == data_type::f16) {
64+
// in this case we want to cast the src tensor up to fp32
65+
arm_compute::TensorInfo src_info
66+
= eltwise_post_op->pd()->aep.data_info;
67+
// new src tensor with fp32 datatype
68+
arm_compute::Tensor src_tensor;
69+
src_tensor.allocator()->init(src_info);
70+
src_tensor.allocator()->allocate();
71+
float *src_f32 = (float *)src_tensor.buffer();
72+
// total_size gives the size in bytes, we divide by 4 because the src_tensor is fp32
73+
size_t num_elements = src_tensor.info()->total_size() / 4;
74+
// cast src up to fp32 and store the result in src_f32
75+
cvt_float16_to_float(
76+
src_f32, (dnnl::impl::float16_t *)src, num_elements);
77+
// perform the operation in fp32
78+
status_t eltwise_status = eltwise_post_op->execute_forward(
79+
ctx, src_f32, src_f32);
80+
if (eltwise_status == status::success) {
81+
// cast src_f32 down and store final result in src
82+
cvt_float_to_float16((dnnl::impl::float16_t *)src, src_f32,
83+
num_elements);
84+
}
85+
src_tensor.allocator()->free();
86+
CHECK(eltwise_status);
87+
88+
} else {
89+
CHECK(eltwise_post_op->execute_forward(ctx, src, src));
90+
}
91+
6392
} else {
6493
return status::runtime_error;
6594
}

src/cpu/aarch64/acl_post_ops.hpp

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,8 @@ struct acl_post_ops_t {
3434
status_t init(engine_t *engine, post_ops_t &post_ops,
3535
const memory_desc_t &dst_md) {
3636

37-
// Disable ACL post ops when in f16 mode. This is because the oneDNN reference runs
38-
// the post op in f32 and then casts down to f16 while ACL runs the post op in f16
39-
// leading to a loss of accuracy compared to ref.
40-
ACL_CHECK_SUPPORT(
41-
post_ops.len() >= 1 && dst_md.data_type == data_type::f16,
42-
"post ops cannot be executed in fp16");
4337
CHECK(post_ops.set_default_formats(&dst_md));
38+
dst_data_type = dst_md.data_type;
4439

4540
// Reset properties derived from post_ops
4641
sum_index = -1;
@@ -105,8 +100,15 @@ struct acl_post_ops_t {
105100
eltwise_desc.alg_kind = po.eltwise.alg;
106101
eltwise_desc.alpha = po.eltwise.alpha;
107102
eltwise_desc.beta = po.eltwise.beta;
108-
eltwise_desc.src_desc = dst_md;
109-
eltwise_desc.dst_desc = dst_md;
103+
memory_desc_t temp_dst = dst_md;
104+
// pass eltwise a desc with f32 datatype to perform the operation in fp32 rather than fp16
105+
// since oneDNN requires all post-ops to run in fp32.
106+
// we don't need to do that to the other post-ops as executing them in fp16 yields the same result.
107+
if (dst_data_type == data_type::f16) {
108+
temp_dst.data_type = data_type::f32;
109+
}
110+
eltwise_desc.src_desc = temp_dst;
111+
eltwise_desc.dst_desc = temp_dst;
110112
eltwise_desc.prop_kind = prop_kind_t::dnnl_forward;
111113
auto empty_attr = dnnl_primitive_attr();
112114
typename acl_eltwise_fwd_t::pd_t acl_eltwise_pd(
@@ -135,16 +137,12 @@ struct acl_post_ops_t {
135137
const memory_desc_t &dst_md,
136138
arm_compute::ActivationLayerInfo &act_info_to_fuse) {
137139

138-
// Disable ACL post ops when in f16 mode. This is because the oneDNN reference runs
139-
// the post op in f32 and then casts down to f16 while ACL runs the post op in f16
140-
// leading to a loss of accuracy compared to ref.
141-
ACL_CHECK_SUPPORT(
142-
base_post_ops.len() >= 1 && dst_md.data_type == data_type::f16,
143-
"post ops cannot be executed in fp16");
144140
CHECK(base_post_ops.set_default_formats(&dst_md));
145-
146-
// If the first entry is eltwise, we fuse it
147-
if (base_post_ops.len() >= 1 && base_post_ops.entry_[0].is_eltwise()) {
141+
dst_data_type = dst_md.data_type;
142+
// If the first entry is eltwise, we fuse it, except when the datatype
143+
// is fp16 because in this case we want to execute the eltwise in fp32.
144+
if (base_post_ops.len() >= 1 && base_post_ops.entry_[0].is_eltwise()
145+
&& dst_data_type != data_type::f16) {
148146

149147
const auto &first_po = base_post_ops.entry_[0].eltwise;
150148
ACL_CHECK_SUPPORT(first_po.scale != 1.0f,
@@ -181,7 +179,7 @@ struct acl_post_ops_t {
181179
private:
182180
// Index of the sum post op if there is one, < 0 means no sum
183181
int sum_index = -1;
184-
182+
data_type_t dst_data_type;
185183
// Vector of primitives used to execute the post ops. They are constructed
186184
// in init to be either acl_binary_t (for sum, add, sub, div, mul, min and
187185
// max) or acl_eltwise_fwd_t (for relu, elu, tanh, square, abs etc)

0 commit comments

Comments
 (0)