Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,72 @@ def fake_scan_file(path: str, config: dict[str, Any] | None = None) -> ScanResul
)


def test_scan_file_passes_shard_allowlist_to_advanced_handler(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
shard = tmp_path / "model-00001-of-00002.safetensors"
shard.write_bytes(b"inside-shard")
allowed_path = str(shard.resolve())
captured_allowed_paths: list[list[str] | None] = []

class DummyScanner:
name = "dummy"

def __init__(self, config: dict[str, Any] | None = None) -> None:
self.config = config or {}

def fake_select_preferred_scanner_id(path: str, header_format: str, ext: str) -> str | None:
assert path == str(shard)
assert isinstance(header_format, str)
assert ext == ".safetensors"
return None

def fake_get_scanner_for_path(path: str, **kwargs: Any) -> type[DummyScanner]:
assert path == str(shard)
assert kwargs == {"scanner_selection": None}
return DummyScanner

def fake_scan_advanced_large_file(
path: str,
scanner: DummyScanner,
progress_callback: Any,
timeout: int,
*,
allowed_shard_paths: list[str] | None = None,
) -> ScanResult:
assert path == str(shard)
assert progress_callback is None
assert timeout == 7200
captured_allowed_paths.append(allowed_shard_paths)
result = ScanResult(scanner_name=scanner.name)
result.bytes_scanned = shard.stat().st_size
result.finish(success=True)
return result

monkeypatch.setattr(core_module, "should_use_advanced_handler", lambda path: path == str(shard))
monkeypatch.setattr(core_module, "_select_preferred_scanner_id", fake_select_preferred_scanner_id)
monkeypatch.setattr(core_module._registry, "get_scanner_for_path", fake_get_scanner_for_path)
monkeypatch.setattr(core_module, "scan_advanced_large_file", fake_scan_advanced_large_file)

result = scan_file(
str(shard),
config={
"cache_scan_results": False,
core_module._SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY: {
"members": [
{"path": allowed_path, "content_hash": "sha256:inside"},
{"path": 123, "content_hash": "invalid"},
"not-a-member",
],
},
},
)

assert result.scanner_name == "dummy"
assert captured_allowed_paths == [[allowed_path]]


def test_directory_scan_reports_incomplete_sharded_model_family_once(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
Expand Down
Loading