Skip to content

Commit 6f5621a

Browse files
committed
cpu: gemm: remove a templated function
1 parent 406a079 commit 6f5621a

File tree

3 files changed

+24
-23
lines changed

3 files changed

+24
-23
lines changed

src/common/gemm.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2021-2023 Intel Corporation
2+
* Copyright 2021-2024 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -116,10 +116,9 @@ dnnl_status_t dnnl_gemm_u8s8s32(char transa, char transb, char offsetc, dim_t M,
116116
#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
117117
status_t status = dnnl_success;
118118
MAYBE_VERBOSE(status, "u8", "s8", "s32",
119-
MAYBE_RUN_STACK_CHECKER(dnnl_gemm_u8s8s32,
120-
cpu::gemm_s8x8s32<uint8_t>, &transb, &transa,
121-
c2f_offsetC(&offsetc), &N, &M, &K, &alpha, B, &ldb, &bo, A,
122-
&lda, &ao, &beta, C, &ldc, co));
119+
MAYBE_RUN_STACK_CHECKER(dnnl_gemm_u8s8s32, cpu::gemm_s8x8s32,
120+
&transb, &transa, c2f_offsetC(&offsetc), &N, &M, &K, &alpha,
121+
B, &ldb, &bo, A, &lda, &ao, &beta, C, &ldc, co));
123122
return status;
124123
#else
125124
return dnnl::impl::status::unimplemented;
@@ -133,10 +132,9 @@ dnnl_status_t dnnl_gemm_s8s8s32(char transa, char transb, char offsetc, dim_t M,
133132
#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
134133
status_t status = dnnl_success;
135134
MAYBE_VERBOSE(status, "s8", "s8", "s32",
136-
MAYBE_RUN_STACK_CHECKER(dnnl_gemm_s8s8s32,
137-
cpu::gemm_s8x8s32<int8_t>, &transb, &transa,
138-
c2f_offsetC(&offsetc), &N, &M, &K, &alpha, B, &ldb, &bo, A,
139-
&lda, &ao, &beta, C, &ldc, co));
135+
MAYBE_RUN_STACK_CHECKER(dnnl_gemm_s8s8s32, cpu::gemm_s8x8s32,
136+
&transb, &transa, c2f_offsetC(&offsetc), &N, &M, &K, &alpha,
137+
B, &ldb, &bo, A, &lda, &ao, &beta, C, &ldc, co));
140138
return status;
141139
#else
142140
return dnnl::impl::status::unimplemented;
@@ -184,9 +182,9 @@ dnnl_status_t dnnl_threadpool_interop_gemm_u8s8s32(char transa, char transb,
184182
status_t status = dnnl_success;
185183
MAYBE_VERBOSE(status, "u8", "s8", "s32",
186184
MAYBE_RUN_STACK_CHECKER(dnnl_threadpool_interop_gemm_u8s8s32,
187-
cpu::gemm_s8x8s32<uint8_t>, &transb, &transa,
188-
c2f_offsetC(&offsetc), &N, &M, &K, &alpha, B, &ldb, &bo, A,
189-
&lda, &ao, &beta, C, &ldc, co));
185+
cpu::gemm_s8x8s32, &transb, &transa, c2f_offsetC(&offsetc),
186+
&N, &M, &K, &alpha, B, &ldb, &bo, A, &lda, &ao, &beta, C,
187+
&ldc, co));
190188
threadpool_utils::deactivate_threadpool();
191189
return status;
192190
}
@@ -200,9 +198,9 @@ dnnl_status_t dnnl_threadpool_interop_gemm_s8s8s32(char transa, char transb,
200198
status_t status = dnnl_success;
201199
MAYBE_VERBOSE(status, "s8", "s8", "s32",
202200
MAYBE_RUN_STACK_CHECKER(dnnl_threadpool_interop_gemm_s8s8s32,
203-
cpu::gemm_s8x8s32<int8_t>, &transb, &transa,
204-
c2f_offsetC(&offsetc), &N, &M, &K, &alpha, B, &ldb, &bo, A,
205-
&lda, &ao, &beta, C, &ldc, co));
201+
cpu::gemm_s8x8s32, &transb, &transa, c2f_offsetC(&offsetc),
202+
&N, &M, &K, &alpha, B, &ldb, &bo, A, &lda, &ao, &beta, C,
203+
&ldc, co));
206204
threadpool_utils::deactivate_threadpool();
207205
return status;
208206
}

src/cpu/gemm/gemm.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2018-2023 Intel Corporation
2+
* Copyright 2018-2024 Intel Corporation
33
* Copyright 2022 IBM Corporation
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -185,7 +185,6 @@ dnnl_status_t try_cblas_gemm_s8u8s32(const char *transa, const char *transb,
185185
#endif
186186
}
187187

188-
template <>
189188
dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
190189
const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
191190
const float *alpha, const int8_t *A, const dim_t *LDA, const int8_t *ao,
@@ -227,7 +226,6 @@ dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
227226
B, LDB, bo, beta, C, LDC, co);
228227
}
229228

230-
template <>
231229
dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
232230
const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
233231
const float *alpha, const int8_t *A, const dim_t *LDA, const int8_t *ao,

src/cpu/gemm/gemm.hpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2018-2023 Intel Corporation
2+
* Copyright 2018-2024 Intel Corporation
33
* Copyright 2022 Arm Ltd. and affiliates
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -57,11 +57,16 @@ dnnl_status_t extended_sgemm(const char *transa, const char *transb,
5757
const float *beta, float *C, const dim_t *ldc,
5858
const float *bias = nullptr, bool force_jit_gemm = false);
5959

60-
template <typename b_dt>
6160
dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
62-
const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
63-
const float *alpha, const int8_t *A, const dim_t *lda, const int8_t *ao,
64-
const b_dt *B, const dim_t *ldb, const b_dt *bo, const float *beta,
61+
const char *offsetc, const dim_t *m, const dim_t *n, const dim_t *k,
62+
const float *alpha, const int8_t *a, const dim_t *lda, const int8_t *ao,
63+
const uint8_t *b, const dim_t *ldb, const uint8_t *bo,
64+
const float *beta, int32_t *c, const dim_t *ldc, const int32_t *co);
65+
66+
dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
67+
const char *offsetc, const dim_t *m, const dim_t *n, const dim_t *k,
68+
const float *alpha, const int8_t *a, const dim_t *lda, const int8_t *ao,
69+
const int8_t *b, const dim_t *ldb, const int8_t *bo, const float *beta,
6570
int32_t *c, const dim_t *ldc, const int32_t *co);
6671

6772
dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb,

0 commit comments

Comments
 (0)