Skip to content

Commit 47f88d2

Browse files
committed
c++ linter and clang-format
1 parent cca2da8 commit 47f88d2

10 files changed

Lines changed: 228 additions & 179 deletions

File tree

onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h

Lines changed: 68 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -44,67 +44,79 @@ FORCEINLINE uint8_t QuantizeOneFP4(float x) {
4444

4545
int sign = x < 0 ? 0b1000 : 0b0000;
4646
x = fabsf(x);
47-
if (x > 0.29166667f)
48-
if (x > 0.583333f)
49-
if (x > 0.8333333f)
47+
if (x > 0.29166667f) {
48+
if (x > 0.583333f) {
49+
if (x > 0.8333333f) {
5050
return 0b0011 + sign;
51-
else
51+
} else {
5252
return 0b0010 + sign;
53-
else if (x > 0.4166667f)
53+
}
54+
} else if (x > 0.4166667f) {
5455
return 0b101 + sign;
55-
else
56+
} else {
5657
return 0b100 + sign;
57-
else if (x > 0.0859375f)
58-
if (x > 0.20833333f)
58+
}
59+
} else if (x > 0.0859375f) {
60+
if (x > 0.20833333f) {
5961
return 0b0111 + sign;
60-
else
62+
} else {
6163
return 0b0110 + sign;
62-
else if (x > 0.00260417f)
64+
}
65+
} else if (x > 0.00260417f) {
6366
return 0b0001 + sign;
64-
else
67+
} else {
6568
return 0b0000 + sign;
69+
}
6670
}
6771

6872
FORCEINLINE uint8_t QuantizeOneNF4(float x) {
69-
if (x > 0.03979014977812767f)
70-
if (x > 0.3893125355243683f) // 1
71-
if (x > 0.6427869200706482f) // 11
72-
if (x > 0.8614784181118011f) // 111
73+
if (x > 0.03979014977812767f) {
74+
if (x > 0.3893125355243683f) { // 1
75+
if (x > 0.6427869200706482f) { // 11
76+
if (x > 0.8614784181118011f) { // 111
7377
return 0b1111;
74-
else
78+
} else {
7579
return 0b1110;
76-
else if (x > 0.5016634166240692f) // 110
80+
}
81+
} else if (x > 0.5016634166240692f) { // 110
7782
return 0b1101;
78-
else
83+
} else {
7984
return 0b1100;
80-
else if (x > 0.2035212516784668f) // 10
81-
if (x > 0.2920137718319893f) // 101
85+
}
86+
} else if (x > 0.2035212516784668f) { // 10
87+
if (x > 0.2920137718319893f) { // 101
8288
return 0b1011;
83-
else
89+
} else {
8490
return 0b1010;
85-
else if (x > 0.1202552504837513f) // 100
91+
}
92+
} else if (x > 0.1202552504837513f) { // 100
8693
return 0b1001;
87-
else
94+
} else {
8895
return 0b1000;
89-
else if (x > -0.33967943489551544f) // 0
90-
if (x > -0.13791173323988914f) // 01
91-
if (x > -0.045525018125772476f) // 011
96+
}
97+
} else if (x > -0.33967943489551544f) { // 0
98+
if (x > -0.13791173323988914f) { // 01
99+
if (x > -0.045525018125772476f) { // 011
92100
return 0b0111;
93-
else
101+
} else {
94102
return 0b0110;
95-
else if (x > -0.23460740596055984f) // 010
103+
}
104+
} else if (x > -0.23460740596055984f) { // 010
96105
return 0b0101;
97-
else
106+
} else {
98107
return 0b0100;
99-
else if (x > -0.6106329262256622f) // 00
100-
if (x > -0.4599952697753906f) // 001
108+
}
109+
} else if (x > -0.6106329262256622f) { // 00
110+
if (x > -0.4599952697753906f) { // 001
101111
return 0b0011;
102-
else
112+
} else {
103113
return 0b0010;
104-
else if (x > -0.8480964004993439f) // 000
114+
}
115+
} else if (x > -0.8480964004993439f) { // 000
105116
return 0b0001;
106-
else
117+
} else {
107118
return 0b0000;
119+
}
108120
}
109121

