Skip to content

Commit fe5c115

Browse files
gmagogsfmichbinblauclaude
authored
[vLLM IR] Add IR op testing and benchmarking infrastructure (vllm-project#40167)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com> Co-authored-by: Theresa Shan <Theresa.Shan@amd.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6867bcd commit fe5c115

12 files changed

Lines changed: 673 additions & 45 deletions

File tree

benchmarks/__init__.py

Whitespace-only changes.

benchmarks/kernels/__init__.py

Whitespace-only changes.

benchmarks/kernels/ir/__init__.py

Whitespace-only changes.
Lines changed: 378 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,378 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: Apache-2.0
3+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
4+
"""
5+
Generic benchmark harness for vLLM IR ops.
6+
7+
Usage:
8+
python benchmarks/kernels/ir/bench_ir_ops.py
9+
python benchmarks/kernels/ir/bench_ir_ops.py --ops rms_norm
10+
python benchmarks/kernels/ir/bench_ir_ops.py --ops rms_norm,silu_mul
11+
python benchmarks/kernels/ir/bench_ir_ops.py --no-cuda-graph
12+
python benchmarks/kernels/ir/bench_ir_ops.py --ops rms_norm --save-path ./results/
13+
"""
14+
15+
import argparse
16+
import contextlib
17+
import csv
18+
import dataclasses
19+
import datetime
20+
import math
21+
import os
22+
import subprocess
23+
import sys
24+
import tempfile
25+
26+
# Ensure repo root is on sys.path so `benchmarks` is importable as a package.
27+
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
28+
if _REPO_ROOT not in sys.path:
29+
sys.path.insert(0, _REPO_ROOT)
30+
31+
# Suppress noisy C++ warnings from vllm kernel registration (written to fd 2
32+
# directly by the dynamic linker, so Python-level sys.stderr redirect won't
33+
# catch them).
34+
_saved_fd = os.dup(2)
35+
try:
36+
with open(os.devnull, "w") as _devnull:
37+
os.dup2(_devnull.fileno(), 2)
38+
import torch
39+
40+
import vllm.kernels # noqa: E402, F401
41+
finally:
42+
os.dup2(_saved_fd, 2)
43+
os.close(_saved_fd)
44+
45+
from tqdm import tqdm # noqa: E402
46+
47+
from benchmarks.kernels.ir.shapes import SHAPE_CONFIGS # noqa: E402 # isort: skip
48+
from vllm.ir.op import IrOp # noqa: E402
49+
from vllm.platforms import current_platform # noqa: E402
50+
from vllm.triton_utils import triton # noqa: E402
51+
52+
53+
@dataclasses.dataclass(frozen=True)
54+
class BenchConfig:
55+
use_cuda_graph: bool = True
56+
warmup: int = 25
57+
rep: int = 100
58+
59+
60+
def _pkg_version(name: str) -> str:
61+
from importlib.metadata import PackageNotFoundError, version
62+
63+
with contextlib.suppress(PackageNotFoundError):
64+
return version(name)
65+
return "not installed"
66+
67+
68+
_METADATA_LABELS = {
69+
"timestamp": "Timestamp",
70+
"git_commit": "Git commit",
71+
"vllm": "vLLM",
72+
"pytorch": "PyTorch",
73+
"cuda_runtime": "CUDA runtime",
74+
"triton": "Triton",
75+
"cutlass": "CUTLASS",
76+
"helion": "Helion",
77+
"device": "Device",
78+
"bench_mode": "Bench mode",
79+
"warmup": "Warmup",
80+
"rep": "Repetitions",
81+
}
82+
83+
84+
def collect_env_metadata(cfg: BenchConfig) -> dict[str, str]:
85+
from vllm.collect_env import get_env_info
86+
87+
env = get_env_info()
88+
89+
git_sha = "unknown"
90+
with contextlib.suppress(subprocess.CalledProcessError, FileNotFoundError):
91+
git_sha = (
92+
subprocess.check_output(
93+
["git", "rev-parse", "--short", "HEAD"], stderr=subprocess.DEVNULL
94+
)
95+
.decode()
96+
.strip()
97+
)
98+
99+
device_name = current_platform.get_device_name()
100+
101+
warmup_note = " ms" if not cfg.use_cuda_graph else " ms (ignored)"
102+
rep_note = " replays" if cfg.use_cuda_graph else " ms"
103+
104+
return {
105+
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
106+
"git_commit": git_sha,
107+
"vllm": str(env.vllm_version),
108+
"pytorch": str(env.torch_version),
109+
"cuda_runtime": str(env.cuda_runtime_version),
110+
"triton": triton.__version__,
111+
"cutlass": _pkg_version("nvidia-cutlass-dsl"),
112+
"helion": _pkg_version("helion"),
113+
"device": device_name,
114+
"bench_mode": "cuda_graph" if cfg.use_cuda_graph else "eager",
115+
"warmup": f"{cfg.warmup}{warmup_note}",
116+
"rep": f"{cfg.rep}{rep_note}",
117+
}
118+
119+
120+
def print_metadata(metadata: dict[str, str]):
121+
print("=" * 60)
122+
for key, val in metadata.items():
123+
print(f"{_METADATA_LABELS.get(key, key) + ':':<16}{val}")
124+
print("=" * 60)
125+
126+
127+
def _clone_args(args: tuple) -> tuple:
128+
return tuple(a.clone() if isinstance(a, torch.Tensor) else a for a in args)
129+
130+
131+
# TODO(gmagogsfm): When the `maybe_inplace` PR lands, ops marked as
132+
# inplace=True will mutate bench_args across iterations. Both CUDA graph
133+
# and eager modes will accumulate drift from repeated in-place mutation.
134+
# We need to re-clone inputs per iteration for inplace ops.
135+
def _bench_one(fn, args, cfg: BenchConfig) -> float:
136+
bench_args = _clone_args(args)
137+
bench_fn = lambda: fn(*bench_args)
138+
139+
if cfg.use_cuda_graph:
140+
ms = triton.testing.do_bench_cudagraph(bench_fn, rep=cfg.rep, quantiles=[0.5])
141+
else:
142+
ms = triton.testing.do_bench(
143+
bench_fn, warmup=cfg.warmup, rep=cfg.rep, quantiles=[0.5]
144+
)
145+
return ms * 1000
146+
147+
148+
# TODO(gmagogsfm): Once compiled native implementation lands (#38775),
149+
# the benchmark baseline should be the compiled native (what vLLM runs by
150+
# default) rather than the uncompiled native implementation.
151+
def collect_timings(
152+
op: IrOp, shape_configs: list[dict], cfg: BenchConfig
153+
) -> tuple[list[str], list[str], dict[str, dict[str, float]]]:
154+
def fmt(v) -> str:
155+
return str(v).split(".")[-1] if isinstance(v, torch.dtype) else str(v)
156+
157+
case_names = [
158+
"_".join(f"{k}={fmt(v)}" for k, v in kwargs.items()) for kwargs in shape_configs
159+
]
160+
providers = [n for n, impl in op.impls.items() if impl.supported]
161+
162+
results: dict[str, dict[str, float]] = {c: {} for c in case_names}
163+
for provider in providers:
164+
impl = op.impls[provider]
165+
desc = f"{op.name} / {provider}"
166+
for case_name, kwargs in tqdm(
167+
zip(case_names, shape_configs),
168+
desc=desc,
169+
total=len(case_names),
170+
unit=" cases",
171+
):
172+
args = op.generate_inputs(**kwargs)
173+
if impl.supports_args(*args):
174+
results[case_name][provider] = _bench_one(impl.impl_fn, args, cfg)
175+
else:
176+
results[case_name][provider] = float("nan")
177+
178+
return case_names, providers, results
179+
180+
181+
def analyze_results(
182+
op_name: str,
183+
case_names: list[str],
184+
providers: list[str],
185+
results: dict[str, dict[str, float]],
186+
) -> tuple[list[dict[str, str]], list[dict[str, str]], list[str]]:
187+
native_col = "native"
188+
non_native = [p for p in providers if p != native_col]
189+
190+
header_cols = ["case"]
191+
for p in providers:
192+
header_cols.append(f"{p} (us)")
193+
for p in non_native:
194+
header_cols.append(f"{p} speedup")
195+
196+
detail_rows: list[dict[str, str]] = []
197+
speedup_data: dict[str, list[tuple[float, str]]] = {p: [] for p in non_native}
198+
199+
for case_name in case_names:
200+
timings = results[case_name]
201+
row: dict[str, str] = {"case": case_name}
202+
203+
for p in providers:
204+
val = timings.get(p, float("nan"))
205+
row[f"{p} (us)"] = f"{val:.2f}" if not math.isnan(val) else "n/a"
206+
207+
native_us = timings.get(native_col, float("nan"))
208+
for p in non_native:
209+
p_us = timings.get(p, float("nan"))
210+
if not math.isnan(native_us) and not math.isnan(p_us) and p_us > 0:
211+
speedup = native_us / p_us
212+
row[f"{p} speedup"] = f"{speedup:.2f}x"
213+
speedup_data[p].append((speedup, case_name))
214+
else:
215+
row[f"{p} speedup"] = "n/a"
216+
217+
detail_rows.append(row)
218+
219+
summary_rows: list[dict[str, str]] = []
220+
for p in non_native:
221+
entries = speedup_data[p]
222+
if not entries:
223+
continue
224+
speedups = [s for s, _ in entries]
225+
geomean = math.exp(sum(math.log(s) for s in speedups) / len(speedups))
226+
best_val, best_case = max(entries)
227+
worst_val, worst_case = min(entries)
228+
wins = sum(1 for s in speedups if s > 1.0)
229+
losses = sum(1 for s in speedups if s < 1.0)
230+
total = len(speedups)
231+
232+
print(f"\n{p} vs native ({wins}/{total} faster, {losses}/{total} slower):")
233+
print(f" geomean speedup: {geomean:.2f}x")
234+
print(f" best: {best_val:.2f}x ({best_case})")
235+
print(f" worst: {worst_val:.2f}x ({worst_case})")
236+
237+
summary_rows.append(
238+
{
239+
"op": op_name,
240+
"provider": p,
241+
"geomean_speedup": f"{geomean:.2f}",
242+
"best_speedup": f"{best_val:.2f}",
243+
"best_case": best_case,
244+
"worst_speedup": f"{worst_val:.2f}",
245+
"worst_case": worst_case,
246+
"wins": str(wins),
247+
"losses": str(losses),
248+
"total": str(total),
249+
}
250+
)
251+
252+
return detail_rows, summary_rows, header_cols
253+
254+
255+
def write_csv(path: str, rows: list[dict[str, str]], fieldnames: list[str]):
256+
with open(path, "w", newline="") as f:
257+
writer = csv.DictWriter(f, fieldnames=fieldnames)
258+
writer.writeheader()
259+
writer.writerows(rows)
260+
261+
262+
def save_results(
263+
save_dir: str,
264+
op_name: str,
265+
detail_rows: list[dict[str, str]],
266+
header_cols: list[str],
267+
all_summary_rows: list[dict[str, str]],
268+
metadata: dict[str, str],
269+
):
270+
write_csv(
271+
os.path.join(save_dir, f"{op_name}_detail.csv"),
272+
detail_rows,
273+
header_cols,
274+
)
275+
if all_summary_rows:
276+
write_csv(
277+
os.path.join(save_dir, "summary.csv"),
278+
all_summary_rows,
279+
list(all_summary_rows[0].keys()),
280+
)
281+
write_csv(
282+
os.path.join(save_dir, "metadata.csv"),
283+
[metadata],
284+
list(metadata.keys()),
285+
)
286+
287+
288+
def parse_args():
289+
parser = argparse.ArgumentParser(description="Benchmark vLLM IR ops")
290+
parser.add_argument(
291+
"--ops",
292+
type=str,
293+
default=None,
294+
help="Comma-separated list of op names to benchmark (substring match)",
295+
)
296+
parser.add_argument(
297+
"--no-cuda-graph",
298+
action="store_true",
299+
help="Disable CUDA graph; use do_bench with L2 cache flushing instead",
300+
)
301+
parser.add_argument(
302+
"--warmup",
303+
type=int,
304+
default=25,
305+
help="Warmup time in ms (do_bench) or ignored with CUDA graph (default: 25)",
306+
)
307+
parser.add_argument(
308+
"--rep",
309+
type=int,
310+
default=100,
311+
help="Repetition time in ms (do_bench) or number of graph replays "
312+
"(do_bench_cudagraph) (default: 100)",
313+
)
314+
parser.add_argument(
315+
"--save-path",
316+
type=str,
317+
default=None,
318+
help="Directory to save results (default: auto-created temp dir)",
319+
)
320+
return parser.parse_args()
321+
322+
323+
def main():
324+
args = parse_args()
325+
cfg = BenchConfig(
326+
use_cuda_graph=not args.no_cuda_graph,
327+
warmup=args.warmup,
328+
rep=args.rep,
329+
)
330+
331+
torch.set_default_device(current_platform.device_type)
332+
333+
metadata = collect_env_metadata(cfg)
334+
print_metadata(metadata)
335+
336+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
337+
save_dir = args.save_path or os.path.join(
338+
tempfile.gettempdir(), f"vllm_ir_bench_{timestamp}"
339+
)
340+
os.makedirs(save_dir, exist_ok=True)
341+
342+
op_filters = [f.strip() for f in args.ops.split(",")] if args.ops else None
343+
all_summary_rows: list[dict[str, str]] = []
344+
345+
for op in IrOp.registry.values():
346+
if op_filters and not any(f in op.name for f in op_filters):
347+
continue
348+
if not op.has_input_generator:
349+
print(f"Skipping op '{op.name}': no input generator registered")
350+
continue
351+
if op.name not in SHAPE_CONFIGS:
352+
raise RuntimeError(
353+
f"No benchmark shape config for op '{op.name}'. "
354+
f"Add it to benchmarks/kernels/ir/shapes.py"
355+
)
356+
357+
case_names, providers, results = collect_timings(
358+
op, SHAPE_CONFIGS[op.name], cfg
359+
)
360+
detail_rows, summary_rows, header_cols = analyze_results(
361+
op.name, case_names, providers, results
362+
)
363+
all_summary_rows.extend(summary_rows)
364+
365+
save_results(
366+
save_dir,
367+
op.name,
368+
detail_rows,
369+
header_cols,
370+
all_summary_rows,
371+
metadata,
372+
)
373+
374+
print(f"\nResults saved to: {save_dir}")
375+
376+
377+
if __name__ == "__main__":
378+
main()

0 commit comments

Comments
 (0)