Skip to content

Commit b4e8bd8

Browse files
mxz297facebook-github-bot
authored andcommitted
adding an option to skip zeroing output tensor for f8f8bf16_rowwise_grouped_dynamic (#3685)
Summary: X-link: facebookresearch/FBGEMM#761 In certain uses cases, the user of this api does not need zeroing out the padded area, so add this option. Note that currently the actual skipping is only done for AMD. Differential Revision: D69380351
1 parent 3c0cd95 commit b4e8bd8

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ __global__ void set_kernel_args_fixed_nk_kernel(
206206
int M,
207207
int N,
208208
int K,
209-
int group_count) {
209+
int group_count,
210+
bool zeroing_output_tensor) {
210211
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
211212
// Each thread is responsible for setting up the arguments for one group.
212213
if (thread_idx < group_count) {
@@ -227,6 +228,7 @@ __global__ void set_kernel_args_fixed_nk_kernel(
227228
// Write kernel args to memory.
228229
kernel_args[thread_idx] = kernel_group_args;
229230
}
231+
if (!zeroing_output_tensor) return;
230232

231233
// Figure out where in memory we are.
232234
// Each thread sets one float 4 which corresponds to 8 bf16 values.
@@ -252,7 +254,8 @@ void set_dynamic_kernel_args(
252254
at::Tensor x_scale,
253255
at::Tensor w_scale,
254256
at::Tensor output,
255-
at::Tensor zero_start_index_M) {
257+
at::Tensor zero_start_index_M,
258+
bool zeroing_output_tensor) {
256259
// Get current cuda stream.
257260
auto stream = at::cuda::getCurrentHIPStream().stream();
258261
int group_count = XQ.size(0);
@@ -292,7 +295,8 @@ void set_dynamic_kernel_args(
292295
M,
293296
N,
294297
K,
295-
group_count);
298+
group_count,
299+
zeroing_output_tensor);
296300
}
297301

298302
template <typename OutputType>
@@ -433,7 +437,8 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
433437
at::Tensor WQ,
434438
at::Tensor x_scale,
435439
at::Tensor w_scale,
436-
at::Tensor zero_start_index_M) {
440+
at::Tensor zero_start_index_M,
441+
bool zeroing_output_tensor = true) {
437442
// Check that input datatypes are valid.
438443
// First confirm that there are the same number of groups in all inputs.
439444
int group_count = XQ.size(0);
@@ -473,7 +478,7 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
473478
{static_cast<long>(group_count * sizeof(KernelArguments))},
474479
XQ.options().dtype(at::kByte));
475480
set_dynamic_kernel_args(
476-
kernel_args, XQ, WQ, x_scale, w_scale, Y, zero_start_index_M);
481+
kernel_args, XQ, WQ, x_scale, w_scale, Y, zero_start_index_M, zeroing_output_tensor);
477482

478483
RowwiseGroupedKernel<at::Tensor, at::Tensor> selected_kernel =
479484
rowwise_grouped_heuristic_dispatch<at::Tensor, at::Tensor>(M, N, K);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,14 +682,20 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
682682
at::Tensor WQ, // FP8
683683
at::Tensor x_scale,
684684
at::Tensor w_scale,
685-
at::Tensor zero_start_index_M) {
685+
at::Tensor zero_start_index_M,
686+
bool zeroing_output_tensor = true) {
686687
at::Tensor Y;
687688
int group_count = XQ.size(0);
688689
int M = XQ.size(1);
689690
int N = WQ.size(1);
690691
int K = XQ.size(0);
691692
int total_output_size = group_count * M * N;
692-
Y = at::zeros(total_output_size, XQ.options().dtype(at::kBFloat16));
693+
if (zeroing_output_tensor) {
694+
Y = at::zeros(total_output_size, XQ.options().dtype(at::kBFloat16));
695+
} else {
696+
Y = at::empty(total_output_size, XQ.options().dtype(at::kBFloat16));
697+
}
698+
693699
// Return continuous view of output.
694700
at::Tensor output = dispatch_fp8_grouped_kernel<at::Tensor>(
695701
XQ, WQ, x_scale, w_scale, Y, zero_start_index_M);
@@ -724,7 +730,8 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
724730
at::Tensor WQ, // FP8
725731
at::Tensor x_scale,
726732
at::Tensor w_scale,
727-
at::Tensor zero_start_index_M) {
733+
at::Tensor zero_start_index_M,
734+
bool zeroing_output_tensor = true) {
728735
throw std::runtime_error(
729736
"CUDA version is older than 12.0"); // requires CUDA>=12
730737
}

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
104104
at::Tensor WQ,
105105
at::Tensor x_scale,
106106
at::Tensor w_scale,
107-
at::Tensor zero_start_index_M);
107+
at::Tensor zero_start_index_M,
108+
bool zeroing_output_tensor = true);
108109
at::Tensor f8f8bf16_blockwise(
109110
at::Tensor XQ,
110111
at::Tensor WQ,
@@ -221,7 +222,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
221222
m.def(
222223
"f8f8bf16_rowwise_grouped_stacked(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor(a!)? output=None) -> Tensor");
223224
m.def(
224-
"f8f8bf16_rowwise_grouped_dynamic(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor zero_start_index_M) -> Tensor");
225+
"f8f8bf16_rowwise_grouped_dynamic(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor zero_start_index_M, bool zeroing_output_tensor=True) -> Tensor");
225226
m.def(
226227
"f8f8bf16_tensorwise(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor");
227228
m.def("per_tensor_quantize_i8(Tensor X, float scale) -> Tensor");

0 commit comments

Comments
 (0)