@@ -34,13 +34,8 @@ struct acl_post_ops_t {
3434 status_t init (engine_t *engine, post_ops_t &post_ops,
3535 const memory_desc_t &dst_md) {
3636
37- // Disable ACL post ops when in f16 mode. This is because the oneDNN reference runs
38- // the post op in f32 and then casts down to f16 while ACL runs the post op in f16
39- // leading to a loss of accuracy compared to ref.
40- ACL_CHECK_SUPPORT (
41- post_ops.len () >= 1 && dst_md.data_type == data_type::f16 ,
42- " post ops cannot be executed in fp16" );
4337 CHECK (post_ops.set_default_formats (&dst_md));
38+ dst_data_type = dst_md.data_type ;
4439
4540 // Reset properties derived from post_ops
4641 sum_index = -1 ;
@@ -105,8 +100,15 @@ struct acl_post_ops_t {
105100 eltwise_desc.alg_kind = po.eltwise .alg ;
106101 eltwise_desc.alpha = po.eltwise .alpha ;
107102 eltwise_desc.beta = po.eltwise .beta ;
108- eltwise_desc.src_desc = dst_md;
109- eltwise_desc.dst_desc = dst_md;
103+ memory_desc_t temp_dst = dst_md;
104+ // pass eltwise a desc with f32 datatype to perform the operation in fp32 rather than fp16
105+ // since oneDNN requires all post-ops to run in fp32.
106+ // we don't need to do that to the other post-ops as executing them in fp16 yields the same result.
107+ if (dst_data_type == data_type::f16 ) {
108+ temp_dst.data_type = data_type::f32 ;
109+ }
110+ eltwise_desc.src_desc = temp_dst;
111+ eltwise_desc.dst_desc = temp_dst;
110112 eltwise_desc.prop_kind = prop_kind_t ::dnnl_forward;
111113 auto empty_attr = dnnl_primitive_attr ();
112114 typename acl_eltwise_fwd_t ::pd_t acl_eltwise_pd (
@@ -135,16 +137,12 @@ struct acl_post_ops_t {
135137 const memory_desc_t &dst_md,
136138 arm_compute::ActivationLayerInfo &act_info_to_fuse) {
137139
138- // Disable ACL post ops when in f16 mode. This is because the oneDNN reference runs
139- // the post op in f32 and then casts down to f16 while ACL runs the post op in f16
140- // leading to a loss of accuracy compared to ref.
141- ACL_CHECK_SUPPORT (
142- base_post_ops.len () >= 1 && dst_md.data_type == data_type::f16 ,
143- " post ops cannot be executed in fp16" );
144140 CHECK (base_post_ops.set_default_formats (&dst_md));
145-
146- // If the first entry is eltwise, we fuse it
147- if (base_post_ops.len () >= 1 && base_post_ops.entry_ [0 ].is_eltwise ()) {
141+ dst_data_type = dst_md.data_type ;
142+ // If the first entry is eltwise, we fuse it, except when the datatype
143+ // is fp16 because in this case we want to execute the eltwise in fp32.
144+ if (base_post_ops.len () >= 1 && base_post_ops.entry_ [0 ].is_eltwise ()
145+ && dst_data_type != data_type::f16 ) {
148146
149147 const auto &first_po = base_post_ops.entry_ [0 ].eltwise ;
150148 ACL_CHECK_SUPPORT (first_po.scale != 1 .0f ,
@@ -181,7 +179,7 @@ struct acl_post_ops_t {
181179private:
182180 // Index of the sum post op if there is one, < 0 means no sum
183181 int sum_index = -1 ;
184-
182+ data_type_t dst_data_type;
185183 // Vector of primitives used to execute the post ops. They are constructed
186184 // in init to be either acl_binary_t (for sum, add, sub, div, mul, min and
187185 // max) or acl_eltwise_fwd_t (for relu, elu, tanh, square, abs etc)
0 commit comments