Skip to content

Commit 86101f6

Browse files
srickettsclaude
andauthored
test: add coverage for all cli commands (#1848)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description For #1833 - adds coverage for all CLI commands in `flashinfer/__main__.py`. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes Do we want all tests in a single file or split by command? <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Expanded CLI test coverage with new real and mocked scenarios for config display, download/listing, export behavior, and cache/cubin management. * Added a GPU-focused CLI test suite for module-status and module listing, gated for GPU environments. * Introduced shared CLI test helpers and removed a redundant show-config test. * **Chores** * Test orchestration updated to run the new CLI test suites. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent bd0b27b commit 86101f6

6 files changed

Lines changed: 528 additions & 16 deletions

File tree

β€Žscripts/task_jit_run_tests_part5.shβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ fi
1515

1616
# Run each test file separately to isolate CUDA memory issues
1717
pytest -s tests/utils/test_logits_processor.py
18+
pytest -s tests/cli/test_cli_cmds.py
19+
pytest -s tests/cli/test_cli_cmds_gpu.py

β€Žtests/cli/__init__.pyβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# CLI test package

β€Žtests/cli/cli_cmd_helpers.pyβ€Ž

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from click.testing import CliRunner
2+
3+
from flashinfer.__main__ import cli
4+
5+
6+
def _test_cmd_helper(cmd: list[str]):
7+
"""
8+
Helper for command tests
9+
"""
10+
runner = CliRunner()
11+
result = runner.invoke(cli, cmd)
12+
assert result.exit_code == 0, result.output
13+
return result.output
14+
15+
16+
def _assert_output_contains_all(output, *expected_strings):
17+
"""Assert that output contains all expected strings."""
18+
missing = [s for s in expected_strings if s not in output]
19+
assert not missing, (
20+
f"Missing strings in output: {missing}\n\nActual output:\n{output}"
21+
)
22+
23+
24+
def _assert_output_contains_any(output, *expected_strings):
25+
"""Assert that output contains at least one of the expected strings."""
26+
found = any(s in output for s in expected_strings)
27+
assert found, (
28+
f"None of the expected strings were found in output: {expected_strings}\n\nActual output:\n{output}"
29+
)

β€Žtests/cli/test_cli_cmds.pyβ€Ž

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
"""
2+
Test that the CLI commands work as expected.
3+
4+
In general there can be two types of tests for each command:
5+
- Real tests (with suffix `_real`) that invoke the commands without any mocking
6+
- Mocked tests (with suffix `_mocked`) that use monkeypatch to mock logic that would
7+
otherwise be slow (e.g. downloading cubins, filesystem calls, etc), and also to
8+
create deterministic state so we can check for expected output (e.g. number of cubins)
9+
10+
These tests don't require a GPU. CLI tests that require a GPU are in test_cli_cmds_gpu.py.
11+
12+
Note: The `replay` command is tested in tests/utils/test_logging_replay.py alongside
13+
the other logging/replay functionality tests, since it's tightly coupled with that feature.
14+
"""
15+
16+
from .cli_cmd_helpers import (
17+
_test_cmd_helper,
18+
_assert_output_contains_all,
19+
_assert_output_contains_any,
20+
)
21+
from flashinfer.artifacts import ArtifactPath
22+
23+
24+
def test_show_config_cmd_real():
25+
"""
26+
Test that show-config command works as expected
27+
"""
28+
out = _test_cmd_helper(["show-config"])
29+
30+
# Basic sections present
31+
_assert_output_contains_all(
32+
out,
33+
"=== Torch Version Info ===",
34+
"=== Environment Variables ===",
35+
"=== Artifact Path ===",
36+
"=== Downloaded Cubins ===",
37+
)
38+
39+
40+
def test_show_config_cmd_mocked(monkeypatch):
41+
"""
42+
Test that show-config command works as but with mocked cubin status
43+
"""
44+
# Don't check filesystem for cubins
45+
monkeypatch.setattr(
46+
"flashinfer.__main__.get_artifacts_status",
47+
lambda: (("foo.cubin", True), ("bar.cubin", False)),
48+
)
49+
# Avoid module registration/inspection
50+
monkeypatch.setattr(
51+
"flashinfer.__main__._ensure_modules_registered",
52+
lambda: [],
53+
)
54+
55+
out = _test_cmd_helper(["show-config"])
56+
57+
# Uses our monkeypatched data
58+
assert "Downloaded 1/2 cubins" in out
59+
60+
61+
def test_cli_group_help_real():
62+
"""
63+
Test that the CLI group runs without error and sanity checks the output
64+
"""
65+
out = _test_cmd_helper([])
66+
_assert_output_contains_any(out, "FlashInfer CLI", "Usage")
67+
68+
69+
def test_download_cubin_flag_mocked(monkeypatch):
70+
# This just tests that the flag is parsed correctly, so we can monkeypatch
71+
# download_artifacts to avoid the latency of downloading cubins
72+
monkeypatch.setattr("flashinfer.__main__.download_artifacts", lambda: None)
73+
74+
out = _test_cmd_helper(["--download-cubin"])
75+
assert "All cubin download tasks completed successfully" in out
76+
77+
78+
def test_download_cubin_cmd_mocked(monkeypatch):
79+
"""
80+
Test that download-cubin can download a single cubin using a mocked cubin path
81+
"""
82+
# Return a real cubin path relative to the repository so it can be downloaded
83+
fmha_cubin = "fmhaSm100aKernel_QE4m3KvE2m1OE4m3H128PagedKvCausalP32VarSeqQ128Kv128PersistentContext.cubin"
84+
85+
# Mock get_subdir_file_list to return a list with (filename, checksum) tuples
86+
def mock_get_subdir_file_list():
87+
return [(f"{ArtifactPath.TRTLLM_GEN_FMHA}/{fmha_cubin}", "fake_checksum_12345")]
88+
89+
monkeypatch.setattr(
90+
"flashinfer.artifacts.get_subdir_file_list", mock_get_subdir_file_list
91+
)
92+
93+
# Mock download_file to avoid actual network calls
94+
monkeypatch.setattr(
95+
"flashinfer.artifacts.download_file", lambda *_args, **_kwargs: True
96+
)
97+
98+
# Mock verify_cubin to always return True
99+
monkeypatch.setattr("flashinfer.artifacts.verify_cubin", lambda *_args: True)
100+
101+
out = _test_cmd_helper(["--download-cubin"])
102+
assert "All cubin download tasks completed successfully" in out
103+
104+
105+
def test_list_cubins_cmd_real():
106+
out = _test_cmd_helper(["list-cubins"])
107+
108+
_assert_output_contains_all(out, "Cubin", "Status")
109+
110+
111+
def test_list_cubins_cmd_mocked(monkeypatch):
112+
monkeypatch.setattr(
113+
"flashinfer.__main__.get_artifacts_status",
114+
lambda: (("foo.cubin", True), ("bar.cubin", False)),
115+
)
116+
117+
out = _test_cmd_helper(["list-cubins"])
118+
_assert_output_contains_all(out, "foo.cubin", "bar.cubin")
119+
120+
121+
def test_clear_cache_cmd_mocked(monkeypatch):
122+
"""
123+
Test that clear-cache command works without actually clearing the cache.
124+
125+
This doesn't test much, just a basic sanity check.
126+
"""
127+
monkeypatch.setattr("flashinfer.__main__.clear_cache_dir", lambda: None)
128+
129+
out = _test_cmd_helper(["clear-cache"])
130+
assert "Cache cleared successfully" in out
131+
132+
133+
def test_clear_cache_cmd_real(monkeypatch, tmp_path):
134+
"""
135+
Test that clear-cache command actually clears the cache directory.
136+
137+
Uses a temporary directory to avoid side effects on the real cache.
138+
"""
139+
# Create a temporary JIT directory with some dummy cache files
140+
temp_jit_dir = tmp_path / "cached_ops"
141+
temp_jit_dir.mkdir(parents=True, exist_ok=True)
142+
143+
# Create some dummy cached files to simulate a real cache
144+
dummy_module_dir = temp_jit_dir / "test_module_abc123"
145+
dummy_module_dir.mkdir(parents=True, exist_ok=True)
146+
(dummy_module_dir / "test_module.so").write_text("dummy shared library")
147+
(dummy_module_dir / "build.ninja").write_text("dummy build file")
148+
149+
# Monkeypatch the FLASHINFER_JIT_DIR to point to our temp directory
150+
monkeypatch.setattr("flashinfer.jit.core.jit_env.FLASHINFER_JIT_DIR", temp_jit_dir)
151+
152+
# Verify the cache directory exists before clearing
153+
assert temp_jit_dir.exists()
154+
assert (dummy_module_dir / "test_module.so").exists()
155+
156+
# Run the clear-cache command
157+
out = _test_cmd_helper(["clear-cache"])
158+
assert "Cache cleared successfully" in out
159+
160+
# Verify the cache directory has been removed
161+
assert not temp_jit_dir.exists()
162+
163+
164+
def test_clear_cubin_cmd_mocked(monkeypatch):
165+
"""
166+
Test that clear-cubin command works without actually clearing the cubin.
167+
168+
This doesn't test much, just a basic sanity check.
169+
"""
170+
monkeypatch.setattr("flashinfer.__main__.clear_cubin", lambda: None)
171+
172+
out = _test_cmd_helper(["clear-cubin"])
173+
assert "Cubin cleared successfully" in out
174+
175+
176+
def test_clear_cubin_cmd_real(monkeypatch, tmp_path):
177+
"""
178+
Test that clear-cubin command actually clears the cubin directory.
179+
180+
Uses a temporary directory to avoid side effects on the real cubins.
181+
"""
182+
# Create a temporary cubin directory with some dummy cubin files
183+
temp_cubin_dir = tmp_path / "cubins"
184+
temp_cubin_dir.mkdir(parents=True, exist_ok=True)
185+
186+
# Create some dummy cubin files to simulate real cubins
187+
dummy_cubin_subdir = temp_cubin_dir / "trtllm_gen_fmha"
188+
dummy_cubin_subdir.mkdir(parents=True, exist_ok=True)
189+
(dummy_cubin_subdir / "test_kernel.cubin").write_text("dummy cubin data")
190+
(dummy_cubin_subdir / "checksums.txt").write_text("abc123 test_kernel.cubin")
191+
192+
# Monkeypatch FLASHINFER_CUBIN_DIR to point to our temp directory
193+
# Need to patch it in multiple places where it's imported
194+
monkeypatch.setattr("flashinfer.artifacts.FLASHINFER_CUBIN_DIR", temp_cubin_dir)
195+
monkeypatch.setattr(
196+
"flashinfer.jit.cubin_loader.FLASHINFER_CUBIN_DIR", temp_cubin_dir
197+
)
198+
199+
# Verify the cubin directory exists before clearing
200+
assert temp_cubin_dir.exists()
201+
assert (dummy_cubin_subdir / "test_kernel.cubin").exists()
202+
203+
# Run the clear-cubin command
204+
out = _test_cmd_helper(["clear-cubin"])
205+
assert "Cubin cleared successfully" in out
206+
207+
# Verify the cubin directory has been removed
208+
assert not temp_cubin_dir.exists()
209+
210+
211+
class MockJitSpec:
212+
"""Mock JitSpec for testing export-compile-commands."""
213+
214+
def __init__(self, name, compile_commands):
215+
self.name = name
216+
self._compile_commands = compile_commands
217+
218+
def get_compile_commands(self):
219+
return self._compile_commands
220+
221+
222+
def test_export_compile_commands_mocked(monkeypatch, tmp_path):
223+
"""
224+
Test that export-compile-commands writes correct JSON output.
225+
"""
226+
# Create mock specs with compile commands
227+
mock_specs = {
228+
"module_a": MockJitSpec(
229+
"module_a",
230+
[
231+
{
232+
"directory": "/path/to/build",
233+
"command": "nvcc -c kernel_a.cu",
234+
"file": "kernel_a.cu",
235+
}
236+
],
237+
),
238+
"module_b": MockJitSpec(
239+
"module_b",
240+
[
241+
{
242+
"directory": "/path/to/build",
243+
"command": "nvcc -c kernel_b.cu",
244+
"file": "kernel_b.cu",
245+
}
246+
],
247+
),
248+
}
249+
250+
monkeypatch.setattr("flashinfer.__main__._ensure_modules_registered", lambda: [])
251+
monkeypatch.setattr(
252+
"flashinfer.__main__.jit_spec_registry.get_all_specs", lambda: mock_specs
253+
)
254+
255+
# Use tmp_path to write output file
256+
output_file = tmp_path / "compile_commands.json"
257+
out = _test_cmd_helper(["export-compile-commands", str(output_file)])
258+
259+
assert "Successfully exported 2 compile commands" in out
260+
assert output_file.exists()
261+
262+
# Verify JSON content
263+
import json
264+
265+
with open(output_file) as f:
266+
commands = json.load(f)
267+
268+
assert len(commands) == 2
269+
assert commands[0]["file"] == "kernel_a.cu"
270+
assert commands[1]["file"] == "kernel_b.cu"
271+
272+
273+
def test_export_compile_commands_output_option(monkeypatch, tmp_path):
274+
"""
275+
Test that --output option overrides PATH argument.
276+
"""
277+
mock_specs = {
278+
"module_a": MockJitSpec(
279+
"module_a",
280+
[{"directory": "/build", "command": "nvcc -c a.cu", "file": "a.cu"}],
281+
),
282+
}
283+
284+
monkeypatch.setattr("flashinfer.__main__._ensure_modules_registered", lambda: [])
285+
monkeypatch.setattr(
286+
"flashinfer.__main__.jit_spec_registry.get_all_specs", lambda: mock_specs
287+
)
288+
289+
# PATH argument should be ignored when --output is specified
290+
output_file = tmp_path / "custom_output.json"
291+
ignored_file = tmp_path / "ignored.json"
292+
out = _test_cmd_helper(
293+
["export-compile-commands", str(ignored_file), "--output", str(output_file)]
294+
)
295+
296+
assert "Successfully exported 1 compile commands" in out
297+
assert output_file.exists()
298+
assert not ignored_file.exists()
299+
300+
301+
def test_export_compile_commands_no_modules(monkeypatch, tmp_path):
302+
"""
303+
Test that export-compile-commands handles empty module registry.
304+
"""
305+
monkeypatch.setattr("flashinfer.__main__._ensure_modules_registered", lambda: [])
306+
monkeypatch.setattr(
307+
"flashinfer.__main__.jit_spec_registry.get_all_specs", lambda: {}
308+
)
309+
310+
output_file = tmp_path / "compile_commands.json"
311+
out = _test_cmd_helper(["export-compile-commands", str(output_file)])
312+
313+
assert "No modules found" in out
314+
# File should not be created when no modules exist
315+
assert not output_file.exists()

0 commit comments

Comments
Β (0)