Skip to content

Commit 8d309ae

Browse files
committed
add MatMul Bnb4 support
cleanup clean up
1 parent 2a17d5c commit 8d309ae

24 files changed

Lines changed: 2145 additions & 0 deletions

cmake/onnxruntime_rocm_hipify.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ set(contrib_ops_excluded_files
5454
"quantization/attention_quantization_impl.cuh"
5555
"quantization/dequantize_blockwise.cuh"
5656
"quantization/dequantize_blockwise.cu"
57+
"quantization/dequantize_blockwise_bnb4.cuh"
58+
"quantization/dequantize_blockwise_bnb4.cu"
59+
"quantization/matmul_bnb4.cc"
60+
"quantization/matmul_bnb4.cuh"
61+
"quantization/matmul_bnb4.cu"
5762
"quantization/matmul_nbits.cc"
5863
"quantization/matmul_nbits.cuh"
5964
"quantization/matmul_nbits.cu"

docs/ContribOperators.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Do not modify directly.*
4747
* <a href="#com.microsoft.Inverse">com.microsoft.Inverse</a>
4848
* <a href="#com.microsoft.Irfft">com.microsoft.Irfft</a>
4949
* <a href="#com.microsoft.LongformerAttention">com.microsoft.LongformerAttention</a>
50+
* <a href="#com.microsoft.MatMulNBits">com.microsoft.MatMulBnb4</a>
5051
* <a href="#com.microsoft.MatMulFpQ4">com.microsoft.MatMulFpQ4</a>
5152
* <a href="#com.microsoft.MatMulInteger16">com.microsoft.MatMulInteger16</a>
5253
* <a href="#com.microsoft.MatMulIntegerToFloat">com.microsoft.MatMulIntegerToFloat</a>
@@ -2504,6 +2505,62 @@ This version of the operator has been available since version 1 of the 'com.micr
25042505
</dl>
25052506

25062507

2508+
### <a name="com.microsoft.MatMulBnb4"></a><a name="com.microsoft.matmulbnb4">**com.microsoft.MatMulBnb4**</a>
2509+
2510+
MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
2511+
1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
2512+
2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'.
2513+
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
2514+
3. Input B's quantization constants or scales are specified by input 'absmax'.
2515+
2516+
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
2517+
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
2518+
2519+
#### Version
2520+
2521+
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
2522+
2523+
#### Attributes
2524+
2525+
<dl>
2526+
<dt><tt>K</tt> : int (required)</dt>
2527+
<dd>size of each input feature</dd>
2528+
<dt><tt>N</tt> : int (required)</dt>
2529+
<dd>size of each output feature</dd>
2530+
<dt><tt>block_size</tt> : int (required)</dt>
2531+
<dd>number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.</dd>
2532+
<dt><tt>quant_type</tt> : int (required)</dt>
2533+
<dd>Quantization data type. 0 for FP4, 1 for NF4.</dd>
2534+
</dl>
2535+
2536+
#### Inputs
2537+
2538+
<dl>
2539+
<dt><tt>A</tt> : T1</dt>
2540+
<dd>The input tensor, not quantized</dd>
2541+
<dt><tt>B</tt> : T2</dt>
2542+
<dd>1-dimensional quantized data for weight</dd>
2543+
<dt><tt>absmax</tt> : T1</dt>
2544+
<dd>Quantization constants</dd>
2545+
</dl>
2546+
2547+
#### Outputs
2548+
2549+
<dl>
2550+
<dt><tt>Y</tt> : T1</dt>
2551+
<dd>tensor. The output tensor has the same rank as the input. </dd>
2552+
</dl>
2553+
2554+
#### Type Constraints
2555+
2556+
<dl>
2557+
<dt><tt>T1</tt> : tensor(float), tensor(float16)</dt>
2558+
<dd>Constrain input and output types to float/half_float tensors.</dd>
2559+
<dt><tt>T2</tt> : tensor(uint8)</dt>
2560+
<dd>Constrain quantized weight types to uint8.</dd>
2561+
</dl>
2562+
2563+
25072564
### <a name="com.microsoft.MatMulFpQ4"></a><a name="com.microsoft.matmulfpq4">**com.microsoft.MatMulFpQ4**</a>
25082565

25092566
Matrix product with right hand matrix being pre-packed and quantized int4 data blob.

docs/OperatorKernels.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ Do not modify directly.*
454454
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
455455
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
456456
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
457+
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
457458
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
458459
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
459460
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
@@ -849,6 +850,7 @@ Do not modify directly.*
849850
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
850851
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
851852
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
853+
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
852854
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
853855
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
854856
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|

onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe
3030
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
3131
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
3232
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits);
33+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
3334
#ifndef ORT_MINIMAL_BUILD
3435
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4);
3536
#endif
@@ -270,6 +271,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
270271
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>, // backward compatibility
271272
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul)>,
272273
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
274+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4)>,
273275
#ifndef ORT_MINIMAL_BUILD
274276
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4)>,
275277
#endif
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <cstdint>
7+
#include <algorithm>
8+
#include <cmath>
9+
10+
namespace onnxruntime {
11+
namespace contrib {
12+
13+
#if defined(_MSC_VER)
14+
#define FORCEINLINE __forceinline
15+
#else
16+
#define FORCEINLINE __attribute__((always_inline)) inline
17+
#endif
18+
19+
typedef enum Bnb_DataType_t {
20+
FP4 = 0,
21+
NF4 = 1,
22+
} Bnb_DataType_t;
23+
24+
FORCEINLINE uint8_t QuantizeOneFP4(float x) {
25+
// FP4 with bias of 3
26+
// first bit is a sign
27+
// subnormals
28+
// 0b000 = 0
29+
// 0b001 = 0.0625
30+
// 0b110 = 2
31+
// 0b111 = 3
32+
// 0b100 = 4
33+
// 0b101 = 6
34+
// 0b010 = 8
35+
// 0b011 = 12
36+
37+
// we do a binary search
38+
// the pivots are divided by 12 (the FP4 absmax)
39+
// since we assum input data is in [-1.0, 1.0]
40+
41+
// !be careful here, its easy to make a mistake
42+
// that is difficult to noice if you add an extra
43+
// zero somewhere!
44+
45+
int sign = x < 0 ? 0b1000 : 0b0000;
46+
x = fabsf(x);
47+
if (x > 0.29166667f)
48+
if (x > 0.583333f)
49+
if (x > 0.8333333f)
50+
return 0b0011 + sign;
51+
else
52+
return 0b0010 + sign;
53+
else if (x > 0.4166667f)
54+
return 0b101 + sign;
55+
else
56+
return 0b100 + sign;
57+
else if (x > 0.0859375f)
58+
if (x > 0.20833333f)
59+
return 0b0111 + sign;
60+
else
61+
return 0b0110 + sign;
62+
else if (x > 0.00260417f)
63+
return 0b0001 + sign;
64+
else
65+
return 0b0000 + sign;
66+
}
67+
68+
FORCEINLINE uint8_t QuantizeOneNF4(float x) {
69+
if (x > 0.03979014977812767f)
70+
if (x > 0.3893125355243683f) // 1
71+
if (x > 0.6427869200706482f) // 11
72+
if (x > 0.8614784181118011f) // 111
73+
return 0b1111;
74+
else
75+
return 0b1110;
76+
else if (x > 0.5016634166240692f) // 110
77+
return 0b1101;
78+
else
79+
return 0b1100;
80+
else if (x > 0.2035212516784668f) // 10
81+
if (x > 0.2920137718319893f) // 101
82+
return 0b1011;
83+
else
84+
return 0b1010;
85+
else if (x > 0.1202552504837513f) // 100
86+
return 0b1001;
87+
else
88+
return 0b1000;
89+
else if (x > -0.33967943489551544f) // 0
90+
if (x > -0.13791173323988914f) // 01
91+
if (x > -0.045525018125772476f) // 011
92+
return 0b0111;
93+
else
94+
return 0b0110;
95+
else if (x > -0.23460740596055984f) // 010
96+
return 0b0101;
97+
else
98+
return 0b0100;
99+
else if (x > -0.6106329262256622f) // 00
100+
if (x > -0.4599952697753906f) // 001
101+
return 0b0011;
102+
else
103+
return 0b0010;
104+
else if (x > -0.8480964004993439f) // 000
105+
return 0b0001;
106+
else
107+
return 0b0000;
108+
}
109+
110+
template <int32_t DATA_TYPE>
111+
FORCEINLINE uint8_t QuantizeOneBnb4(float x) {
112+
if constexpr (DATA_TYPE == FP4)
113+
return QuantizeOneFP4(x);
114+
else
115+
return QuantizeOneNF4(x);
116+
}
117+
118+
template <typename T, int32_t block_size, int32_t DATA_TYPE>
119+
FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) {
120+
float local_absmax = 0.0f;
121+
122+
int32_t block_len = std::min(block_size, numel - block_idx * block_size);
123+
int32_t src_offset = block_idx * block_size;
124+
int32_t dst_offset = block_idx * block_size / 2;
125+
126+
for (int32_t idx = 0; idx < block_len; idx++) {
127+
const float v = static_cast<float>(src[src_offset + idx]);
128+
local_absmax = fmaxf(local_absmax, fabsf(v));
129+
}
130+
131+
absmax_block = static_cast<T>(local_absmax);
132+
const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f;
133+
134+
for (int32_t idx = 0; idx < block_len; idx += 2) {
135+
const float v0 = static_cast<float>(src[src_offset + idx]) * reciprocal_absmax;
136+
const uint8_t vi0 = QuantizeOneBnb4<DATA_TYPE>(v0);
137+
138+
const float v1 = (idx + 1 < block_len) ? static_cast<float>(src[src_offset + idx + 1]) * reciprocal_absmax : 0;
139+
const uint8_t vi1 = QuantizeOneBnb4<DATA_TYPE>(v1);
140+
141+
dst[dst_offset + idx / 2] = (vi0 << 4) | vi1;
142+
}
143+
}
144+
145+
static float fp4_qaunt_map[16] = {
146+
0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f,
147+
0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f,
148+
-0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f,
149+
-0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f};
150+
151+
static float nf4_qaunt_map[16] = {
152+
-1.0f, -0.6961928009986877f, -0.5250730514526367f, -0.39491748809814453f,
153+
-0.28444138169288635f, -0.18477343022823334f, -0.09105003625154495f, 0.0f,
154+
0.07958029955625534f, 0.16093020141124725f, 0.24611230194568634f, 0.33791524171829224f,
155+
0.44070982933044434f, 0.5626170039176941f, 0.7229568362236023f, 1.0f};
156+
157+
template <typename T, int32_t DATA_TYPE>
158+
FORCEINLINE T DequantizeOneBnb4(uint8_t x) {
159+
if constexpr (DATA_TYPE == FP4)
160+
return static_cast<T>(fp4_qaunt_map[x]);
161+
else
162+
return static_cast<T>(nf4_qaunt_map[x]);
163+
}
164+
165+
template <typename T, int32_t block_size, int32_t DATA_TYPE>
166+
FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) {
167+
int32_t block_len = std::min(block_size, numel - block_idx * block_size);
168+
int32_t src_offset = block_idx * block_size / 2;
169+
int32_t dst_offset = block_idx * block_size;
170+
171+
for (int32_t idx = 0; idx < block_len; idx += 2) {
172+
const uint8_t val = src[src_offset + idx / 2];
173+
174+
dst[dst_offset + idx] = DequantizeOneBnb4<T, DATA_TYPE>(val >> 4) * absmax_block;
175+
if (idx + 1 < block_len)
176+
dst[dst_offset + idx + 1] = DequantizeOneBnb4<T, DATA_TYPE>(val & 0xF) * absmax_block;
177+
}
178+
}
179+
180+
} // namespace contrib
181+
} // namespace onnxruntime

0 commit comments

Comments
 (0)