|
| 1 | +/////////////////////////////////////////////////////////////////////// |
| 2 | +// File: intsimdmatrixrvv.cpp |
| 3 | +// Description: matrix-vector product for 8-bit data on rvv. |
| 4 | +// Author: sunyuechi |
| 5 | +// |
| 6 | +// Copyright (c) 2024 Institute of Software Chinese Academy of Sciences (ISCAS). |
| 7 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 8 | +// you may not use this file except in compliance with the License. |
| 9 | +// You may obtain a copy of the License at |
| 10 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +// Unless required by applicable law or agreed to in writing, software |
| 12 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +// See the License for the specific language governing permissions and |
| 15 | +// limitations under the License. |
| 16 | +/////////////////////////////////////////////////////////////////////// |
| 17 | + |
| 18 | +#ifdef HAVE_CONFIG_H |
| 19 | +# include "config_auto.h" // for HAVE_RVV, ... |
| 20 | +#endif |
| 21 | + |
| 22 | +#if HAVE_RVV |
| 23 | +# include "intsimdmatrix.h" |
| 24 | +# include "tesstypes.h" |
| 25 | + |
| 26 | +namespace tesseract { |
| 27 | + |
| 28 | +static int DotProduct(const int8_t *u, const int8_t *v, int num) { |
| 29 | + int total = 0; |
| 30 | + |
| 31 | + asm __volatile__ ( |
| 32 | + " .option arch, +v \n\t" |
| 33 | + " vsetvli t0,zero,e32,m8,ta,ma \n\t" |
| 34 | + " vmv.v.i v0,0 \n\t" |
| 35 | + "1: \n\t" |
| 36 | + " vsetvli t0,%[num],e8,m2,ta,ma \n\t" |
| 37 | + " vle8.v v16,0(%[u]) \n\t" |
| 38 | + " vle8.v v24,0(%[v]) \n\t" |
| 39 | + " sub %[num],%[num],t0 \n\t" |
| 40 | + " vwmul.vv v8,v24,v16 \n\t" |
| 41 | + " add %[u],%[u],t0 \n\t" |
| 42 | + " add %[v],%[v],t0 \n\t" |
| 43 | + " vsetvli zero,zero,e16,m4,tu,ma \n\t" |
| 44 | + " vwadd.wv v0,v0,v8 \n\t" |
| 45 | + " bnez %[num],1b \n\t" |
| 46 | + " vsetvli t0,zero,e32,m8,ta,ma \n\t" |
| 47 | + " vmv.s.x v8,zero \n\t" |
| 48 | + " vredsum.vs v0,v0,v8 \n\t" |
| 49 | + " vmv.x.s %[total],v0 \n\t" |
| 50 | + : [u] "+r" (u), |
| 51 | + [v] "+r" (v), |
| 52 | + [num] "+r" (num), |
| 53 | + [total] "+r" (total) |
| 54 | + : |
| 55 | + : "cc", "memory" |
| 56 | + ); |
| 57 | + |
| 58 | + return total; |
| 59 | +} |
| 60 | + |
| 61 | +static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const TFloat *scales, |
| 62 | + const int8_t *u, TFloat *v) { |
| 63 | + int num_out = dim1; |
| 64 | + int num_in = dim2 - 1; |
| 65 | + for (int i = 0; i < num_out; ++i) { |
| 66 | + const int8_t *wi_start = wi + i * dim2; |
| 67 | + int total = DotProduct(wi_start, u, num_in); |
| 68 | + // Add in the bias and apply scaling. |
| 69 | + v[i] = (total + wi_start[num_in] * INT8_MAX) * scales[i]; |
| 70 | + } |
| 71 | +} |
| 72 | + |
| 73 | +const IntSimdMatrix IntSimdMatrix::intSimdMatrixRVV = { |
| 74 | + // Function. |
| 75 | + matrixDotVector, |
| 76 | + // Number of 32 bit outputs held in each register. |
| 77 | + 1, |
| 78 | + // Maximum number of registers that we will use to hold outputs. |
| 79 | + 1, |
| 80 | + // Number of 8 bit inputs in the inputs register. |
| 81 | + 1, |
| 82 | + // Number of inputs in each weight group. |
| 83 | + 1 |
| 84 | +}; |
| 85 | + |
| 86 | +} // namespace tesseract. |
| 87 | + |
| 88 | +#endif /* HAVE_RVV */ |
0 commit comments