110122
template <int32_t DATA_TYPE>
@@ -142,17 +154,27 @@ FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block,
142154
}
143155
}
144156

145-
static float fp4_qaunt_map[16] = {
146-
0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f,
147-
0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f,
148-
-0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f,
149-
-0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f};
150-
151-
static float nf4_qaunt_map[16] = {
152-
-1.0f, -0.6961928009986877f, -0.5250730514526367f, -0.39491748809814453f,
153-
-0.28444138169288635f, -0.18477343022823334f, -0.09105003625154495f, 0.0f,
154-
0.07958029955625534f, 0.16093020141124725f, 0.24611230194568634f, 0.33791524171829224f,
155-
0.44070982933044434f, 0.5626170039176941f, 0.7229568362236023f, 1.0f};
157+
static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f,
158+
0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f,
159+
-0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f,
160+
-0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f};
161+
162+
static float nf4_qaunt_map[16] = {-1.0f,
163+
-0.6961928009986877f,
164+
-0.5250730514526367f,
165+
-0.39491748809814453f,
166+
-0.28444138169288635f,
167+
-0.18477343022823334f,
168+
-0.09105003625154495f,
169+
0.0f,
170+
0.07958029955625534f,
171+
0.16093020141124725f,
172+
0.24611230194568634f,
173+
0.33791524171829224f,
174+
0.44070982933044434f,
175+
0.5626170039176941f,
176+
0.7229568362236023f,
177+
1.0f};
156178

157179
template <typename T, int32_t DATA_TYPE>
158180
FORCEINLINE T DequantizeOneBnb4(uint8_t x) {
@@ -172,8 +194,7 @@ FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block,
172194
const uint8_t val = src[src_offset + idx / 2];
173195

174196
dst[dst_offset + idx] = DequantizeOneBnb4<T, DATA_TYPE>(val >> 4) * absmax_block;
175-
if (idx + 1 < block_len)
176-
dst[dst_offset + idx + 1] = DequantizeOneBnb4<T, DATA_TYPE>(val & 0xF) * absmax_block;
197+
if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4<T, DATA_TYPE>(val & 0xF) * absmax_block;
177198
}
178199
}
179200

onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ void QuantizeBlockwiseBnb4(
5151
int32_t N,
5252
int32_t K,
5353
onnxruntime::concurrency::ThreadPool* thread_pool) {
54-
ORT_ENFORCE(quant_type == FP4 || quant_type == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
54+
ORT_ENFORCE(
55+
quant_type == FP4 || quant_type == NF4,
56+
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
5557

5658
if (block_size == 16) {
5759
QuantizeBlockwiseBn4DataTyped(16, quant_type);
@@ -106,7 +108,9 @@ void DequantizeBlockwiseBnb4(
106108
int32_t N,
107109
int32_t K,
108110
onnxruntime::concurrency::ThreadPool* thread_pool) {
109-
ORT_ENFORCE(quant_type == FP4 || quant_type == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
111+
ORT_ENFORCE(
112+
quant_type == FP4 || quant_type == NF4,
113+
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
110114

111115
if (block_size == 16) {
112116
DequantizeBlockwiseBn4DataTyped(16, quant_type);

onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ class MatMulBnb4 final : public OpKernel {
1818
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
1919
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
2020
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("quant_type", &quant_type_));
21-
ORT_ENFORCE(quant_type_ == FP4 || quant_type_ == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
21+
ORT_ENFORCE(
22+
quant_type_ == FP4 || quant_type_ == NF4,
23+
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
2224
}
2325

2426
Status Compute(OpKernelContext* context) const override;
@@ -45,14 +47,15 @@ Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
4547
auto status = ctx->GetTempSpaceAllocator(&allocator);
4648
ORT_RETURN_IF_ERROR(status);
4749
auto tmp_b_data_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
48-
DequantizeBlockwiseBnb4<float>(tmp_b_data_ptr.get(),
49-
b_quant_data,
50-
absmax_data,
51-
static_cast<int32_t>(block_size_),
52-
static_cast<int32_t>(quant_type_),
53-
static_cast<int32_t>(N_),
54-
static_cast<int32_t>(K_),
55-
thread_pool);
50+
DequantizeBlockwiseBnb4<float>(
51+
tmp_b_data_ptr.get(),
52+
b_quant_data,
53+
absmax_data,
54+
static_cast<int32_t>(block_size_),
55+
static_cast<int32_t>(quant_type_),
56+
static_cast<int32_t>(N_),
57+
static_cast<int32_t>(K_),
58+
thread_pool);
5659

5760
constexpr bool transa = false;
5861
constexpr bool transb = true;
@@ -63,8 +66,7 @@ Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
6366
Tensor* y = ctx->Output(0, helper.OutputShape());
6467

6568
// Bail out early if the output is going to be empty
66-
if (y->Shape().Size() == 0)
67-
return Status::OK();
69+
if (y->Shape().Size() == 0) return Status::OK();
6870

6971
auto* y_data = y->MutableData<float>();
7072

@@ -88,8 +90,7 @@ Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
8890
data[i].alpha = 1.f;
8991
data[i].beta = 0.0f;
9092
}
91-
MlasGemmBatch(CblasNoTrans, CblasTrans,
92-
M, N, K, data.data(), max_len, thread_pool);
93+
MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), max_len, thread_pool);
9394

9495
return Status::OK();
9596
}

onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu

Lines changed: 58 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,22 @@ namespace onnxruntime {
1111
namespace contrib {
1212
namespace cuda {
1313

14-
template<class T>
15-
Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream)
16-
{
17-
ORT_ENFORCE(quant_type == FP4 || quant_type == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
18-
14+
template <class T>
15+
Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) {
16+
ORT_ENFORCE(
17+
quant_type == FP4 || quant_type == NF4,
18+
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
19+
1920
T host_quant_map[16];
2021
switch (quant_type) {
2122
case FP4:
22-
for(int i = 0; i < 16; i++)
23-
host_quant_map[i] = static_cast<T>(fp4_qaunt_map[i]);
23+
for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast<T>(fp4_qaunt_map[i]);
2424
break;
2525
case NF4:
26-
for(int i = 0; i < 16; i++)
27-
host_quant_map[i] = static_cast<T>(nf4_qaunt_map[i]);
26+
for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast<T>(nf4_qaunt_map[i]);
2827
break;
2928
}
30-
CUDA_CALL_THROW(cudaMemcpyAsync(quant_map_buffer, host_quant_map, sizeof(T)*16, cudaMemcpyHostToDevice, stream));
29+
CUDA_CALL_THROW(cudaMemcpyAsync(quant_map_buffer, host_quant_map, sizeof(T) * 16, cudaMemcpyHostToDevice, stream));
3130

3231
return Status::OK();
3332
}
@@ -36,60 +35,82 @@ template Status SetBnbQuantMap<float>(int quant_type, float* quant_map_buffer, c
3635

3736
template Status SetBnbQuantMap<half>(int quant_type, half* quant_map_buffer, cudaStream_t stream);
3837

39-
40-
template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH>
41-
__global__ void kDequantizeBlockwise(const T *quant_map, T *output, const unsigned char *quant_data, const T *absmax, const int block_size, const int n)
42-
{
38+
template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH>
39+
__global__ void kDequantizeBlockwise(
40+
const T* quant_map,
41+
T* output,
42+
const uint8_t* quant_data,
43+
const T* absmax,
44+
const int block_size,
45+
const int n) {
4346
const int n_load = (gridDim.x * TILE_SIZE);
4447
int valid_items_load = 0;
4548
int valid_items_store = 0;
4649
const int base_idx = (blockIdx.x * TILE_SIZE);
4750

48-
T vals[NUM_PER_TH*2];
49-
unsigned char qvals[NUM_PER_TH];
51+
T vals[NUM_PER_TH * 2];
52+
uint8_t qvals[NUM_PER_TH];
5053
T local_abs_max = T(0.0f);
5154

52-
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
53-
typedef cub::BlockStore<T, THREADS, NUM_PER_TH*2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
55+
typedef cub::BlockLoad<uint8_t, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
56+
typedef cub::BlockStore<T, THREADS, NUM_PER_TH * 2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
5457

5558
__shared__ typename LoadChar::TempStorage loadchar;
5659
__shared__ typename StoreT::TempStorage storet;
5760

58-
for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE)
59-
{
60-
valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i;
61-
valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2;
61+
for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) {
62+
valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i;
63+
valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2;
6264

63-
local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(block_size)]);
65+
local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]);
6466

6567
__syncthreads();
6668
LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128);
6769

6870
#pragma unroll NUM_PER_TH
69-
for(int j = 0; j < NUM_PER_TH; j++)
70-
{
71-
vals[j*2] = quant_map[qvals[j] >> 4] * local_abs_max;
72-
vals[j*2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max;
71+
for (int j = 0; j < NUM_PER_TH; j++) {
72+
vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max;
73+
vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max;
7374
}
7475

7576
__syncthreads();
76-
StoreT(storet).Store(&(output[i*2]), vals, valid_items_store);
77+
StoreT(storet).Store(&(output[i * 2]), vals, valid_items_store);
7778
}
7879
}
7980

80-
81-
template<class T>
82-
Status DequantizeBnb4(const T* quant_map, T *output, const unsigned char *quant_data, const T *absmax, int block_size, int numel, cudaStream_t stream)
83-
{
81+
template <class T>
82+
Status DequantizeBnb4(
83+
const T* quant_map,
84+
T* output,
85+
const uint8_t* quant_data,
86+
const T* absmax,
87+
int block_size,
88+
int numel,
89+
cudaStream_t stream) {
8490
int tile_size = 1024;
85-
kDequantizeBlockwise<T, 512, 64, 8><<<(numel+tile_size-1)/tile_size, 64, 0, stream>>>(quant_map, output, quant_data, absmax, block_size/2, numel);
86-
91+
kDequantizeBlockwise<T, 512, 64, 8><<<(numel + tile_size - 1) / tile_size, 64, 0, stream>>>(
92+
quant_map, output, quant_data, absmax, block_size / 2, numel);
93+
8794
return Status::OK();
8895
}
8996

90-
template Status DequantizeBnb4<float>(const float* quant_map, float *output, const unsigned char *quant_data, const float *absmax, int block_size, int numel, cudaStream_t stream);
91-
92-
template Status DequantizeBnb4<half>(const half* quant_map, half *output, const unsigned char *quant_data, const half *absmax, int block_size, int numel, cudaStream_t stream);
97+
template Status DequantizeBnb4<float>(
98+
const float* quant_map,
99+
float* output,
100+
const uint8_t* quant_data,
101+
const float* absmax,
102+
int block_size,
103+
int numel,
104+
cudaStream_t stream);
105+
106+
template Status DequantizeBnb4<half>(
107+
const half* quant_map,
108+
half* output,
109+
const uint8_t* quant_data,
110+
const half *absmax,
111+
int block_size,
112+
int numel,
113+
cudaStream_t stream);
93114

94115
} // namespace cuda
95116
} // namespace contrib

onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ template <class T>
1515
Status DequantizeBnb4(
1616
const T* quant_map,
1717
T* output,
18-
const unsigned char* quant_data,
18+
const uint8_t* quant_data,
1919
const T* absmax,
2020
int block_size,
2121
int numel,

0 commit comments

Comments
 (0)