Skip to content

Commit eb67387

Browse files
YUNQIUGUOfacebook-github-bot
authored andcommitted
Add sweep_utils.py script to tune heuristics
Summary: As title. heuristics tuning scripts for `fp16_fast_gemv` currently the script needs a manual hack to update the kernel for passing in block dims to work. see comments in the code. Reviewed By: ipiszy Differential Revision: D68786295
1 parent 00224fe commit eb67387

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch._inductor.utils import do_bench_using_profiling
9+
10+
11+
class SweepHeuristics:
12+
def __init__(self):
13+
self.m = 1
14+
self.block_dims = [
15+
(32, 1),
16+
(32, 4),
17+
(32, 8),
18+
(32, 16),
19+
(32, 32),
20+
(64, 1),
21+
(64, 2),
22+
(64, 4),
23+
(64, 8),
24+
(64, 16),
25+
(128, 1),
26+
(128, 2),
27+
(128, 4),
28+
(128, 8),
29+
(256, 1),
30+
(256, 2),
31+
(256, 4),
32+
(512, 1),
33+
(512, 2),
34+
(1024, 1),
35+
]
36+
self.nks = [(1280, 8192), (8192, 1024), (7168, 8192), (8192, 3584)]
37+
38+
def sweep_heuristics(self) -> None:
39+
for n, k in self.nks:
40+
x = torch.randn(size=(self.m, k), dtype=torch.half, device="cuda")
41+
w = torch.randn(size=(n, k), dtype=torch.half, device="cuda")
42+
best_elapsed_time, best_block_dim_x, best_block_dim_y = None, None, None
43+
44+
for block_dim_x, block_dim_y in self.block_dims:
45+
if (
46+
(k % block_dim_x != 0)
47+
or (n % block_dim_x != 0)
48+
or ((k / block_dim_x) % 8 != 0)
49+
):
50+
continue
51+
# This requires
52+
# 1. update for testing purpose for `fp16_fast_gemv` pytorch custom op to pass in block_dim_x and block_dim_y
53+
# 2. modify the fp16_fast_gemv.cu kernel signature to reflect the block_dim heuristics
54+
# https://www.internalfb.com/code/fbsource/[bafd6390bc8c842b46d81be1a27dafd384503a53]/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py?lines=365
55+
res = do_bench_using_profiling(
56+
lambda: torch.ops.fbgemm.fp16_fast_gemv(
57+
x.T, w, block_dim_x=block_dim_x, block_dim_y=block_dim_y
58+
)
59+
)
60+
61+
if best_elapsed_time is None or res < best_elapsed_time:
62+
best_elapsed_time, best_block_dim_x, best_block_dim_y = (
63+
res,
64+
block_dim_x,
65+
block_dim_y,
66+
)
67+
if best_elapsed_time is None:
68+
print("Error: No valid elapsed time found. Exiting the function.")
69+
return
70+
bw = (
71+
(self.m * k * 2 + n * k * 2 + self.m * n * 2)
72+
/ (best_elapsed_time / 1000)
73+
/ (1024**3)
74+
)
75+
76+
print(f"best elapsed time: {best_elapsed_time} ms")
77+
print(f"best block_dim_x: {best_block_dim_x}")
78+
print(f"best block_dim_y: {best_block_dim_y}")
79+
print(f"best bw: {bw} GB/s")
80+
81+
82+
sweep_instance = SweepHeuristics()
83+
sweep_instance.sweep_heuristics()

0 commit comments

Comments
 (0)