Skip to content

Commit ce976b1

Browse files
committed
fix: avoid duplicate sharded directory scans
1 parent 5a03091 commit ce976b1

4 files changed

Lines changed: 274 additions & 33 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [Unreleased]
9+
10+
### Bug Fixes
11+
12+
- avoid repeatedly scanning sharded model families during directory scans
13+
814
## [0.2.45](https://github.com/promptfoo/modelaudit/compare/v0.2.44...v0.2.45) (2026-05-03)
915

1016
### Bug Fixes

modelaudit/core.py

Lines changed: 122 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
validate_file_type_with_formats,
4444
)
4545
from modelaudit.utils.file.handlers import (
46+
ShardedModelDetector,
4647
scan_advanced_large_file,
4748
should_use_advanced_handler,
4849
)
@@ -103,6 +104,9 @@ def _count_files_up_to(path: Path, limit: int) -> int | None:
103104
_XGBOOST_PICKLE_SPOOF_REASON = "xgboost_binary_pickle_spoof"
104105
_RECOGNIZED_FORMAT_SCANNER_UNAVAILABLE_REASON = "recognized_format_scanner_unavailable"
105106
_XML_MODEL_ROUTING_INCOMPLETE_REASON = "xml_model_routing_incomplete"
107+
_ShardFamilyKey = tuple[str, str, int | None]
108+
_ScanEntry = tuple[str, list[str], _ShardFamilyKey | None]
109+
_SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY = "shard_family_cache_fingerprint"
106110

107111

