Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Bug Fixes

- avoid repeatedly scanning sharded model families during directory scans

## [0.2.45](https://github.com/promptfoo/modelaudit/compare/v0.2.44...v0.2.45) (2026-05-03)

### Bug Fixes
Expand Down
167 changes: 130 additions & 37 deletions modelaudit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
validate_file_type_with_formats,
)
from modelaudit.utils.file.handlers import (
ShardedModelDetector,
scan_advanced_large_file,
should_use_advanced_handler,
)
Expand Down Expand Up @@ -103,6 +104,9 @@ def _count_files_up_to(path: Path, limit: int) -> int | None:
_XGBOOST_PICKLE_SPOOF_REASON = "xgboost_binary_pickle_spoof"
_RECOGNIZED_FORMAT_SCANNER_UNAVAILABLE_REASON = "recognized_format_scanner_unavailable"
_XML_MODEL_ROUTING_INCOMPLETE_REASON = "xml_model_routing_incomplete"
_ShardFamilyKey = tuple[str, str, int | None]
_ScanEntry = tuple[str, list[str], _ShardFamilyKey | None]
_SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY = "shard_family_cache_fingerprint"


def _start_phase_timing(phase_timings: dict[str, float] | None) -> float | None:
Expand All @@ -124,6 +128,44 @@ def _attach_phase_timings(results: ModelAuditResultModel, phase_timings: dict[st
results.phase_timings = phase_timings # type: ignore[attr-defined]


def _shard_family_key_for_path(path: str) -> _ShardFamilyKey | None:
"""Return a stable key for files that belong to the same local shard family."""
path_obj = Path(path)
shard_match = ShardedModelDetector.match_shard_filename(path_obj.name)
if shard_match is None:
return None

pattern = shard_match.get("pattern")
if not isinstance(pattern, str):
return None

expected_total = shard_match.get("expected_total_shards")
if not isinstance(expected_total, int):
expected_total = None

return (str(path_obj.parent), pattern, expected_total)


def _build_shard_family_cache_fingerprint(
shard_family_key: _ShardFamilyKey,
scanned_file_paths: list[str],
content_hashes: dict[str, str],
) -> dict[str, Any]:
"""Fingerprint every present shard so representative cache entries stay valid."""
_parent_dir, pattern, expected_total_shards = shard_family_key
return {
"pattern": pattern,
"expected_total_shards": expected_total_shards,
"members": [
{
"path": str(Path(scanned_file_path).resolve()),
"content_hash": content_hashes.get(scanned_file_path),
}
for scanned_file_path in sorted(scanned_file_paths)
],
}


def _select_preferred_scanner_id(path: str, header_format: str, ext: str) -> str | None:
"""Select a scanner by trusted file structure, not just suffix."""
if header_format == "zip":
Expand Down Expand Up @@ -536,6 +578,8 @@ def scan_model_directory_or_file(

# First pass: collect all file paths that need scanning
files_to_scan: list[str] = []
shard_family_representatives: dict[_ShardFamilyKey, str] = {}
shard_family_paths: dict[_ShardFamilyKey, set[str]] = {}
directory_discovery_started_at = _start_phase_timing(phase_timings)
for root, _, files in os.walk(path, followlinks=False):
for file in files:
Expand Down Expand Up @@ -609,29 +653,49 @@ def scan_model_directory_or_file(
continue

# Add to files to scan list instead of scanning immediately
shard_family_key = _shard_family_key_for_path(target_str)
if shard_family_key is not None:
family_paths = shard_family_paths.setdefault(shard_family_key, set())
family_paths.add(target_str)
if shard_family_key not in shard_family_representatives:
shard_family_representatives[shard_family_key] = target_str
shard_info = ShardedModelDetector.detect_shards(target_str)
if shard_info is not None:
for shard_path in shard_info.get("shards", []):
if isinstance(shard_path, str):
family_paths.add(str(Path(shard_path).resolve()))
continue

files_to_scan.append(target_str)
_finish_phase_timing(phase_timings, "directory_discovery", directory_discovery_started_at)

# Second pass: scan every path independently. Some scanners depend on
# parent paths or sibling files, so content equality alone is not a
# safe proxy for scan-result equality.
if files_to_scan:
scan_entries: list[_ScanEntry] = [(file_path, [file_path], None) for file_path in files_to_scan]
for shard_family_key, representative_file in shard_family_representatives.items():
ordered_family_paths = sorted(shard_family_paths.get(shard_family_key, {representative_file}))
scan_entries.append((representative_file, ordered_family_paths, shard_family_key))

# Second pass: scan every non-shard path independently and every shard
# family once. Shard scans already expand to sibling shards in the
# advanced handler, so scanning each shard path would duplicate work.
if scan_entries:
hash_paths: list[str] = []
seen_hash_paths: set[str] = set()
for _representative_file, scanned_file_paths, _shard_family_key in scan_entries:
for scanned_file_path in scanned_file_paths:
if scanned_file_path not in seen_hash_paths:
hash_paths.append(scanned_file_path)
seen_hash_paths.add(scanned_file_path)

top_level_hashing_started_at = _start_phase_timing(phase_timings)
content_hashes = _hash_files_by_path(files_to_scan)
content_hashes = _hash_files_by_path(hash_paths)
_finish_phase_timing(phase_timings, "top_level_hashing", top_level_hashing_started_at)
duplicate_paths_by_hash: dict[str, list[str]] = {}
for file_path, content_hash in content_hashes.items():
if not content_hash.startswith("unhashable_"):
duplicate_paths_by_hash.setdefault(content_hash, []).append(file_path)
recorded_content_hashes: set[str] = set()

for representative_file, content_hash in content_hashes.items():
# Collect valid content hashes for aggregate hash computation
# Skip "unhashable_" prefix entries (those are placeholder hashes for files that failed to hash)
if not content_hash.startswith("unhashable_") and content_hash not in recorded_content_hashes:
file_hashes.append(content_hash)
recorded_content_hashes.add(content_hash)

for representative_file, scanned_file_paths, shard_family_key in scan_entries:
# Check for interrupts
check_interrupted()

Expand All @@ -641,14 +705,17 @@ def scan_model_directory_or_file(

# Update progress
if progress_callback:
scan_label = Path(representative_file).name
if len(scanned_file_paths) > 1:
scan_label = f"{scan_label} ({len(scanned_file_paths)} shards)"
if total_files is not None and total_files > 0:
progress_callback(
f"Scanning file {processed_files + 1}/{total_files}: {Path(representative_file).name}",
f"Scanning file {processed_files + 1}/{total_files}: {scan_label}",
processed_files / total_files * 100,
)
else:
progress_callback(
f"Scanning file {processed_files + 1}: {Path(representative_file).name}",
f"Scanning file {processed_files + 1}: {scan_label}",
0.0,
)

Expand All @@ -658,7 +725,17 @@ def scan_model_directory_or_file(

file_scan_started_at = _start_phase_timing(phase_timings)
try:
file_result = scan_file(representative_file, config)
file_config = config
if shard_family_key is not None:
file_config = dict(config)
file_config[_SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY] = (
_build_shard_family_cache_fingerprint(
shard_family_key,
scanned_file_paths,
content_hashes,
)
)
file_result = scan_file(representative_file, file_config)
finally:
_finish_phase_timing(phase_timings, "file_scan_dispatch", file_scan_started_at)

Expand All @@ -667,8 +744,17 @@ def scan_model_directory_or_file(
if _scan_result_has_operational_error(file_result):
scan_metadata["has_operational_errors"] = True
results.bytes_scanned += file_result.bytes_scanned
results.files_scanned += 1
processed_files += 1
results.files_scanned += len(scanned_file_paths)
processed_files += len(scanned_file_paths)
for scanned_file_path in scanned_file_paths:
path_content_hash = content_hashes.get(scanned_file_path)
if (
path_content_hash is not None
and not path_content_hash.startswith("unhashable_")
and path_content_hash not in recorded_content_hashes
):
file_hashes.append(path_content_hash)
recorded_content_hashes.add(path_content_hash)

# Add scanner to tracking list (different from scanner_names)
scanner_name = file_result.scanner_name
Expand Down Expand Up @@ -734,32 +820,38 @@ def scan_model_directory_or_file(

results.checks.append(Check(**check_dict))

_add_asset_to_results(results, representative_file, file_result)
for scanned_file_path in scanned_file_paths:
_add_asset_to_results(results, scanned_file_path, file_result)
_finish_phase_timing(phase_timings, "result_merge", result_merge_started_at)

# Add metadata for this path using Pydantic models
license_metadata_started_at = _start_phase_timing(phase_timings)
try:
license_metadata = collect_license_metadata(
representative_file,
nearby_license_cache=nearby_license_cache,
)
finally:
_finish_phase_timing(phase_timings, "license_metadata", license_metadata_started_at)
combined_metadata = {**file_result.metadata, **license_metadata}
combined_metadata["content_hash"] = content_hash
duplicate_files = duplicate_paths_by_hash.get(content_hash, [])
combined_metadata["duplicate_files"] = duplicate_files if len(duplicate_files) > 1 else None
from .models import FileMetadataModel

# Convert ml_context if present
if "ml_context" in combined_metadata and isinstance(combined_metadata["ml_context"], dict):
from .models import MLContextModel
for scanned_file_path in scanned_file_paths:
license_metadata_started_at = _start_phase_timing(phase_timings)
try:
license_metadata = collect_license_metadata(
scanned_file_path,
nearby_license_cache=nearby_license_cache,
)
finally:
_finish_phase_timing(phase_timings, "license_metadata", license_metadata_started_at)
combined_metadata = {**file_result.metadata, **license_metadata}
path_content_hash = content_hashes.get(scanned_file_path)
if path_content_hash is not None:
combined_metadata["content_hash"] = path_content_hash
duplicate_files = duplicate_paths_by_hash.get(path_content_hash, [])
combined_metadata["duplicate_files"] = (
duplicate_files if len(duplicate_files) > 1 else None
)

combined_metadata["ml_context"] = MLContextModel(**combined_metadata["ml_context"])
# Convert ml_context if present
if "ml_context" in combined_metadata and isinstance(combined_metadata["ml_context"], dict):
from .models import MLContextModel

from .models import FileMetadataModel
combined_metadata["ml_context"] = MLContextModel(**combined_metadata["ml_context"])

results.file_metadata[representative_file] = FileMetadataModel(**combined_metadata)
results.file_metadata[scanned_file_path] = FileMetadataModel(**combined_metadata)

if max_total_size > 0 and results.bytes_scanned > max_total_size:
_add_issue_to_model(
Expand All @@ -783,7 +875,8 @@ def scan_model_directory_or_file(
location=representative_file,
details={"exception_type": type(e).__name__},
)
_add_error_asset_to_results(results, representative_file)
for scanned_file_path in scanned_file_paths:
_add_error_asset_to_results(results, scanned_file_path)

# Final progress update for directory scan
if progress_callback and not limit_reached and total_files is not None and total_files > 0:
Expand Down
25 changes: 25 additions & 0 deletions modelaudit/utils/file/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,31 @@ class ShardedModelDetector:
r"params_shard_(\d+)\.bin", # Custom parameter sharding
]

@classmethod
def match_shard_filename(cls, file_name: str) -> dict[str, int | str | None] | None:
"""Return shard metadata for a filename when it matches a known shard pattern."""
for pattern in cls.SHARD_PATTERNS:
match = re.fullmatch(pattern, file_name)
if not match:
continue

current_shard_index: int | None = None
expected_total_shards: int | None = None
if match.lastindex:
with suppress(IndexError, ValueError):
current_shard_index = int(match.group(1))
if (match.lastindex or 0) >= 2:
with suppress(IndexError, ValueError):
expected_total_shards = int(match.group(2))

return {
"pattern": pattern,
"current_shard_index": current_shard_index,
"expected_total_shards": expected_total_shards,
}

return None

@classmethod
def detect_shards(cls, file_path: str) -> dict[str, Any] | None:
"""
Expand Down
Loading
Loading