|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 4 | +""" |
| 5 | +Generic benchmark harness for vLLM IR ops. |
| 6 | +
|
| 7 | +Usage: |
| 8 | + python benchmarks/kernels/ir/bench_ir_ops.py |
| 9 | + python benchmarks/kernels/ir/bench_ir_ops.py --ops rms_norm |
| 10 | + python benchmarks/kernels/ir/bench_ir_ops.py --ops rms_norm,silu_mul |
| 11 | + python benchmarks/kernels/ir/bench_ir_ops.py --no-cuda-graph |
| 12 | + python benchmarks/kernels/ir/bench_ir_ops.py --ops rms_norm --save-path ./results/ |
| 13 | +""" |
| 14 | + |
| 15 | +import argparse |
| 16 | +import contextlib |
| 17 | +import csv |
| 18 | +import dataclasses |
| 19 | +import datetime |
| 20 | +import math |
| 21 | +import os |
| 22 | +import subprocess |
| 23 | +import sys |
| 24 | +import tempfile |
| 25 | + |
| 26 | +# Ensure repo root is on sys.path so `benchmarks` is importable as a package. |
| 27 | +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) |
| 28 | +if _REPO_ROOT not in sys.path: |
| 29 | + sys.path.insert(0, _REPO_ROOT) |
| 30 | + |
| 31 | +# Suppress noisy C++ warnings from vllm kernel registration (written to fd 2 |
| 32 | +# directly by the dynamic linker, so Python-level sys.stderr redirect won't |
| 33 | +# catch them). |
| 34 | +_saved_fd = os.dup(2) |
| 35 | +try: |
| 36 | + with open(os.devnull, "w") as _devnull: |
| 37 | + os.dup2(_devnull.fileno(), 2) |
| 38 | + import torch |
| 39 | + |
| 40 | + import vllm.kernels # noqa: E402, F401 |
| 41 | +finally: |
| 42 | + os.dup2(_saved_fd, 2) |
| 43 | + os.close(_saved_fd) |
| 44 | + |
| 45 | +from tqdm import tqdm # noqa: E402 |
| 46 | + |
| 47 | +from benchmarks.kernels.ir.shapes import SHAPE_CONFIGS # noqa: E402 # isort: skip |
| 48 | +from vllm.ir.op import IrOp # noqa: E402 |
| 49 | +from vllm.platforms import current_platform # noqa: E402 |
| 50 | +from vllm.triton_utils import triton # noqa: E402 |
| 51 | + |
| 52 | + |
| 53 | +@dataclasses.dataclass(frozen=True) |
| 54 | +class BenchConfig: |
| 55 | + use_cuda_graph: bool = True |
| 56 | + warmup: int = 25 |
| 57 | + rep: int = 100 |
| 58 | + |
| 59 | + |
| 60 | +def _pkg_version(name: str) -> str: |
| 61 | + from importlib.metadata import PackageNotFoundError, version |
| 62 | + |
| 63 | + with contextlib.suppress(PackageNotFoundError): |
| 64 | + return version(name) |
| 65 | + return "not installed" |
| 66 | + |
| 67 | + |
| 68 | +_METADATA_LABELS = { |
| 69 | + "timestamp": "Timestamp", |
| 70 | + "git_commit": "Git commit", |
| 71 | + "vllm": "vLLM", |
| 72 | + "pytorch": "PyTorch", |
| 73 | + "cuda_runtime": "CUDA runtime", |
| 74 | + "triton": "Triton", |
| 75 | + "cutlass": "CUTLASS", |
| 76 | + "helion": "Helion", |
| 77 | + "device": "Device", |
| 78 | + "bench_mode": "Bench mode", |
| 79 | + "warmup": "Warmup", |
| 80 | + "rep": "Repetitions", |
| 81 | +} |
| 82 | + |
| 83 | + |
| 84 | +def collect_env_metadata(cfg: BenchConfig) -> dict[str, str]: |
| 85 | + from vllm.collect_env import get_env_info |
| 86 | + |
| 87 | + env = get_env_info() |
| 88 | + |
| 89 | + git_sha = "unknown" |
| 90 | + with contextlib.suppress(subprocess.CalledProcessError, FileNotFoundError): |
| 91 | + git_sha = ( |
| 92 | + subprocess.check_output( |
| 93 | + ["git", "rev-parse", "--short", "HEAD"], stderr=subprocess.DEVNULL |
| 94 | + ) |
| 95 | + .decode() |
| 96 | + .strip() |
| 97 | + ) |
| 98 | + |
| 99 | + device_name = current_platform.get_device_name() |
| 100 | + |
| 101 | + warmup_note = " ms" if not cfg.use_cuda_graph else " ms (ignored)" |
| 102 | + rep_note = " replays" if cfg.use_cuda_graph else " ms" |
| 103 | + |
| 104 | + return { |
| 105 | + "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| 106 | + "git_commit": git_sha, |
| 107 | + "vllm": str(env.vllm_version), |
| 108 | + "pytorch": str(env.torch_version), |
| 109 | + "cuda_runtime": str(env.cuda_runtime_version), |
| 110 | + "triton": triton.__version__, |
| 111 | + "cutlass": _pkg_version("nvidia-cutlass-dsl"), |
| 112 | + "helion": _pkg_version("helion"), |
| 113 | + "device": device_name, |
| 114 | + "bench_mode": "cuda_graph" if cfg.use_cuda_graph else "eager", |
| 115 | + "warmup": f"{cfg.warmup}{warmup_note}", |
| 116 | + "rep": f"{cfg.rep}{rep_note}", |
| 117 | + } |
| 118 | + |
| 119 | + |
| 120 | +def print_metadata(metadata: dict[str, str]): |
| 121 | + print("=" * 60) |
| 122 | + for key, val in metadata.items(): |
| 123 | + print(f"{_METADATA_LABELS.get(key, key) + ':':<16}{val}") |
| 124 | + print("=" * 60) |
| 125 | + |
| 126 | + |
| 127 | +def _clone_args(args: tuple) -> tuple: |
| 128 | + return tuple(a.clone() if isinstance(a, torch.Tensor) else a for a in args) |
| 129 | + |
| 130 | + |
| 131 | +# TODO(gmagogsfm): When the `maybe_inplace` PR lands, ops marked as |
| 132 | +# inplace=True will mutate bench_args across iterations. Both CUDA graph |
| 133 | +# and eager modes will accumulate drift from repeated in-place mutation. |
| 134 | +# We need to re-clone inputs per iteration for inplace ops. |
| 135 | +def _bench_one(fn, args, cfg: BenchConfig) -> float: |
| 136 | + bench_args = _clone_args(args) |
| 137 | + bench_fn = lambda: fn(*bench_args) |
| 138 | + |
| 139 | + if cfg.use_cuda_graph: |
| 140 | + ms = triton.testing.do_bench_cudagraph(bench_fn, rep=cfg.rep, quantiles=[0.5]) |
| 141 | + else: |
| 142 | + ms = triton.testing.do_bench( |
| 143 | + bench_fn, warmup=cfg.warmup, rep=cfg.rep, quantiles=[0.5] |
| 144 | + ) |
| 145 | + return ms * 1000 |
| 146 | + |
| 147 | + |
| 148 | +# TODO(gmagogsfm): Once compiled native implementation lands (#38775), |
| 149 | +# the benchmark baseline should be the compiled native (what vLLM runs by |
| 150 | +# default) rather than the uncompiled native implementation. |
| 151 | +def collect_timings( |
| 152 | + op: IrOp, shape_configs: list[dict], cfg: BenchConfig |
| 153 | +) -> tuple[list[str], list[str], dict[str, dict[str, float]]]: |
| 154 | + def fmt(v) -> str: |
| 155 | + return str(v).split(".")[-1] if isinstance(v, torch.dtype) else str(v) |
| 156 | + |
| 157 | + case_names = [ |
| 158 | + "_".join(f"{k}={fmt(v)}" for k, v in kwargs.items()) for kwargs in shape_configs |
| 159 | + ] |
| 160 | + providers = [n for n, impl in op.impls.items() if impl.supported] |
| 161 | + |
| 162 | + results: dict[str, dict[str, float]] = {c: {} for c in case_names} |
| 163 | + for provider in providers: |
| 164 | + impl = op.impls[provider] |
| 165 | + desc = f"{op.name} / {provider}" |
| 166 | + for case_name, kwargs in tqdm( |
| 167 | + zip(case_names, shape_configs), |
| 168 | + desc=desc, |
| 169 | + total=len(case_names), |
| 170 | + unit=" cases", |
| 171 | + ): |
| 172 | + args = op.generate_inputs(**kwargs) |
| 173 | + if impl.supports_args(*args): |
| 174 | + results[case_name][provider] = _bench_one(impl.impl_fn, args, cfg) |
| 175 | + else: |
| 176 | + results[case_name][provider] = float("nan") |
| 177 | + |
| 178 | + return case_names, providers, results |
| 179 | + |
| 180 | + |
| 181 | +def analyze_results( |
| 182 | + op_name: str, |
| 183 | + case_names: list[str], |
| 184 | + providers: list[str], |
| 185 | + results: dict[str, dict[str, float]], |
| 186 | +) -> tuple[list[dict[str, str]], list[dict[str, str]], list[str]]: |
| 187 | + native_col = "native" |
| 188 | + non_native = [p for p in providers if p != native_col] |
| 189 | + |
| 190 | + header_cols = ["case"] |
| 191 | + for p in providers: |
| 192 | + header_cols.append(f"{p} (us)") |
| 193 | + for p in non_native: |
| 194 | + header_cols.append(f"{p} speedup") |
| 195 | + |
| 196 | + detail_rows: list[dict[str, str]] = [] |
| 197 | + speedup_data: dict[str, list[tuple[float, str]]] = {p: [] for p in non_native} |
| 198 | + |
| 199 | + for case_name in case_names: |
| 200 | + timings = results[case_name] |
| 201 | + row: dict[str, str] = {"case": case_name} |
| 202 | + |
| 203 | + for p in providers: |
| 204 | + val = timings.get(p, float("nan")) |
| 205 | + row[f"{p} (us)"] = f"{val:.2f}" if not math.isnan(val) else "n/a" |
| 206 | + |
| 207 | + native_us = timings.get(native_col, float("nan")) |
| 208 | + for p in non_native: |
| 209 | + p_us = timings.get(p, float("nan")) |
| 210 | + if not math.isnan(native_us) and not math.isnan(p_us) and p_us > 0: |
| 211 | + speedup = native_us / p_us |
| 212 | + row[f"{p} speedup"] = f"{speedup:.2f}x" |
| 213 | + speedup_data[p].append((speedup, case_name)) |
| 214 | + else: |
| 215 | + row[f"{p} speedup"] = "n/a" |
| 216 | + |
| 217 | + detail_rows.append(row) |
| 218 | + |
| 219 | + summary_rows: list[dict[str, str]] = [] |
| 220 | + for p in non_native: |
| 221 | + entries = speedup_data[p] |
| 222 | + if not entries: |
| 223 | + continue |
| 224 | + speedups = [s for s, _ in entries] |
| 225 | + geomean = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) |
| 226 | + best_val, best_case = max(entries) |
| 227 | + worst_val, worst_case = min(entries) |
| 228 | + wins = sum(1 for s in speedups if s > 1.0) |
| 229 | + losses = sum(1 for s in speedups if s < 1.0) |
| 230 | + total = len(speedups) |
| 231 | + |
| 232 | + print(f"\n{p} vs native ({wins}/{total} faster, {losses}/{total} slower):") |
| 233 | + print(f" geomean speedup: {geomean:.2f}x") |
| 234 | + print(f" best: {best_val:.2f}x ({best_case})") |
| 235 | + print(f" worst: {worst_val:.2f}x ({worst_case})") |
| 236 | + |
| 237 | + summary_rows.append( |
| 238 | + { |
| 239 | + "op": op_name, |
| 240 | + "provider": p, |
| 241 | + "geomean_speedup": f"{geomean:.2f}", |
| 242 | + "best_speedup": f"{best_val:.2f}", |
| 243 | + "best_case": best_case, |
| 244 | + "worst_speedup": f"{worst_val:.2f}", |
| 245 | + "worst_case": worst_case, |
| 246 | + "wins": str(wins), |
| 247 | + "losses": str(losses), |
| 248 | + "total": str(total), |
| 249 | + } |
| 250 | + ) |
| 251 | + |
| 252 | + return detail_rows, summary_rows, header_cols |
| 253 | + |
| 254 | + |
| 255 | +def write_csv(path: str, rows: list[dict[str, str]], fieldnames: list[str]): |
| 256 | + with open(path, "w", newline="") as f: |
| 257 | + writer = csv.DictWriter(f, fieldnames=fieldnames) |
| 258 | + writer.writeheader() |
| 259 | + writer.writerows(rows) |
| 260 | + |
| 261 | + |
| 262 | +def save_results( |
| 263 | + save_dir: str, |
| 264 | + op_name: str, |
| 265 | + detail_rows: list[dict[str, str]], |
| 266 | + header_cols: list[str], |
| 267 | + all_summary_rows: list[dict[str, str]], |
| 268 | + metadata: dict[str, str], |
| 269 | +): |
| 270 | + write_csv( |
| 271 | + os.path.join(save_dir, f"{op_name}_detail.csv"), |
| 272 | + detail_rows, |
| 273 | + header_cols, |
| 274 | + ) |
| 275 | + if all_summary_rows: |
| 276 | + write_csv( |
| 277 | + os.path.join(save_dir, "summary.csv"), |
| 278 | + all_summary_rows, |
| 279 | + list(all_summary_rows[0].keys()), |
| 280 | + ) |
| 281 | + write_csv( |
| 282 | + os.path.join(save_dir, "metadata.csv"), |
| 283 | + [metadata], |
| 284 | + list(metadata.keys()), |
| 285 | + ) |
| 286 | + |
| 287 | + |
| 288 | +def parse_args(): |
| 289 | + parser = argparse.ArgumentParser(description="Benchmark vLLM IR ops") |
| 290 | + parser.add_argument( |
| 291 | + "--ops", |
| 292 | + type=str, |
| 293 | + default=None, |
| 294 | + help="Comma-separated list of op names to benchmark (substring match)", |
| 295 | + ) |
| 296 | + parser.add_argument( |
| 297 | + "--no-cuda-graph", |
| 298 | + action="store_true", |
| 299 | + help="Disable CUDA graph; use do_bench with L2 cache flushing instead", |
| 300 | + ) |
| 301 | + parser.add_argument( |
| 302 | + "--warmup", |
| 303 | + type=int, |
| 304 | + default=25, |
| 305 | + help="Warmup time in ms (do_bench) or ignored with CUDA graph (default: 25)", |
| 306 | + ) |
| 307 | + parser.add_argument( |
| 308 | + "--rep", |
| 309 | + type=int, |
| 310 | + default=100, |
| 311 | + help="Repetition time in ms (do_bench) or number of graph replays " |
| 312 | + "(do_bench_cudagraph) (default: 100)", |
| 313 | + ) |
| 314 | + parser.add_argument( |
| 315 | + "--save-path", |
| 316 | + type=str, |
| 317 | + default=None, |
| 318 | + help="Directory to save results (default: auto-created temp dir)", |
| 319 | + ) |
| 320 | + return parser.parse_args() |
| 321 | + |
| 322 | + |
| 323 | +def main(): |
| 324 | + args = parse_args() |
| 325 | + cfg = BenchConfig( |
| 326 | + use_cuda_graph=not args.no_cuda_graph, |
| 327 | + warmup=args.warmup, |
| 328 | + rep=args.rep, |
| 329 | + ) |
| 330 | + |
| 331 | + torch.set_default_device(current_platform.device_type) |
| 332 | + |
| 333 | + metadata = collect_env_metadata(cfg) |
| 334 | + print_metadata(metadata) |
| 335 | + |
| 336 | + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| 337 | + save_dir = args.save_path or os.path.join( |
| 338 | + tempfile.gettempdir(), f"vllm_ir_bench_{timestamp}" |
| 339 | + ) |
| 340 | + os.makedirs(save_dir, exist_ok=True) |
| 341 | + |
| 342 | + op_filters = [f.strip() for f in args.ops.split(",")] if args.ops else None |
| 343 | + all_summary_rows: list[dict[str, str]] = [] |
| 344 | + |
| 345 | + for op in IrOp.registry.values(): |
| 346 | + if op_filters and not any(f in op.name for f in op_filters): |
| 347 | + continue |
| 348 | + if not op.has_input_generator: |
| 349 | + print(f"Skipping op '{op.name}': no input generator registered") |
| 350 | + continue |
| 351 | + if op.name not in SHAPE_CONFIGS: |
| 352 | + raise RuntimeError( |
| 353 | + f"No benchmark shape config for op '{op.name}'. " |
| 354 | + f"Add it to benchmarks/kernels/ir/shapes.py" |
| 355 | + ) |
| 356 | + |
| 357 | + case_names, providers, results = collect_timings( |
| 358 | + op, SHAPE_CONFIGS[op.name], cfg |
| 359 | + ) |
| 360 | + detail_rows, summary_rows, header_cols = analyze_results( |
| 361 | + op.name, case_names, providers, results |
| 362 | + ) |
| 363 | + all_summary_rows.extend(summary_rows) |
| 364 | + |
| 365 | + save_results( |
| 366 | + save_dir, |
| 367 | + op.name, |
| 368 | + detail_rows, |
| 369 | + header_cols, |
| 370 | + all_summary_rows, |
| 371 | + metadata, |
| 372 | + ) |
| 373 | + |
| 374 | + print(f"\nResults saved to: {save_dir}") |
| 375 | + |
| 376 | + |
| 377 | +if __name__ == "__main__": |
| 378 | + main() |
0 commit comments