108112
def _start_phase_timing(phase_timings: dict[str, float] | None) -> float | None:
@@ -124,6 +128,44 @@ def _attach_phase_timings(results: ModelAuditResultModel, phase_timings: dict[st
124128
results.phase_timings = phase_timings # type: ignore[attr-defined]
125129

126130

131+
def _shard_family_key_for_path(path: str) -> _ShardFamilyKey | None:
132+
"""Return a stable key for files that belong to the same local shard family."""
133+
path_obj = Path(path)
134+
shard_match = ShardedModelDetector.match_shard_filename(path_obj.name)
135+
if shard_match is None:
136+
return None
137+
138+
pattern = shard_match.get("pattern")
139+
if not isinstance(pattern, str):
140+
return None
141+
142+
expected_total = shard_match.get("expected_total_shards")
143+
if not isinstance(expected_total, int):
144+
expected_total = None
145+
146+
return (str(path_obj.parent), pattern, expected_total)
147+
148+
149+
def _build_shard_family_cache_fingerprint(
150+
shard_family_key: _ShardFamilyKey,
151+
scanned_file_paths: list[str],
152+
content_hashes: dict[str, str],
153+
) -> dict[str, Any]:
154+
"""Fingerprint every present shard so representative cache entries stay valid."""
155+
_parent_dir, pattern, expected_total_shards = shard_family_key
156+
return {
157+
"pattern": pattern,
158+
"expected_total_shards": expected_total_shards,
159+
"members": [
160+
{
161+
"path": str(Path(scanned_file_path).resolve()),
162+
"content_hash": content_hashes.get(scanned_file_path),
163+
}
164+
for scanned_file_path in sorted(scanned_file_paths)
165+
],
166+
}
167+
168+
127169
def _select_preferred_scanner_id(path: str, header_format: str, ext: str) -> str | None:
128170
"""Select a scanner by trusted file structure, not just suffix."""
129171
if header_format == "zip":
@@ -536,6 +578,8 @@ def scan_model_directory_or_file(
536578

537579
# First pass: collect all file paths that need scanning
538580
files_to_scan: list[str] = []
581+
shard_family_representatives: dict[_ShardFamilyKey, str] = {}
582+
shard_family_paths: dict[_ShardFamilyKey, set[str]] = {}
539583
directory_discovery_started_at = _start_phase_timing(phase_timings)
540584
for root, _, files in os.walk(path, followlinks=False):
541585
for file in files:
@@ -609,29 +653,54 @@ def scan_model_directory_or_file(
609653
continue
610654

611655
# Add to files to scan list instead of scanning immediately
656+
shard_family_key = _shard_family_key_for_path(target_str)
657+
if shard_family_key is not None:
658+
family_paths = shard_family_paths.setdefault(shard_family_key, set())
659+
family_paths.add(target_str)
660+
if shard_family_key not in shard_family_representatives:
661+
shard_family_representatives[shard_family_key] = target_str
662+
shard_info = ShardedModelDetector.detect_shards(target_str)
663+
if shard_info is not None:
664+
for shard_path in shard_info.get("shards", []):
665+
if isinstance(shard_path, str):
666+
family_paths.add(str(Path(shard_path).resolve()))
667+
continue
668+
612669
files_to_scan.append(target_str)
613670
_finish_phase_timing(phase_timings, "directory_discovery", directory_discovery_started_at)
614671

615-
# Second pass: scan every path independently. Some scanners depend on
616-
# parent paths or sibling files, so content equality alone is not a
617-
# safe proxy for scan-result equality.
618-
if files_to_scan:
672+
scan_entries: list[_ScanEntry] = [(file_path, [file_path], None) for file_path in files_to_scan]
673+
for shard_family_key, representative_file in shard_family_representatives.items():
674+
ordered_family_paths = sorted(shard_family_paths.get(shard_family_key, {representative_file}))
675+
scan_entries.append((representative_file, ordered_family_paths, shard_family_key))
676+
677+
# Second pass: scan every non-shard path independently and every shard
678+
# family once. Shard scans already expand to sibling shards in the
679+
# advanced handler, so scanning each shard path would duplicate work.
680+
if scan_entries:
681+
hash_paths: list[str] = []
682+
seen_hash_paths: set[str] = set()
683+
for _representative_file, scanned_file_paths, _shard_family_key in scan_entries:
684+
for scanned_file_path in scanned_file_paths:
685+
if scanned_file_path not in seen_hash_paths:
686+
hash_paths.append(scanned_file_path)
687+
seen_hash_paths.add(scanned_file_path)
688+
619689
top_level_hashing_started_at = _start_phase_timing(phase_timings)
620-
content_hashes = _hash_files_by_path(files_to_scan)
690+
content_hashes = _hash_files_by_path(hash_paths)
621691
_finish_phase_timing(phase_timings, "top_level_hashing", top_level_hashing_started_at)
622692
duplicate_paths_by_hash: dict[str, list[str]] = {}
623693
for file_path, content_hash in content_hashes.items():
624694
if not content_hash.startswith("unhashable_"):
625695
duplicate_paths_by_hash.setdefault(content_hash, []).append(file_path)
626696
recorded_content_hashes: set[str] = set()
627697

628-
for representative_file, content_hash in content_hashes.items():
629-
# Collect valid content hashes for aggregate hash computation
630-
# Skip "unhashable_" prefix entries (those are placeholder hashes for files that failed to hash)
698+
for content_hash in content_hashes.values():
631699
if not content_hash.startswith("unhashable_") and content_hash not in recorded_content_hashes:
632700
file_hashes.append(content_hash)
633701
recorded_content_hashes.add(content_hash)
634702

703+
for representative_file, scanned_file_paths, shard_family_key in scan_entries:
635704
# Check for interrupts
636705
check_interrupted()
637706

@@ -641,14 +710,17 @@ def scan_model_directory_or_file(
641710

642711
# Update progress
643712
if progress_callback:
713+
scan_label = Path(representative_file).name
714+
if len(scanned_file_paths) > 1:
715+
scan_label = f"{scan_label} ({len(scanned_file_paths)} shards)"
644716
if total_files is not None and total_files > 0:
645717
progress_callback(
646-
f"Scanning file {processed_files + 1}/{total_files}: {Path(representative_file).name}",
718+
f"Scanning file {processed_files + 1}/{total_files}: {scan_label}",
647719
processed_files / total_files * 100,
648720
)
649721
else:
650722
progress_callback(
651-
f"Scanning file {processed_files + 1}: {Path(representative_file).name}",
723+
f"Scanning file {processed_files + 1}: {scan_label}",
652724
0.0,
653725
)
654726

@@ -658,7 +730,17 @@ def scan_model_directory_or_file(
658730

659731
file_scan_started_at = _start_phase_timing(phase_timings)
660732
try:
661-
file_result = scan_file(representative_file, config)
733+
file_config = config
734+
if shard_family_key is not None:
735+
file_config = dict(config)
736+
file_config[_SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY] = (
737+
_build_shard_family_cache_fingerprint(
738+
shard_family_key,
739+
scanned_file_paths,
740+
content_hashes,
741+
)
742+
)
743+
file_result = scan_file(representative_file, file_config)
662744
finally:
663745
_finish_phase_timing(phase_timings, "file_scan_dispatch", file_scan_started_at)
664746

@@ -667,8 +749,8 @@ def scan_model_directory_or_file(
667749
if _scan_result_has_operational_error(file_result):
668750
scan_metadata["has_operational_errors"] = True
669751
results.bytes_scanned += file_result.bytes_scanned
670-
results.files_scanned += 1
671-
processed_files += 1
752+
results.files_scanned += len(scanned_file_paths)
753+
processed_files += len(scanned_file_paths)
672754

673755
# Add scanner to tracking list (different from scanner_names)
674756
scanner_name = file_result.scanner_name
@@ -734,32 +816,38 @@ def scan_model_directory_or_file(
734816

735817
results.checks.append(Check(**check_dict))
736818

737-
_add_asset_to_results(results, representative_file, file_result)
819+
for scanned_file_path in scanned_file_paths:
820+
_add_asset_to_results(results, scanned_file_path, file_result)
738821
_finish_phase_timing(phase_timings, "result_merge", result_merge_started_at)
739822

740823
# Add metadata for this path using Pydantic models
741-
license_metadata_started_at = _start_phase_timing(phase_timings)
742-
try:
743-
license_metadata = collect_license_metadata(
744-
representative_file,
745-
nearby_license_cache=nearby_license_cache,
746-
)
747-
finally:
748-
_finish_phase_timing(phase_timings, "license_metadata", license_metadata_started_at)
749-
combined_metadata = {**file_result.metadata, **license_metadata}
750-
combined_metadata["content_hash"] = content_hash
751-
duplicate_files = duplicate_paths_by_hash.get(content_hash, [])
752-
combined_metadata["duplicate_files"] = duplicate_files if len(duplicate_files) > 1 else None
824+
from .models import FileMetadataModel
753825

754-
# Convert ml_context if present
755-
if "ml_context" in combined_metadata and isinstance(combined_metadata["ml_context"], dict):
756-
from .models import MLContextModel
826+
for scanned_file_path in scanned_file_paths:
827+
license_metadata_started_at = _start_phase_timing(phase_timings)
828+
try:
829+
license_metadata = collect_license_metadata(
830+
scanned_file_path,
831+
nearby_license_cache=nearby_license_cache,
832+
)
833+
finally:
834+
_finish_phase_timing(phase_timings, "license_metadata", license_metadata_started_at)
835+
combined_metadata = {**file_result.metadata, **license_metadata}
836+
path_content_hash = content_hashes.get(scanned_file_path)
837+
if path_content_hash is not None:
838+
combined_metadata["content_hash"] = path_content_hash
839+
duplicate_files = duplicate_paths_by_hash.get(path_content_hash, [])
840+
combined_metadata["duplicate_files"] = (
841+
duplicate_files if len(duplicate_files) > 1 else None
842+
)
757843

758-
combined_metadata["ml_context"] = MLContextModel(**combined_metadata["ml_context"])
844+
# Convert ml_context if present
845+
if "ml_context" in combined_metadata and isinstance(combined_metadata["ml_context"], dict):
846+
from .models import MLContextModel
759847

760-
from .models import FileMetadataModel
848+
combined_metadata["ml_context"] = MLContextModel(**combined_metadata["ml_context"])
761849

762-
results.file_metadata[representative_file] = FileMetadataModel(**combined_metadata)
850+
results.file_metadata[scanned_file_path] = FileMetadataModel(**combined_metadata)
763851

764852
if max_total_size > 0 and results.bytes_scanned > max_total_size:
765853
_add_issue_to_model(
@@ -783,7 +871,8 @@ def scan_model_directory_or_file(
783871
location=representative_file,
784872
details={"exception_type": type(e).__name__},
785873
)
786-
_add_error_asset_to_results(results, representative_file)
874+
for scanned_file_path in scanned_file_paths:
875+
_add_error_asset_to_results(results, scanned_file_path)
787876

788877
# Final progress update for directory scan
789878
if progress_callback and not limit_reached and total_files is not None and total_files > 0:

modelaudit/utils/file/handlers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,31 @@ class ShardedModelDetector:
8989
r"params_shard_(\d+)\.bin", # Custom parameter sharding
9090
]
9191

92+
@classmethod
93+
def match_shard_filename(cls, file_name: str) -> dict[str, int | str | None] | None:
94+
"""Return shard metadata for a filename when it matches a known shard pattern."""
95+
for pattern in cls.SHARD_PATTERNS:
96+
match = re.fullmatch(pattern, file_name)
97+
if not match:
98+
continue
99+
100+
current_shard_index: int | None = None
101+
expected_total_shards: int | None = None
102+
if match.lastindex:
103+
with suppress(IndexError, ValueError):
104+
current_shard_index = int(match.group(1))
105+
if (match.lastindex or 0) >= 2:
106+
with suppress(IndexError, ValueError):
107+
expected_total_shards = int(match.group(2))
108+
109+
return {
110+
"pattern": pattern,
111+
"current_shard_index": current_shard_index,
112+
"expected_total_shards": expected_total_shards,
113+
}
114+
115+
return None
116+
92117
@classmethod
93118
def detect_shards(cls, file_path: str) -> dict[str, Any] | None:
94119
"""

0 commit comments

Comments
 (0)