|
| 1 | +""" |
| 2 | +Test that the CLI commands work as expected. |
| 3 | +
|
| 4 | +In general there can be two types of tests for each command: |
| 5 | +- Real tests (with suffix `_real`) that invoke the commands without any mocking |
| 6 | +- Mocked tests (with suffix `_mocked`) that use monkeypatch to mock logic that would |
| 7 | + otherwise be slow (e.g. downloading cubins, filesystem calls, etc), and also to |
| 8 | + create deterministic state so we can check for expected output (e.g. number of cubins) |
| 9 | +
|
| 10 | +These tests don't require a GPU. CLI tests that require a GPU are in test_cli_cmds_gpu.py. |
| 11 | +
|
| 12 | +Note: The `replay` command is tested in tests/utils/test_logging_replay.py alongside |
| 13 | +the other logging/replay functionality tests, since it's tightly coupled with that feature. |
| 14 | +""" |
| 15 | + |
| 16 | +from .cli_cmd_helpers import ( |
| 17 | + _test_cmd_helper, |
| 18 | + _assert_output_contains_all, |
| 19 | + _assert_output_contains_any, |
| 20 | +) |
| 21 | +from flashinfer.artifacts import ArtifactPath |
| 22 | + |
| 23 | + |
| 24 | +def test_show_config_cmd_real(): |
| 25 | + """ |
| 26 | + Test that show-config command works as expected |
| 27 | + """ |
| 28 | + out = _test_cmd_helper(["show-config"]) |
| 29 | + |
| 30 | + # Basic sections present |
| 31 | + _assert_output_contains_all( |
| 32 | + out, |
| 33 | + "=== Torch Version Info ===", |
| 34 | + "=== Environment Variables ===", |
| 35 | + "=== Artifact Path ===", |
| 36 | + "=== Downloaded Cubins ===", |
| 37 | + ) |
| 38 | + |
| 39 | + |
| 40 | +def test_show_config_cmd_mocked(monkeypatch): |
| 41 | + """ |
| 42 | + Test that show-config command works as but with mocked cubin status |
| 43 | + """ |
| 44 | + # Don't check filesystem for cubins |
| 45 | + monkeypatch.setattr( |
| 46 | + "flashinfer.__main__.get_artifacts_status", |
| 47 | + lambda: (("foo.cubin", True), ("bar.cubin", False)), |
| 48 | + ) |
| 49 | + # Avoid module registration/inspection |
| 50 | + monkeypatch.setattr( |
| 51 | + "flashinfer.__main__._ensure_modules_registered", |
| 52 | + lambda: [], |
| 53 | + ) |
| 54 | + |
| 55 | + out = _test_cmd_helper(["show-config"]) |
| 56 | + |
| 57 | + # Uses our monkeypatched data |
| 58 | + assert "Downloaded 1/2 cubins" in out |
| 59 | + |
| 60 | + |
| 61 | +def test_cli_group_help_real(): |
| 62 | + """ |
| 63 | + Test that the CLI group runs without error and sanity checks the output |
| 64 | + """ |
| 65 | + out = _test_cmd_helper([]) |
| 66 | + _assert_output_contains_any(out, "FlashInfer CLI", "Usage") |
| 67 | + |
| 68 | + |
| 69 | +def test_download_cubin_flag_mocked(monkeypatch): |
| 70 | + # This just tests that the flag is parsed correctly, so we can monkeypatch |
| 71 | + # download_artifacts to avoid the latency of downloading cubins |
| 72 | + monkeypatch.setattr("flashinfer.__main__.download_artifacts", lambda: None) |
| 73 | + |
| 74 | + out = _test_cmd_helper(["--download-cubin"]) |
| 75 | + assert "All cubin download tasks completed successfully" in out |
| 76 | + |
| 77 | + |
| 78 | +def test_download_cubin_cmd_mocked(monkeypatch): |
| 79 | + """ |
| 80 | + Test that download-cubin can download a single cubin using a mocked cubin path |
| 81 | + """ |
| 82 | + # Return a real cubin path relative to the repository so it can be downloaded |
| 83 | + fmha_cubin = "fmhaSm100aKernel_QE4m3KvE2m1OE4m3H128PagedKvCausalP32VarSeqQ128Kv128PersistentContext.cubin" |
| 84 | + |
| 85 | + # Mock get_subdir_file_list to return a list with (filename, checksum) tuples |
| 86 | + def mock_get_subdir_file_list(): |
| 87 | + return [(f"{ArtifactPath.TRTLLM_GEN_FMHA}/{fmha_cubin}", "fake_checksum_12345")] |
| 88 | + |
| 89 | + monkeypatch.setattr( |
| 90 | + "flashinfer.artifacts.get_subdir_file_list", mock_get_subdir_file_list |
| 91 | + ) |
| 92 | + |
| 93 | + # Mock download_file to avoid actual network calls |
| 94 | + monkeypatch.setattr( |
| 95 | + "flashinfer.artifacts.download_file", lambda *_args, **_kwargs: True |
| 96 | + ) |
| 97 | + |
| 98 | + # Mock verify_cubin to always return True |
| 99 | + monkeypatch.setattr("flashinfer.artifacts.verify_cubin", lambda *_args: True) |
| 100 | + |
| 101 | + out = _test_cmd_helper(["--download-cubin"]) |
| 102 | + assert "All cubin download tasks completed successfully" in out |
| 103 | + |
| 104 | + |
| 105 | +def test_list_cubins_cmd_real(): |
| 106 | + out = _test_cmd_helper(["list-cubins"]) |
| 107 | + |
| 108 | + _assert_output_contains_all(out, "Cubin", "Status") |
| 109 | + |
| 110 | + |
| 111 | +def test_list_cubins_cmd_mocked(monkeypatch): |
| 112 | + monkeypatch.setattr( |
| 113 | + "flashinfer.__main__.get_artifacts_status", |
| 114 | + lambda: (("foo.cubin", True), ("bar.cubin", False)), |
| 115 | + ) |
| 116 | + |
| 117 | + out = _test_cmd_helper(["list-cubins"]) |
| 118 | + _assert_output_contains_all(out, "foo.cubin", "bar.cubin") |
| 119 | + |
| 120 | + |
| 121 | +def test_clear_cache_cmd_mocked(monkeypatch): |
| 122 | + """ |
| 123 | + Test that clear-cache command works without actually clearing the cache. |
| 124 | +
|
| 125 | + This doesn't test much, just a basic sanity check. |
| 126 | + """ |
| 127 | + monkeypatch.setattr("flashinfer.__main__.clear_cache_dir", lambda: None) |
| 128 | + |
| 129 | + out = _test_cmd_helper(["clear-cache"]) |
| 130 | + assert "Cache cleared successfully" in out |
| 131 | + |
| 132 | + |
| 133 | +def test_clear_cache_cmd_real(monkeypatch, tmp_path): |
| 134 | + """ |
| 135 | + Test that clear-cache command actually clears the cache directory. |
| 136 | +
|
| 137 | + Uses a temporary directory to avoid side effects on the real cache. |
| 138 | + """ |
| 139 | + # Create a temporary JIT directory with some dummy cache files |
| 140 | + temp_jit_dir = tmp_path / "cached_ops" |
| 141 | + temp_jit_dir.mkdir(parents=True, exist_ok=True) |
| 142 | + |
| 143 | + # Create some dummy cached files to simulate a real cache |
| 144 | + dummy_module_dir = temp_jit_dir / "test_module_abc123" |
| 145 | + dummy_module_dir.mkdir(parents=True, exist_ok=True) |
| 146 | + (dummy_module_dir / "test_module.so").write_text("dummy shared library") |
| 147 | + (dummy_module_dir / "build.ninja").write_text("dummy build file") |
| 148 | + |
| 149 | + # Monkeypatch the FLASHINFER_JIT_DIR to point to our temp directory |
| 150 | + monkeypatch.setattr("flashinfer.jit.core.jit_env.FLASHINFER_JIT_DIR", temp_jit_dir) |
| 151 | + |
| 152 | + # Verify the cache directory exists before clearing |
| 153 | + assert temp_jit_dir.exists() |
| 154 | + assert (dummy_module_dir / "test_module.so").exists() |
| 155 | + |
| 156 | + # Run the clear-cache command |
| 157 | + out = _test_cmd_helper(["clear-cache"]) |
| 158 | + assert "Cache cleared successfully" in out |
| 159 | + |
| 160 | + # Verify the cache directory has been removed |
| 161 | + assert not temp_jit_dir.exists() |
| 162 | + |
| 163 | + |
| 164 | +def test_clear_cubin_cmd_mocked(monkeypatch): |
| 165 | + """ |
| 166 | + Test that clear-cubin command works without actually clearing the cubin. |
| 167 | +
|
| 168 | + This doesn't test much, just a basic sanity check. |
| 169 | + """ |
| 170 | + monkeypatch.setattr("flashinfer.__main__.clear_cubin", lambda: None) |
| 171 | + |
| 172 | + out = _test_cmd_helper(["clear-cubin"]) |
| 173 | + assert "Cubin cleared successfully" in out |
| 174 | + |
| 175 | + |
| 176 | +def test_clear_cubin_cmd_real(monkeypatch, tmp_path): |
| 177 | + """ |
| 178 | + Test that clear-cubin command actually clears the cubin directory. |
| 179 | +
|
| 180 | + Uses a temporary directory to avoid side effects on the real cubins. |
| 181 | + """ |
| 182 | + # Create a temporary cubin directory with some dummy cubin files |
| 183 | + temp_cubin_dir = tmp_path / "cubins" |
| 184 | + temp_cubin_dir.mkdir(parents=True, exist_ok=True) |
| 185 | + |
| 186 | + # Create some dummy cubin files to simulate real cubins |
| 187 | + dummy_cubin_subdir = temp_cubin_dir / "trtllm_gen_fmha" |
| 188 | + dummy_cubin_subdir.mkdir(parents=True, exist_ok=True) |
| 189 | + (dummy_cubin_subdir / "test_kernel.cubin").write_text("dummy cubin data") |
| 190 | + (dummy_cubin_subdir / "checksums.txt").write_text("abc123 test_kernel.cubin") |
| 191 | + |
| 192 | + # Monkeypatch FLASHINFER_CUBIN_DIR to point to our temp directory |
| 193 | + # Need to patch it in multiple places where it's imported |
| 194 | + monkeypatch.setattr("flashinfer.artifacts.FLASHINFER_CUBIN_DIR", temp_cubin_dir) |
| 195 | + monkeypatch.setattr( |
| 196 | + "flashinfer.jit.cubin_loader.FLASHINFER_CUBIN_DIR", temp_cubin_dir |
| 197 | + ) |
| 198 | + |
| 199 | + # Verify the cubin directory exists before clearing |
| 200 | + assert temp_cubin_dir.exists() |
| 201 | + assert (dummy_cubin_subdir / "test_kernel.cubin").exists() |
| 202 | + |
| 203 | + # Run the clear-cubin command |
| 204 | + out = _test_cmd_helper(["clear-cubin"]) |
| 205 | + assert "Cubin cleared successfully" in out |
| 206 | + |
| 207 | + # Verify the cubin directory has been removed |
| 208 | + assert not temp_cubin_dir.exists() |
| 209 | + |
| 210 | + |
| 211 | +class MockJitSpec: |
| 212 | + """Mock JitSpec for testing export-compile-commands.""" |
| 213 | + |
| 214 | + def __init__(self, name, compile_commands): |
| 215 | + self.name = name |
| 216 | + self._compile_commands = compile_commands |
| 217 | + |
| 218 | + def get_compile_commands(self): |
| 219 | + return self._compile_commands |
| 220 | + |
| 221 | + |
| 222 | +def test_export_compile_commands_mocked(monkeypatch, tmp_path): |
| 223 | + """ |
| 224 | + Test that export-compile-commands writes correct JSON output. |
| 225 | + """ |
| 226 | + # Create mock specs with compile commands |
| 227 | + mock_specs = { |
| 228 | + "module_a": MockJitSpec( |
| 229 | + "module_a", |
| 230 | + [ |
| 231 | + { |
| 232 | + "directory": "/path/to/build", |
| 233 | + "command": "nvcc -c kernel_a.cu", |
| 234 | + "file": "kernel_a.cu", |
| 235 | + } |
| 236 | + ], |
| 237 | + ), |
| 238 | + "module_b": MockJitSpec( |
| 239 | + "module_b", |
| 240 | + [ |
| 241 | + { |
| 242 | + "directory": "/path/to/build", |
| 243 | + "command": "nvcc -c kernel_b.cu", |
| 244 | + "file": "kernel_b.cu", |
| 245 | + } |
| 246 | + ], |
| 247 | + ), |
| 248 | + } |
| 249 | + |
| 250 | + monkeypatch.setattr("flashinfer.__main__._ensure_modules_registered", lambda: []) |
| 251 | + monkeypatch.setattr( |
| 252 | + "flashinfer.__main__.jit_spec_registry.get_all_specs", lambda: mock_specs |
| 253 | + ) |
| 254 | + |
| 255 | + # Use tmp_path to write output file |
| 256 | + output_file = tmp_path / "compile_commands.json" |
| 257 | + out = _test_cmd_helper(["export-compile-commands", str(output_file)]) |
| 258 | + |
| 259 | + assert "Successfully exported 2 compile commands" in out |
| 260 | + assert output_file.exists() |
| 261 | + |
| 262 | + # Verify JSON content |
| 263 | + import json |
| 264 | + |
| 265 | + with open(output_file) as f: |
| 266 | + commands = json.load(f) |
| 267 | + |
| 268 | + assert len(commands) == 2 |
| 269 | + assert commands[0]["file"] == "kernel_a.cu" |
| 270 | + assert commands[1]["file"] == "kernel_b.cu" |
| 271 | + |
| 272 | + |
| 273 | +def test_export_compile_commands_output_option(monkeypatch, tmp_path): |
| 274 | + """ |
| 275 | + Test that --output option overrides PATH argument. |
| 276 | + """ |
| 277 | + mock_specs = { |
| 278 | + "module_a": MockJitSpec( |
| 279 | + "module_a", |
| 280 | + [{"directory": "/build", "command": "nvcc -c a.cu", "file": "a.cu"}], |
| 281 | + ), |
| 282 | + } |
| 283 | + |
| 284 | + monkeypatch.setattr("flashinfer.__main__._ensure_modules_registered", lambda: []) |
| 285 | + monkeypatch.setattr( |
| 286 | + "flashinfer.__main__.jit_spec_registry.get_all_specs", lambda: mock_specs |
| 287 | + ) |
| 288 | + |
| 289 | + # PATH argument should be ignored when --output is specified |
| 290 | + output_file = tmp_path / "custom_output.json" |
| 291 | + ignored_file = tmp_path / "ignored.json" |
| 292 | + out = _test_cmd_helper( |
| 293 | + ["export-compile-commands", str(ignored_file), "--output", str(output_file)] |
| 294 | + ) |
| 295 | + |
| 296 | + assert "Successfully exported 1 compile commands" in out |
| 297 | + assert output_file.exists() |
| 298 | + assert not ignored_file.exists() |
| 299 | + |
| 300 | + |
| 301 | +def test_export_compile_commands_no_modules(monkeypatch, tmp_path): |
| 302 | + """ |
| 303 | + Test that export-compile-commands handles empty module registry. |
| 304 | + """ |
| 305 | + monkeypatch.setattr("flashinfer.__main__._ensure_modules_registered", lambda: []) |
| 306 | + monkeypatch.setattr( |
| 307 | + "flashinfer.__main__.jit_spec_registry.get_all_specs", lambda: {} |
| 308 | + ) |
| 309 | + |
| 310 | + output_file = tmp_path / "compile_commands.json" |
| 311 | + out = _test_cmd_helper(["export-compile-commands", str(output_file)]) |
| 312 | + |
| 313 | + assert "No modules found" in out |
| 314 | + # File should not be created when no modules exist |
| 315 | + assert not output_file.exists() |
0 commit comments