11/*******************************************************************************
2- * Copyright 2019-2022 Intel Corporation
2+ * Copyright 2019-2023 Intel Corporation
33*
44* Licensed under the Apache License, Version 2.0 (the "License");
55* you may not use this file except in compliance with the License.
@@ -6582,7 +6582,7 @@ void gemm_kernel_generator_t<hw>::outerProductSystolic(int h, int ha, int hb,
65826582// Decide whether to use the legacy post-op injector inside C update.
65836583// Needed if we can't convert C to f32 in-place, but doesn't support binary post-ops.
65846584static inline bool useEltwiseInjector(const GEMMProblem &problem) {
6585- return problem.hasPostOp () && (problem.Tc.size() < 4);
6585+ return problem.hasNonSum1PostOp () && (problem.Tc.size() < 4);
65866586}
65876587
65886588// Perform C update operation on C_acc, given original C data in C_load.
@@ -14546,6 +14546,12 @@ bool gemm_kernel_generator_t<hw>::gemmBodyInternal(
1454614546 subproblem.beta_real = 1;
1454714547 subproblem.beta_imag = 0;
1454814548
14549+ if (subproblem.postOps.len() > 0) {
14550+ auto &lastPO = subproblem.postOps
14551+ .entry_[subproblem.postOps.len() - 1];
14552+ if (lastPO.kind == primitive_kind::sum) lastPO.sum.scale = 1.0f;
14553+ }
14554+
1454914555 if (!gemmUpdateC(subproblem, strategy, substate)) return false;
1455014556
1455114557 if (checkBeta0) {
@@ -14566,6 +14572,14 @@ bool gemm_kernel_generator_t<hw>::gemmBodyInternal(
1456614572 subproblem.beta_real = 0;
1456714573 subproblem.beta_imag = 0;
1456814574
14575+ if (subproblem.postOps.len() > 0) {
14576+ auto &lastPO = subproblem.postOps
14577+ .entry_[subproblem.postOps.len() - 1];
14578+ if (lastPO.kind == primitive_kind::sum)
14579+ subproblem.postOps.entry_.resize(
14580+ subproblem.postOps.len() - 1);
14581+ }
14582+
1456914583 substrategy.C.atomic = false;
1457014584
1457114585 if (!gemmUpdateC(subproblem, substrategy, substate)) return false;
0 commit comments