Skip to content

Commit 45f75ec

Browse files
committed
c++ linter
spellcheck
1 parent cca2da8 commit 45f75ec

File tree

10 files changed

+177
-116
lines changed

10 files changed

+177
-116
lines changed

onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -44,67 +44,79 @@ FORCEINLINE uint8_t QuantizeOneFP4(float x) {
4444

4545
int sign = x < 0 ? 0b1000 : 0b0000;
4646
x = fabsf(x);
47-
if (x > 0.29166667f)
48-
if (x > 0.583333f)
49-
if (x > 0.8333333f)
47+
if (x > 0.29166667f) {
48+
if (x > 0.583333f) {
49+
if (x > 0.8333333f) {
5050
return 0b0011 + sign;
51-
else
51+
} else {
5252
return 0b0010 + sign;
53-
else if (x > 0.4166667f)
53+
}
54+
} else if (x > 0.4166667f) {
5455
return 0b101 + sign;
55-
else
56+
} else {
5657
return 0b100 + sign;
57-
else if (x > 0.0859375f)
58-
if (x > 0.20833333f)
58+
}
59+
} else if (x > 0.0859375f) {
60+
if (x > 0.20833333f) {
5961
return 0b0111 + sign;
60-
else
62+
} else {
6163
return 0b0110 + sign;
62-
else if (x > 0.00260417f)
64+
}
65+
} else if (x > 0.00260417f) {
6366
return 0b0001 + sign;
64-
else
67+
} else {
6568
return 0b0000 + sign;
69+
}
6670
}
6771

6872
FORCEINLINE uint8_t QuantizeOneNF4(float x) {
69-
if (x > 0.03979014977812767f)
70-
if (x > 0.3893125355243683f) // 1
71-
if (x > 0.6427869200706482f) // 11
72-
if (x > 0.8614784181118011f) // 111
73+
if (x > 0.03979014977812767f) {
74+
if (x > 0.3893125355243683f) { // 1
75+
if (x > 0.6427869200706482f) { // 11
76+
if (x > 0.8614784181118011f) { // 111
7377
return 0b1111;
74-
else
78+
} else {
7579
return 0b1110;
76-
else if (x > 0.5016634166240692f) // 110
80+
}
81+
} else if (x > 0.5016634166240692f) { // 110
7782
return 0b1101;
78-
else
83+
} else {
7984
return 0b1100;
80-
else if (x > 0.2035212516784668f) // 10
81-
if (x > 0.2920137718319893f) // 101
85+
}
86+
} else if (x > 0.2035212516784668f) { // 10
87+
if (x > 0.2920137718319893f) { // 101
8288
return 0b1011;
83-
else
89+
} else {
8490
return 0b1010;
85-
else if (x > 0.1202552504837513f) // 100
91+
}
92+
} else if (x > 0.1202552504837513f) { // 100
8693
return 0b1001;
87-
else
94+
} else {
8895
return 0b1000;
89-
else if (x > -0.33967943489551544f) // 0
90-
if (x > -0.13791173323988914f) // 01
91-
if (x > -0.045525018125772476f) // 011
96+
}
97+
} else if (x > -0.33967943489551544f) { // 0
98+
if (x > -0.13791173323988914f) { // 01
99+
if (x > -0.045525018125772476f) { // 011
92100
return 0b0111;
93-
else
101+
} else {
94102
return 0b0110;
95-
else if (x > -0.23460740596055984f) // 010
103+
}
104+
} else if (x > -0.23460740596055984f) { // 010
96105
return 0b0101;
97-
else
106+
} else {
98107
return 0b0100;
99-
else if (x > -0.6106329262256622f) // 00
100-
if (x > -0.4599952697753906f) // 001
108+
}
109+
} else if (x > -0.6106329262256622f) { // 00
110+
if (x > -0.4599952697753906f) { // 001
101111
return 0b0011;
102-
else
112+
} else {
103113
return 0b0010;
104-
else if (x > -0.8480964004993439f) // 000
114+
}
115+
} else if (x > -0.8480964004993439f) { // 000
105116
return 0b0001;
106-
else
117+
} else {
107118
return 0b0000;
119+
}
108120
}
109121

110122
template <int32_t DATA_TYPE>

onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ void QuantizeBlockwiseBnb4(
5151
int32_t N,
5252
int32_t K,
5353
onnxruntime::concurrency::ThreadPool* thread_pool) {
54-
ORT_ENFORCE(quant_type == FP4 || quant_type == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
54+
ORT_ENFORCE(
55+
quant_type == FP4 || quant_type == NF4,
56+
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
5557

5658
if (block_size == 16) {
5759
QuantizeBlockwiseBn4DataTyped(16, quant_type);
@@ -106,7 +108,9 @@ void DequantizeBlockwiseBnb4(
106108
int32_t N,
107109
int32_t K,
108110
onnxruntime::concurrency::ThreadPool* thread_pool) {
109-
ORT_ENFORCE(quant_type == FP4 || quant_type == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
111+
ORT_ENFORCE(
112+
quant_type == FP4 || quant_type == NF4,
113+
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
110114

111115
if (block_size == 16) {
112116
DequantizeBlockwiseBn4DataTyped(16, quant_type);

onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ class MatMulBnb4 final : public OpKernel {
1818
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
1919
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
2020
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("quant_type", &quant_type_));
21-
ORT_ENFORCE(quant_type_ == FP4 || quant_type_ == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
21+
ORT_ENFORCE(
22+
quant_type_ == FP4 || quant_type_ == NF4,
23+
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
2224
}
2325

2426
Status Compute(OpKernelContext* context) const override;

onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,19 @@ namespace contrib {
1212
namespace cuda {
1313

1414
template<class T>
15-
Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream)
16-
{
17-
ORT_ENFORCE(quant_type == FP4 || quant_type == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
18-
15+
Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) {
16+
ORT_ENFORCE(
17+
quant_type == FP4 || quant_type == NF4,
18+
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
19+
1920
T host_quant_map[16];
2021
switch (quant_type) {
2122
case FP4:
22-
for(int i = 0; i < 16; i++)
23+
for (int i = 0; i < 16; i++)
2324
host_quant_map[i] = static_cast<T>(fp4_qaunt_map[i]);
2425
break;
2526
case NF4:
26-
for(int i = 0; i < 16; i++)
27+
for (int i = 0; i < 16; i++)
2728
host_quant_map[i] = static_cast<T>(nf4_qaunt_map[i]);
2829
break;
2930
}
@@ -38,25 +39,29 @@ template Status SetBnbQuantMap<half>(int quant_type, half* quant_map_buffer, cud
3839

3940

4041
template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH>
41-
__global__ void kDequantizeBlockwise(const T *quant_map, T *output, const unsigned char *quant_data, const T *absmax, const int block_size, const int n)
42-
{
42+
__global__ void kDequantizeBlockwise(
43+
const T *quant_map,
44+
T *output,
45+
const uint8_t *quant_data,
46+
const T *absmax,
47+
const int block_size,
48+
const int n) {
4349
const int n_load = (gridDim.x * TILE_SIZE);
4450
int valid_items_load = 0;
4551
int valid_items_store = 0;
4652
const int base_idx = (blockIdx.x * TILE_SIZE);
4753

4854
T vals[NUM_PER_TH*2];
49-
unsigned char qvals[NUM_PER_TH];
55+
uint8_t qvals[NUM_PER_TH];
5056
T local_abs_max = T(0.0f);
5157

52-
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
58+
typedef cub::BlockLoad<uint8_t, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
5359
typedef cub::BlockStore<T, THREADS, NUM_PER_TH*2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
5460

5561
__shared__ typename LoadChar::TempStorage loadchar;
5662
__shared__ typename StoreT::TempStorage storet;
5763

58-
for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE)
59-
{
64+
for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) {
6065
valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i;
6166
valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2;
6267

@@ -66,8 +71,7 @@ __global__ void kDequantizeBlockwise(const T *quant_map, T *output, const unsign
6671
LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128);
6772

6873
#pragma unroll NUM_PER_TH
69-
for(int j = 0; j < NUM_PER_TH; j++)
70-
{
74+
for (int j = 0; j < NUM_PER_TH; j++) {
7175
vals[j*2] = quant_map[qvals[j] >> 4] * local_abs_max;
7276
vals[j*2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max;
7377
}
@@ -79,17 +83,43 @@ __global__ void kDequantizeBlockwise(const T *quant_map, T *output, const unsign
7983

8084

8185
template<class T>
82-
Status DequantizeBnb4(const T* quant_map, T *output, const unsigned char *quant_data, const T *absmax, int block_size, int numel, cudaStream_t stream)
83-
{
86+
Status DequantizeBnb4(
87+
const T* quant_map,
88+
T *output,
89+
const uint8_t *quant_data,
90+
const T *absmax,
91+
int block_size,
92+
int numel,
93+
cudaStream_t stream) {
8494
int tile_size = 1024;
85-
kDequantizeBlockwise<T, 512, 64, 8><<<(numel+tile_size-1)/tile_size, 64, 0, stream>>>(quant_map, output, quant_data, absmax, block_size/2, numel);
86-
95+
kDequantizeBlockwise<T, 512, 64, 8><<<(numel+tile_size-1)/tile_size, 64, 0, stream>>>(
96+
quant_map,
97+
output,
98+
quant_data,
99+
absmax,
100+
block_size/2,
101+
numel);
102+
87103
return Status::OK();
88104
}
89105

90-
template Status DequantizeBnb4<float>(const float* quant_map, float *output, const unsigned char *quant_data, const float *absmax, int block_size, int numel, cudaStream_t stream);
91-
92-
template Status DequantizeBnb4<half>(const half* quant_map, half *output, const unsigned char *quant_data, const half *absmax, int block_size, int numel, cudaStream_t stream);
106+
template Status DequantizeBnb4<float>(
107+
const float* quant_map,
108+
float *output,
109+
const uint8_t *quant_data,
110+
const float *absmax,
111+
int block_size,
112+
int numel,
113+
cudaStream_t stream);
114+
115+
template Status DequantizeBnb4<half>(
116+
const half* quant_map,
117+
half *output,
118+
const uint8_t *quant_data,
119+
const half *absmax,
120+
int block_size,
121+
int numel,
122+
cudaStream_t stream);
93123

94124
} // namespace cuda
95125
} // namespace contrib

onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ template <class T>
1515
Status DequantizeBnb4(
1616
const T* quant_map,
1717
T* output,
18-
const unsigned char* quant_data,
18+
const uint8_t* quant_data,
1919
const T* absmax,
2020
int block_size,
2121
int numel,

onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ class MatMulBnb4 final : public CudaKernel {
2222
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
2323
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
2424
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("quant_type", &quant_type_));
25-
ORT_ENFORCE(quant_type_ == FP4 || quant_type_ == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
25+
ORT_ENFORCE(
26+
quant_type_ == FP4 || quant_type_ == NF4,
27+
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
2628
}
2729

2830
Status ComputeInternal(OpKernelContext* context) const override;

0 commit comments

Comments
 (0)