Skip to content

Commit a1e6bc5

Browse files
committed
gpu: jit: gemm: remove unnecessary type conversions with sum post-ops
1 parent dbb7c28 commit a1e6bc5

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2019-2022 Intel Corporation
2+
* Copyright 2019-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -6582,7 +6582,7 @@ void gemm_kernel_generator_t<hw>::outerProductSystolic(int h, int ha, int hb,
65826582
// Decide whether to use the legacy post-op injector inside C update.
65836583
// Needed if we can't convert C to f32 in-place, but doesn't support binary post-ops.
65846584
static inline bool useEltwiseInjector(const GEMMProblem &problem) {
6585-
return problem.hasPostOp() && (problem.Tc.size() < 4);
6585+
return problem.hasNonSum1PostOp() && (problem.Tc.size() < 4);
65866586
}
65876587

65886588
// Perform C update operation on C_acc, given original C data in C_load.
@@ -14546,6 +14546,12 @@ bool gemm_kernel_generator_t<hw>::gemmBodyInternal(
1454614546
subproblem.beta_real = 1;
1454714547
subproblem.beta_imag = 0;
1454814548

14549+
if (subproblem.postOps.len() > 0) {
14550+
auto &lastPO = subproblem.postOps
14551+
.entry_[subproblem.postOps.len() - 1];
14552+
if (lastPO.kind == primitive_kind::sum) lastPO.sum.scale = 1.0f;
14553+
}
14554+
1454914555
if (!gemmUpdateC(subproblem, strategy, substate)) return false;
1455014556

1455114557
if (checkBeta0) {
@@ -14566,6 +14572,14 @@ bool gemm_kernel_generator_t<hw>::gemmBodyInternal(
1456614572
subproblem.beta_real = 0;
1456714573
subproblem.beta_imag = 0;
1456814574

14575+
if (subproblem.postOps.len() > 0) {
14576+
auto &lastPO = subproblem.postOps
14577+
.entry_[subproblem.postOps.len() - 1];
14578+
if (lastPO.kind == primitive_kind::sum)
14579+
subproblem.postOps.entry_.resize(
14580+
subproblem.postOps.len() - 1);
14581+
}
14582+
1456914583
substrategy.C.atomic = false;
1457014584

1457114585
if (!gemmUpdateC(subproblem, substrategy, substate)) return false;

src/gpu/jit/gemm/gen_gemm_kernel_generator.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2019-2022 Intel Corporation
2+
* Copyright 2019-2023 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -817,6 +817,11 @@ struct GEMMProblem : public CommonProblem {
817817
std::vector<bool> binaryBatch;
818818

819819
bool hasPostOp() const { return postOps.len() > 0; }
820+
bool hasNonSum1PostOp() const {
821+
for (const auto &e : postOps.entry_)
822+
if (!e.is_sum()) return true;
823+
return false;
824+
}
820825
bool hasBinaryPostOp() const {
821826
for (int idx = 0; idx < postOps.len(); idx++)
822827
if (postOps.entry_[idx].is_binary()) return true;
@@ -840,7 +845,7 @@ struct GEMMProblem : public CommonProblem {
840845
if (!(alpha1() || alphaM1())) return true;
841846
if (!(beta0() || beta1())) return true;
842847
if (beta1() && !Tc_ext.isSubsetOf(Tc)) return true;
843-
if (hasPostOp()) return true;
848+
if (hasNonSum1PostOp()) return true;
844849
return false;
845850
}
846851

0 commit comments

Comments
 (0)