Skip to content

Commit 94b1bfd

Browse files
jambaykkleiti
authored andcommitted
Add MatMul FP4 and NF4 Support (microsoft#18066)
### Description Add a contrib op MatMulBnb4 (FP4 and NF4) and related toolchain to support quantization on weight. This PR adds: - schema for contrib op MatMulBnb4 which can support FP4 (4-bit floating point) and NF4 (4-bit NormalFloat) quantization on weight. - a naive implementation for MatMulBnb4 on CPU and GPU, i.e., implemented like MatMul(A, Dequantize(B)). - a special implementation for GemV for MatMulBnb4 and related benchmark tool. - tool to quantize model to FP4 or NF4.
1 parent 236a995 commit 94b1bfd

23 files changed

Lines changed: 2236 additions & 0 deletions

cmake/onnxruntime_rocm_hipify.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ set(contrib_ops_excluded_files
5454
"quantization/attention_quantization_impl.cuh"
5555
"quantization/dequantize_blockwise.cuh"
5656
"quantization/dequantize_blockwise.cu"
57+
"quantization/dequantize_blockwise_bnb4.cuh"
58+
"quantization/dequantize_blockwise_bnb4.cu"
59+
"quantization/matmul_bnb4.cc"
60+
"quantization/matmul_bnb4.cuh"
61+
"quantization/matmul_bnb4.cu"
5762
"quantization/matmul_nbits.cc"
5863
"quantization/matmul_nbits.cuh"
5964
"quantization/matmul_nbits.cu"

docs/ContribOperators.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Do not modify directly.*
4747
* <a href="#com.microsoft.Inverse">com.microsoft.Inverse</a>
4848
* <a href="#com.microsoft.Irfft">com.microsoft.Irfft</a>
4949
* <a href="#com.microsoft.LongformerAttention">com.microsoft.LongformerAttention</a>
50+
* <a href="#com.microsoft.MatMulBnb4">com.microsoft.MatMulBnb4</a>
5051
* <a href="#com.microsoft.MatMulFpQ4">com.microsoft.MatMulFpQ4</a>
5152
* <a href="#com.microsoft.MatMulInteger16">com.microsoft.MatMulInteger16</a>
5253
* <a href="#com.microsoft.MatMulIntegerToFloat">com.microsoft.MatMulIntegerToFloat</a>
@@ -2504,6 +2505,62 @@ This version of the operator has been available since version 1 of the 'com.micr
25042505
</dl>
25052506

25062507

2508+
### <a name="com.microsoft.MatMulBnb4"></a><a name="com.microsoft.matmulbnb4">**com.microsoft.MatMulBnb4**</a>
2509+
2510+
MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
2511+
1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
2512+
2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'.
2513+
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
2514+
3. Input B's quantization constants or scales are specified by input 'absmax'.
2515+
2516+
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
2517+
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
2518+
2519+
#### Version
2520+
2521+
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
2522+
2523+
#### Attributes
2524+
2525+
<dl>
2526+
<dt><tt>K</tt> : int (required)</dt>
2527+
<dd>size of each input feature</dd>
2528+
<dt><tt>N</tt> : int (required)</dt>
2529+
<dd>size of each output feature</dd>
2530+
<dt><tt>block_size</tt> : int (required)</dt>
2531+
<dd>number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.</dd>
2532+
<dt><tt>quant_type</tt> : int (required)</dt>
2533+
<dd>quantization data type. 0 for FP4, 1 for NF4.</dd>
2534+
</dl>
2535+
2536+
#### Inputs
2537+
2538+
<dl>
2539+
<dt><tt>A</tt> : T1</dt>
2540+
<dd>The input tensor, not quantized</dd>
2541+
<dt><tt>B</tt> : T2</dt>
2542+
<dd>1-dimensional quantized data for weight</dd>
2543+
<dt><tt>absmax</tt> : T1</dt>
2544+
<dd>quantization constants</dd>
2545+
</dl>
2546+
2547+
#### Outputs
2548+
2549+
<dl>
2550+
<dt><tt>Y</tt> : T1</dt>
2551+
<dd>tensor. The output tensor has the same rank as the input. </dd>
2552+
</dl>
2553+
2554+
#### Type Constraints
2555+
2556+
<dl>
2557+
<dt><tt>T1</tt> : tensor(float), tensor(float16)</dt>
2558+
<dd>Constrain input and output types to float/half_float tensors.</dd>
2559+
<dt><tt>T2</tt> : tensor(uint8)</dt>
2560+
<dd>Constrain quantized weight types to uint8.</dd>
2561+
</dl>
2562+
2563+
25072564
### <a name="com.microsoft.MatMulFpQ4"></a><a name="com.microsoft.matmulfpq4">**com.microsoft.MatMulFpQ4**</a>
25082565

25092566
Matrix product with right hand matrix being pre-packed and quantized int4 data blob.

docs/OperatorKernels.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ Do not modify directly.*
457457
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
458458
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
459459
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
460+
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
460461
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
461462
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
462463
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
@@ -852,6 +853,7 @@ Do not modify directly.*
852853
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
853854
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
854855
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
856+
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
855857
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
856858
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
857859
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|

onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe
3030
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
3131
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
3232
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits);
33+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
3334
#ifndef ORT_MINIMAL_BUILD
3435
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4);
3536
#endif
@@ -270,6 +271,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
270271
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>, // backward compatibility
271272
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul)>,
272273
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
274+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4)>,
273275
#ifndef ORT_MINIMAL_BUILD
274276
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4)>,
275277
#endif
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <cstdint>
7+
#include <algorithm>
8+
#include <cmath>
9+
10+
namespace onnxruntime {
11+
namespace contrib {
12+
13+
#if defined(_MSC_VER)
14+
#define FORCEINLINE __forceinline
15+
#else
16+
#define FORCEINLINE __attribute__((always_inline)) inline
17+
#endif
18+
19+
typedef enum Bnb_DataType_t {
20+
FP4 = 0,
21+
NF4 = 1,
22+
} Bnb_DataType_t;
23+
24+
FORCEINLINE uint8_t QuantizeOneFP4(float x) {
25+
// FP4 with bias of 3
26+
// first bit is a sign
27+
// subnormals
28+
// 0b000 = 0
29+
// 0b001 = 0.0625
30+
// 0b110 = 2
31+
// 0b111 = 3
32+
// 0b100 = 4
33+
// 0b101 = 6
34+
// 0b010 = 8
35+
// 0b011 = 12
36+
37+
// we do a binary search
38+
// the pivots are divided by 12 (the FP4 absmax)
39+
// since we assum input data is in [-1.0, 1.0]
40+
41+
// !be careful here, its easy to make a mistake
42+
// that is difficult to noice if you add an extra
43+
// zero somewhere!
44+
45+
uint8_t sign = x < 0 ? 0b1000 : 0b0000;
46+
x = fabsf(x);
47+
if (x > 0.29166667f) {
48+
if (x > 0.583333f) {
49+
if (x > 0.8333333f) {
50+
return 0b0011 + sign;
51+
} else {
52+
return 0b0010 + sign;
53+
}
54+
} else if (x > 0.4166667f) {
55+
return 0b101 + sign;
56+
} else {
57+
return 0b100 + sign;
58+
}
59+
} else if (x > 0.0859375f) {
60+
if (x > 0.20833333f) {
61+
return 0b0111 + sign;
62+
} else {
63+
return 0b0110 + sign;
64+
}
65+
} else if (x > 0.00260417f) {
66+
return 0b0001 + sign;
67+
} else {
68+
return 0b0000 + sign;
69+
}
70+
}
71+
72+
FORCEINLINE uint8_t QuantizeOneNF4(float x) {
73+
if (x > 0.03979014977812767f) {
74+
if (x > 0.3893125355243683f) { // 1
75+
if (x > 0.6427869200706482f) { // 11
76+
if (x > 0.8614784181118011f) { // 111
77+
return 0b1111;
78+
} else {
79+
return 0b1110;
80+
}
81+
} else if (x > 0.5016634166240692f) { // 110
82+
return 0b1101;
83+
} else {
84+
return 0b1100;
85+
}
86+
} else if (x > 0.2035212516784668f) { // 10
87+
if (x > 0.2920137718319893f) { // 101
88+
return 0b1011;
89+
} else {
90+
return 0b1010;
91+
}
92+
} else if (x > 0.1202552504837513f) { // 100
93+
return 0b1001;
94+
} else {
95+
return 0b1000;
96+
}
97+
} else if (x > -0.33967943489551544f) { // 0
98+
if (x > -0.13791173323988914f) { // 01
99+
if (x > -0.045525018125772476f) { // 011
100+
return 0b0111;
101+
} else {
102+
return 0b0110;
103+
}
104+
} else if (x > -0.23460740596055984f) { // 010
105+
return 0b0101;
106+
} else {
107+
return 0b0100;
108+
}
109+
} else if (x > -0.6106329262256622f) { // 00
110+
if (x > -0.4599952697753906f) { // 001
111+
return 0b0011;
112+
} else {
113+
return 0b0010;
114+
}
115+
} else if (x > -0.8480964004993439f) { // 000
116+
return 0b0001;
117+
} else {
118+
return 0b0000;
119+
}
120+
}
121+
122+
template <int32_t DATA_TYPE>
123+
FORCEINLINE uint8_t QuantizeOneBnb4(float x) {
124+
if constexpr (DATA_TYPE == FP4)
125+
return QuantizeOneFP4(x);
126+
else
127+
return QuantizeOneNF4(x);
128+
}
129+
130+
template <typename T, int32_t block_size, int32_t DATA_TYPE>
131+
FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) {
132+
float local_absmax = 0.0f;
133+
134+
int32_t block_len = std::min(block_size, numel - block_idx * block_size);
135+
int32_t src_offset = block_idx * block_size;
136+
int32_t dst_offset = block_idx * block_size / 2;
137+
138+
for (int32_t idx = 0; idx < block_len; idx++) {
139+
const float v = static_cast<float>(src[src_offset + idx]);
140+
local_absmax = fmaxf(local_absmax, fabsf(v));
141+
}
142+
143+
absmax_block = static_cast<T>(local_absmax);
144+
const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f;
145+
146+
for (int32_t idx = 0; idx < block_len; idx += 2) {
147+
const float v0 = static_cast<float>(src[src_offset + idx]) * reciprocal_absmax;
148+
const uint8_t vi0 = QuantizeOneBnb4<DATA_TYPE>(v0);
149+
150+
const float v1 = (idx + 1 < block_len) ? static_cast<float>(src[src_offset + idx + 1]) * reciprocal_absmax : 0;
151+
const uint8_t vi1 = QuantizeOneBnb4<DATA_TYPE>(v1);
152+
153+
dst[dst_offset + idx / 2] = (vi0 << 4) | vi1;
154+
}
155+
}
156+
157+
static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f,
158+
0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f,
159+
-0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f,
160+
-0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f};
161+
162+
static float nf4_qaunt_map[16] = {-1.0f,
163+
-0.6961928009986877f,
164+
-0.5250730514526367f,
165+
-0.39491748809814453f,
166+
-0.28444138169288635f,
167+
-0.18477343022823334f,
168+
-0.09105003625154495f,
169+
0.0f,
170+
0.07958029955625534f,
171+
0.16093020141124725f,
172+
0.24611230194568634f,
173+
0.33791524171829224f,
174+
0.44070982933044434f,
175+
0.5626170039176941f,
176+
0.7229568362236023f,
177+
1.0f};
178+
179+
template <typename T, int32_t DATA_TYPE>
180+
FORCEINLINE T DequantizeOneBnb4(uint8_t x) {
181+
if constexpr (DATA_TYPE == FP4)
182+
return static_cast<T>(fp4_qaunt_map[x]);
183+
else
184+
return static_cast<T>(nf4_qaunt_map[x]);
185+
}
186+
187+
template <typename T, int32_t block_size, int32_t DATA_TYPE>
188+
FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) {
189+
int32_t block_len = std::min(block_size, numel - block_idx * block_size);
190+
int32_t src_offset = block_idx * block_size / 2;
191+
int32_t dst_offset = block_idx * block_size;
192+
193+
for (int32_t idx = 0; idx < block_len; idx += 2) {
194+
const uint8_t val = src[src_offset + idx / 2];
195+
196+
dst[dst_offset + idx] = DequantizeOneBnb4<T, DATA_TYPE>(val >> 4) * absmax_block;
197+
if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4<T, DATA_TYPE>(val & 0xF) * absmax_block;
198+
}
199+
}
200+
201+
} // namespace contrib
202+
} // namespace onnxruntime

0 commit comments

Comments
 (0)