@@ -11,23 +11,22 @@ namespace onnxruntime {
1111namespace contrib {
1212namespace cuda {
1313
14- template <class T >
15- Status SetBnbQuantMap (int quant_type, T* quant_map_buffer, cudaStream_t stream)
16- {
17- ORT_ENFORCE (quant_type == FP4 || quant_type == NF4, " Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported." );
18-
14+ template <class T >
15+ Status SetBnbQuantMap (int quant_type, T* quant_map_buffer, cudaStream_t stream) {
16+ ORT_ENFORCE (
17+ quant_type == FP4 || quant_type == NF4,
18+ " Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported." );
19+
1920 T host_quant_map[16 ];
2021 switch (quant_type) {
2122 case FP4:
22- for (int i = 0 ; i < 16 ; i++)
23- host_quant_map[i] = static_cast <T>(fp4_qaunt_map[i]);
23+ for (int i = 0 ; i < 16 ; i++) host_quant_map[i] = static_cast <T>(fp4_qaunt_map[i]);
2424 break ;
2525 case NF4:
26- for (int i = 0 ; i < 16 ; i++)
27- host_quant_map[i] = static_cast <T>(nf4_qaunt_map[i]);
26+ for (int i = 0 ; i < 16 ; i++) host_quant_map[i] = static_cast <T>(nf4_qaunt_map[i]);
2827 break ;
2928 }
30- CUDA_CALL_THROW (cudaMemcpyAsync (quant_map_buffer, host_quant_map, sizeof (T)* 16 , cudaMemcpyHostToDevice, stream));
29+ CUDA_CALL_THROW (cudaMemcpyAsync (quant_map_buffer, host_quant_map, sizeof (T) * 16 , cudaMemcpyHostToDevice, stream));
3130
3231 return Status::OK ();
3332}
@@ -36,60 +35,82 @@ template Status SetBnbQuantMap<float>(int quant_type, float* quant_map_buffer, c
3635
3736template Status SetBnbQuantMap<half>(int quant_type, half* quant_map_buffer, cudaStream_t stream);
3837
39-
40- template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH>
41- __global__ void kDequantizeBlockwise (const T *quant_map, T *output, const unsigned char *quant_data, const T *absmax, const int block_size, const int n)
42- {
38+ template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH>
39+ __global__ void kDequantizeBlockwise (
40+ const T* quant_map,
41+ T* output,
42+ const uint8_t * quant_data,
43+ const T* absmax,
44+ const int block_size,
45+ const int n) {
4346 const int n_load = (gridDim .x * TILE_SIZE);
4447 int valid_items_load = 0 ;
4548 int valid_items_store = 0 ;
4649 const int base_idx = (blockIdx .x * TILE_SIZE);
4750
48- T vals[NUM_PER_TH* 2 ];
49- unsigned char qvals[NUM_PER_TH];
51+ T vals[NUM_PER_TH * 2 ];
52+ uint8_t qvals[NUM_PER_TH];
5053 T local_abs_max = T (0 .0f );
5154
52- typedef cub::BlockLoad<unsigned char , THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
53- typedef cub::BlockStore<T, THREADS, NUM_PER_TH* 2 , cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
55+ typedef cub::BlockLoad<uint8_t , THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
56+ typedef cub::BlockStore<T, THREADS, NUM_PER_TH * 2 , cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
5457
5558 __shared__ typename LoadChar::TempStorage loadchar;
5659 __shared__ typename StoreT::TempStorage storet;
5760
58- for (unsigned int i = base_idx; i < n_load; i += gridDim .x *TILE_SIZE)
59- {
60- valid_items_load = (n+1 )/2 - i > TILE_SIZE ? TILE_SIZE : (n+1 )/2 - i;
61- valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2 ;
61+ for (unsigned int i = base_idx; i < n_load; i += gridDim .x * TILE_SIZE) {
62+ valid_items_load = (n + 1 ) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1 ) / 2 - i;
63+ valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2 ;
6264
63- local_abs_max = __ldg (&absmax[(i+ threadIdx .x * NUM_PER_TH)/ (block_size)]);
65+ local_abs_max = __ldg (&absmax[(i + threadIdx .x * NUM_PER_TH) / (block_size)]);
6466
6567 __syncthreads ();
6668 LoadChar (loadchar).Load (&(quant_data[i]), qvals, valid_items_load, 128 );
6769
6870 #pragma unroll NUM_PER_TH
69- for (int j = 0 ; j < NUM_PER_TH; j++)
70- {
71- vals[j*2 ] = quant_map[qvals[j] >> 4 ] * local_abs_max;
72- vals[j*2 + 1 ] = quant_map[qvals[j] & 0x0F ] * local_abs_max;
71+ for (int j = 0 ; j < NUM_PER_TH; j++) {
72+ vals[j * 2 ] = quant_map[qvals[j] >> 4 ] * local_abs_max;
73+ vals[j * 2 + 1 ] = quant_map[qvals[j] & 0x0F ] * local_abs_max;
7374 }
7475
7576 __syncthreads ();
76- StoreT (storet).Store (&(output[i* 2 ]), vals, valid_items_store);
77+ StoreT (storet).Store (&(output[i * 2 ]), vals, valid_items_store);
7778 }
7879}
7980
80-
81- template <class T >
82- Status DequantizeBnb4 (const T* quant_map, T *output, const unsigned char *quant_data, const T *absmax, int block_size, int numel, cudaStream_t stream)
83- {
81+ template <class T >
82+ Status DequantizeBnb4 (
83+ const T* quant_map,
84+ T* output,
85+ const uint8_t * quant_data,
86+ const T* absmax,
87+ int block_size,
88+ int numel,
89+ cudaStream_t stream) {
8490 int tile_size = 1024 ;
85- kDequantizeBlockwise <T, 512 , 64 , 8 ><<<(numel+tile_size-1 )/tile_size, 64 , 0 , stream>>> (quant_map, output, quant_data, absmax, block_size/2 , numel);
86-
91+ kDequantizeBlockwise <T, 512 , 64 , 8 ><<<(numel + tile_size - 1 ) / tile_size, 64 , 0 , stream>>> (
92+ quant_map, output, quant_data, absmax, block_size / 2 , numel);
93+
8794 return Status::OK ();
8895}
8996
90- template Status DequantizeBnb4<float >(const float * quant_map, float *output, const unsigned char *quant_data, const float *absmax, int block_size, int numel, cudaStream_t stream);
91-
92- template Status DequantizeBnb4<half>(const half* quant_map, half *output, const unsigned char *quant_data, const half *absmax, int block_size, int numel, cudaStream_t stream);
97+ template Status DequantizeBnb4<float >(
98+ const float * quant_map,
99+ float * output,
100+ const uint8_t * quant_data,
101+ const float * absmax,
102+ int block_size,
103+ int numel,
104+ cudaStream_t stream);
105+
106+ template Status DequantizeBnb4<half>(
107+ const half* quant_map,
108+ half* output,
109+ const uint8_t * quant_data,
110+ const half *absmax,
111+ int block_size,
112+ int numel,
113+ cudaStream_t stream);
93114
94115} // namespace cuda
95116} // namespace contrib
0 commit comments