@@ -12,18 +12,19 @@ namespace contrib {
1212namespace cuda {
1313
1414template <class T >
15- Status SetBnbQuantMap (int quant_type, T* quant_map_buffer, cudaStream_t stream)
16- {
17- ORT_ENFORCE (quant_type == FP4 || quant_type == NF4, " Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported." );
18-
15+ Status SetBnbQuantMap (int quant_type, T* quant_map_buffer, cudaStream_t stream) {
16+ ORT_ENFORCE (
17+ quant_type == FP4 || quant_type == NF4,
18+ " Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported." );
19+
1920 T host_quant_map[16 ];
2021 switch (quant_type) {
2122 case FP4:
22- for (int i = 0 ; i < 16 ; i++)
23+ for (int i = 0 ; i < 16 ; i++)
2324 host_quant_map[i] = static_cast <T>(fp4_qaunt_map[i]);
2425 break ;
2526 case NF4:
26- for (int i = 0 ; i < 16 ; i++)
27+ for (int i = 0 ; i < 16 ; i++)
2728 host_quant_map[i] = static_cast <T>(nf4_qaunt_map[i]);
2829 break ;
2930 }
@@ -38,25 +39,29 @@ template Status SetBnbQuantMap<half>(int quant_type, half* quant_map_buffer, cud
3839
3940
4041template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH>
41- __global__ void kDequantizeBlockwise (const T *quant_map, T *output, const unsigned char *quant_data, const T *absmax, const int block_size, const int n)
42- {
42+ __global__ void kDequantizeBlockwise (
43+ const T *quant_map,
44+ T *output,
45+ const uint8_t *quant_data,
46+ const T *absmax,
47+ const int block_size,
48+ const int n) {
4349 const int n_load = (gridDim .x * TILE_SIZE);
4450 int valid_items_load = 0 ;
4551 int valid_items_store = 0 ;
4652 const int base_idx = (blockIdx .x * TILE_SIZE);
4753
4854 T vals[NUM_PER_TH*2 ];
49- unsigned char qvals[NUM_PER_TH];
55+ uint8_t qvals[NUM_PER_TH];
5056 T local_abs_max = T (0 .0f );
5157
52- typedef cub::BlockLoad<unsigned char , THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
58+ typedef cub::BlockLoad<uint8_t , THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
5359 typedef cub::BlockStore<T, THREADS, NUM_PER_TH*2 , cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
5460
5561 __shared__ typename LoadChar::TempStorage loadchar;
5662 __shared__ typename StoreT::TempStorage storet;
5763
58- for (unsigned int i = base_idx; i < n_load; i += gridDim .x *TILE_SIZE)
59- {
64+ for (unsigned int i = base_idx; i < n_load; i += gridDim .x *TILE_SIZE) {
6065 valid_items_load = (n+1 )/2 - i > TILE_SIZE ? TILE_SIZE : (n+1 )/2 - i;
6166 valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2 ;
6267
@@ -66,8 +71,7 @@ __global__ void kDequantizeBlockwise(const T *quant_map, T *output, const unsign
6671 LoadChar (loadchar).Load (&(quant_data[i]), qvals, valid_items_load, 128 );
6772
6873 #pragma unroll NUM_PER_TH
69- for (int j = 0 ; j < NUM_PER_TH; j++)
70- {
74+ for (int j = 0 ; j < NUM_PER_TH; j++) {
7175 vals[j*2 ] = quant_map[qvals[j] >> 4 ] * local_abs_max;
7276 vals[j*2 + 1 ] = quant_map[qvals[j] & 0x0F ] * local_abs_max;
7377 }
@@ -79,17 +83,43 @@ __global__ void kDequantizeBlockwise(const T *quant_map, T *output, const unsign
7983
8084
8185template <class T >
82- Status DequantizeBnb4 (const T* quant_map, T *output, const unsigned char *quant_data, const T *absmax, int block_size, int numel, cudaStream_t stream)
83- {
86+ Status DequantizeBnb4 (
87+ const T* quant_map,
88+ T *output,
89+ const uint8_t *quant_data,
90+ const T *absmax,
91+ int block_size,
92+ int numel,
93+ cudaStream_t stream) {
8494 int tile_size = 1024 ;
85- kDequantizeBlockwise <T, 512 , 64 , 8 ><<<(numel+tile_size-1 )/tile_size, 64 , 0 , stream>>> (quant_map, output, quant_data, absmax, block_size/2 , numel);
86-
95+ kDequantizeBlockwise <T, 512 , 64 , 8 ><<<(numel+tile_size-1 )/tile_size, 64 , 0 , stream>>> (
96+ quant_map,
97+ output,
98+ quant_data,
99+ absmax,
100+ block_size/2 ,
101+ numel);
102+
87103 return Status::OK ();
88104}
89105
90- template Status DequantizeBnb4<float >(const float * quant_map, float *output, const unsigned char *quant_data, const float *absmax, int block_size, int numel, cudaStream_t stream);
91-
92- template Status DequantizeBnb4<half>(const half* quant_map, half *output, const unsigned char *quant_data, const half *absmax, int block_size, int numel, cudaStream_t stream);
106+ template Status DequantizeBnb4<float >(
107+ const float * quant_map,
108+ float *output,
109+ const uint8_t *quant_data,
110+ const float *absmax,
111+ int block_size,
112+ int numel,
113+ cudaStream_t stream);
114+
115+ template Status DequantizeBnb4<half>(
116+ const half* quant_map,
117+ half *output,
118+ const uint8_t *quant_data,
119+ const half *absmax,
120+ int block_size,
121+ int numel,
122+ cudaStream_t stream);
93123
94124} // namespace cuda
95125} // namespace contrib
0 commit comments