Skip to content

Commit 7a326ed

Browse files
YUNQIUGUOfacebook-github-bot
authored andcommitted
update to tune for small ms and quantized gemv (#3712)
Summary: X-link: facebookresearch/FBGEMM#794 as title Reviewed By: ipiszy Differential Revision: D69819701
1 parent 7eeaee8 commit 7a326ed

File tree

1 file changed

+81
-43
lines changed

1 file changed

+81
-43
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/sweep_utils.py

Lines changed: 81 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import fbgemm_gpu.experimental.gen_ai # noqa: F401
78
import torch
89
from torch._inductor.utils import do_bench_using_profiling
910

1011

1112
class SweepHeuristics:
1213
def __init__(self):
13-
self.m = 1
14+
self.ms = [1, 2, 3, 4]
1415
self.block_dims = [
1516
(32, 1),
1617
(32, 4),
@@ -35,52 +36,89 @@ def __init__(self):
3536
]
3637
self.nks = [(1280, 8192), (8192, 1024), (7168, 8192), (8192, 3584)]
3738

38-
def sweep_heuristics(self, fn) -> None:
39-
for n, k in self.nks:
40-
x = torch.randn(size=(self.m, k), dtype=torch.half, device="cuda")
41-
w = torch.randn(size=(n, k), dtype=torch.half, device="cuda")
42-
best_elapsed_time, best_block_dim_x, best_block_dim_y = None, None, None
39+
def sweep_heuristics(self, fn, quantize_w=False, quantize_x=False) -> None:
40+
for m in self.ms:
41+
for n, k in self.nks:
42+
x = torch.randn(size=(m, k), dtype=torch.bfloat16, device="cuda")
43+
w = torch.randn(size=(n, k), dtype=torch.bfloat16, device="cuda")
4344

44-
for block_dim_x, block_dim_y in self.block_dims:
45-
if (
46-
(k % block_dim_x != 0)
47-
or (n % block_dim_x != 0)
48-
or ((k / block_dim_x) % 8 != 0)
49-
):
50-
continue
51-
# This requires
52-
# 1. update for "testing purpose" the pytorch custom gemv op to accept additional params block_dim_x and block_dim_y
53-
# 2. modify the corresponding `{precision}_fast_gemv.cu` kernel signature to reflect the block_dim_x and block_dim_y heuristics
54-
# e.g. https://www.internalfb.com/code/fbsource/[208a27f25373]/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py?lines=375
55-
res = do_bench_using_profiling(
56-
lambda func=fn, x=x, w=w, block_dim_x=block_dim_x, block_dim_y=block_dim_y: func(
57-
x, w, block_dim_x, block_dim_y
58-
)
59-
)
45+
best_elapsed_time, best_block_dim_x, best_block_dim_y = None, None, None
46+
47+
for block_dim_x, block_dim_y in self.block_dims:
48+
if (
49+
(k % block_dim_x != 0)
50+
or (n % block_dim_x != 0)
51+
or ((k / block_dim_x) % 8 != 0)
52+
):
53+
continue
54+
55+
res = 0.0
56+
if quantize_w and quantize_x:
57+
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
58+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
59+
res = do_bench_using_profiling(
60+
lambda func=fn, x=xq, w=wq, scale=x_scale * w_scale, block_dim_x=block_dim_x, block_dim_y=block_dim_y: func(
61+
x, w, scale, block_dim_x, block_dim_y
62+
)
63+
)
64+
elif quantize_w:
65+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
66+
res = do_bench_using_profiling(
67+
lambda func=fn, x=x, w=wq, scale=w_scale, block_dim_x=block_dim_x, block_dim_y=block_dim_y: func(
68+
x, w, scale, block_dim_x, block_dim_y
69+
)
70+
)
71+
else:
72+
# This requires
73+
# 1. update for "testing purpose" the pytorch custom gemv op to accept additional params block_dim_x and block_dim_y
74+
# 2. modify the corresponding `{precision}_fast_gemv.cu` kernel signature to reflect the block_dim_x and block_dim_y heuristics
75+
# e.g. https://www.internalfb.com/code/fbsource/[208a27f25373]/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py?lines=375
76+
res = do_bench_using_profiling(
77+
lambda func=fn, x=x, w=w, block_dim_x=block_dim_x, block_dim_y=block_dim_y: func(
78+
x, w, block_dim_x, block_dim_y
79+
)
80+
)
6081

61-
if best_elapsed_time is None or res < best_elapsed_time:
62-
best_elapsed_time, best_block_dim_x, best_block_dim_y = (
63-
res,
64-
block_dim_x,
65-
block_dim_y,
82+
if best_elapsed_time is None or res < best_elapsed_time:
83+
best_elapsed_time, best_block_dim_x, best_block_dim_y = (
84+
res,
85+
block_dim_x,
86+
block_dim_y,
87+
)
88+
if best_elapsed_time is None:
89+
print("Error: No valid elapsed time found. Exiting the function.")
90+
return
91+
if fn == torch.ops.fbgemm.bf16_fast_gemv:
92+
bw = (
93+
(m * k * 2 + n * k * 2 + m * n * 2)
94+
/ (best_elapsed_time / 1000)
95+
/ (1024**3)
96+
)
97+
elif fn == torch.ops.fbgemm.bf16fp8bf16_fast_gemv:
98+
bw = (
99+
(m * k * 2 + n * k + m * n * 2)
100+
/ (best_elapsed_time / 1000)
101+
/ (1024**3)
102+
)
103+
else: # Assuming fn is torch.ops.fbgemm.fp8fp8bf16_fast_gemv
104+
bw = (
105+
(m * k + n * k + m * n * 2)
106+
/ (best_elapsed_time / 1000)
107+
/ (1024**3)
66108
)
67-
if best_elapsed_time is None:
68-
print("Error: No valid elapsed time found. Exiting the function.")
69-
return
70-
bw = (
71-
(self.m * k * 2 + n * k * 2 + self.m * n * 2)
72-
/ (best_elapsed_time / 1000)
73-
/ (1024**3)
74-
)
75-
print(f"m: {self.m}, n: {n}, k: {k}")
76-
print(f"tuning heuristics for kernel: {fn.__name__}")
77-
print(f"best elapsed time: {best_elapsed_time} ms")
78-
print(f"best block_dim_x: {best_block_dim_x}")
79-
print(f"best block_dim_y: {best_block_dim_y}")
80-
print(f"best bw: {bw} GB/s")
109+
print(f"m: {m}, n: {n}, k: {k}")
110+
print(f"tuning heuristics for kernel: {fn.__name__}")
111+
print(f"best elapsed time: {best_elapsed_time} ms")
112+
print(f"best block_dim_x: {best_block_dim_x}")
113+
print(f"best block_dim_y: {best_block_dim_y}")
114+
print(f"best bw: {bw} GB/s")
81115

82116

83117
sweep_instance = SweepHeuristics()
84118
sweep_instance.sweep_heuristics(fn=torch.ops.fbgemm.bf16_fast_gemv)
85-
sweep_instance.sweep_heuristics(fn=torch.ops.fbgemm.bf16fp8bf16_fast_gemv)
86-
sweep_instance.sweep_heuristics(fn=torch.ops.fbgemm.fp8fp8bf16_fast_gemv)
119+
sweep_instance.sweep_heuristics(
120+
fn=torch.ops.fbgemm.bf16fp8bf16_fast_gemv, quantize_w=True
121+
)
122+
sweep_instance.sweep_heuristics(
123+
fn=torch.ops.fbgemm.fp8fp8bf16_fast_gemv, quantize_w=True, quantize_x=True
124+
)

0 commit comments

Comments
 (0)