4343 validate_file_type_with_formats ,
4444)
4545from modelaudit .utils .file .handlers import (
46+ ShardedModelDetector ,
4647 scan_advanced_large_file ,
4748 should_use_advanced_handler ,
4849)
@@ -103,6 +104,9 @@ def _count_files_up_to(path: Path, limit: int) -> int | None:
103104_XGBOOST_PICKLE_SPOOF_REASON = "xgboost_binary_pickle_spoof"
104105_RECOGNIZED_FORMAT_SCANNER_UNAVAILABLE_REASON = "recognized_format_scanner_unavailable"
105106_XML_MODEL_ROUTING_INCOMPLETE_REASON = "xml_model_routing_incomplete"
107+ _ShardFamilyKey = tuple [str , str , int | None ]
108+ _ScanEntry = tuple [str , list [str ], _ShardFamilyKey | None ]
109+ _SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY = "shard_family_cache_fingerprint"
106110
107111
108112def _start_phase_timing (phase_timings : dict [str , float ] | None ) -> float | None :
@@ -124,6 +128,44 @@ def _attach_phase_timings(results: ModelAuditResultModel, phase_timings: dict[st
124128 results .phase_timings = phase_timings # type: ignore[attr-defined]
125129
126130
131+ def _shard_family_key_for_path (path : str ) -> _ShardFamilyKey | None :
132+ """Return a stable key for files that belong to the same local shard family."""
133+ path_obj = Path (path )
134+ shard_match = ShardedModelDetector .match_shard_filename (path_obj .name )
135+ if shard_match is None :
136+ return None
137+
138+ pattern = shard_match .get ("pattern" )
139+ if not isinstance (pattern , str ):
140+ return None
141+
142+ expected_total = shard_match .get ("expected_total_shards" )
143+ if not isinstance (expected_total , int ):
144+ expected_total = None
145+
146+ return (str (path_obj .parent ), pattern , expected_total )
147+
148+
149+ def _build_shard_family_cache_fingerprint (
150+ shard_family_key : _ShardFamilyKey ,
151+ scanned_file_paths : list [str ],
152+ content_hashes : dict [str , str ],
153+ ) -> dict [str , Any ]:
154+ """Fingerprint every present shard so representative cache entries stay valid."""
155+ _parent_dir , pattern , expected_total_shards = shard_family_key
156+ return {
157+ "pattern" : pattern ,
158+ "expected_total_shards" : expected_total_shards ,
159+ "members" : [
160+ {
161+ "path" : str (Path (scanned_file_path ).resolve ()),
162+ "content_hash" : content_hashes .get (scanned_file_path ),
163+ }
164+ for scanned_file_path in sorted (scanned_file_paths )
165+ ],
166+ }
167+
168+
127169def _select_preferred_scanner_id (path : str , header_format : str , ext : str ) -> str | None :
128170 """Select a scanner by trusted file structure, not just suffix."""
129171 if header_format == "zip" :
@@ -536,6 +578,8 @@ def scan_model_directory_or_file(
536578
537579 # First pass: collect all file paths that need scanning
538580 files_to_scan : list [str ] = []
581+ shard_family_representatives : dict [_ShardFamilyKey , str ] = {}
582+ shard_family_paths : dict [_ShardFamilyKey , set [str ]] = {}
539583 directory_discovery_started_at = _start_phase_timing (phase_timings )
540584 for root , _ , files in os .walk (path , followlinks = False ):
541585 for file in files :
@@ -609,29 +653,54 @@ def scan_model_directory_or_file(
609653 continue
610654
611655 # Add to files to scan list instead of scanning immediately
656+ shard_family_key = _shard_family_key_for_path (target_str )
657+ if shard_family_key is not None :
658+ family_paths = shard_family_paths .setdefault (shard_family_key , set ())
659+ family_paths .add (target_str )
660+ if shard_family_key not in shard_family_representatives :
661+ shard_family_representatives [shard_family_key ] = target_str
662+ shard_info = ShardedModelDetector .detect_shards (target_str )
663+ if shard_info is not None :
664+ for shard_path in shard_info .get ("shards" , []):
665+ if isinstance (shard_path , str ):
666+ family_paths .add (str (Path (shard_path ).resolve ()))
667+ continue
668+
612669 files_to_scan .append (target_str )
613670 _finish_phase_timing (phase_timings , "directory_discovery" , directory_discovery_started_at )
614671
615- # Second pass: scan every path independently. Some scanners depend on
616- # parent paths or sibling files, so content equality alone is not a
617- # safe proxy for scan-result equality.
618- if files_to_scan :
672+ scan_entries : list [_ScanEntry ] = [(file_path , [file_path ], None ) for file_path in files_to_scan ]
673+ for shard_family_key , representative_file in shard_family_representatives .items ():
674+ ordered_family_paths = sorted (shard_family_paths .get (shard_family_key , {representative_file }))
675+ scan_entries .append ((representative_file , ordered_family_paths , shard_family_key ))
676+
677+ # Second pass: scan every non-shard path independently and every shard
678+ # family once. Shard scans already expand to sibling shards in the
679+ # advanced handler, so scanning each shard path would duplicate work.
680+ if scan_entries :
681+ hash_paths : list [str ] = []
682+ seen_hash_paths : set [str ] = set ()
683+ for _representative_file , scanned_file_paths , _shard_family_key in scan_entries :
684+ for scanned_file_path in scanned_file_paths :
685+ if scanned_file_path not in seen_hash_paths :
686+ hash_paths .append (scanned_file_path )
687+ seen_hash_paths .add (scanned_file_path )
688+
619689 top_level_hashing_started_at = _start_phase_timing (phase_timings )
620- content_hashes = _hash_files_by_path (files_to_scan )
690+ content_hashes = _hash_files_by_path (hash_paths )
621691 _finish_phase_timing (phase_timings , "top_level_hashing" , top_level_hashing_started_at )
622692 duplicate_paths_by_hash : dict [str , list [str ]] = {}
623693 for file_path , content_hash in content_hashes .items ():
624694 if not content_hash .startswith ("unhashable_" ):
625695 duplicate_paths_by_hash .setdefault (content_hash , []).append (file_path )
626696 recorded_content_hashes : set [str ] = set ()
627697
628- for representative_file , content_hash in content_hashes .items ():
629- # Collect valid content hashes for aggregate hash computation
630- # Skip "unhashable_" prefix entries (those are placeholder hashes for files that failed to hash)
698+ for content_hash in content_hashes .values ():
631699 if not content_hash .startswith ("unhashable_" ) and content_hash not in recorded_content_hashes :
632700 file_hashes .append (content_hash )
633701 recorded_content_hashes .add (content_hash )
634702
703+ for representative_file , scanned_file_paths , shard_family_key in scan_entries :
635704 # Check for interrupts
636705 check_interrupted ()
637706
@@ -641,14 +710,17 @@ def scan_model_directory_or_file(
641710
642711 # Update progress
643712 if progress_callback :
713+ scan_label = Path (representative_file ).name
714+ if len (scanned_file_paths ) > 1 :
715+ scan_label = f"{ scan_label } ({ len (scanned_file_paths )} shards)"
644716 if total_files is not None and total_files > 0 :
645717 progress_callback (
646- f"Scanning file { processed_files + 1 } /{ total_files } : { Path ( representative_file ). name } " ,
718+ f"Scanning file { processed_files + 1 } /{ total_files } : { scan_label } " ,
647719 processed_files / total_files * 100 ,
648720 )
649721 else :
650722 progress_callback (
651- f"Scanning file { processed_files + 1 } : { Path ( representative_file ). name } " ,
723+ f"Scanning file { processed_files + 1 } : { scan_label } " ,
652724 0.0 ,
653725 )
654726
@@ -658,7 +730,17 @@ def scan_model_directory_or_file(
658730
659731 file_scan_started_at = _start_phase_timing (phase_timings )
660732 try :
661- file_result = scan_file (representative_file , config )
733+ file_config = config
734+ if shard_family_key is not None :
735+ file_config = dict (config )
736+ file_config [_SHARD_FAMILY_CACHE_FINGERPRINT_CONFIG_KEY ] = (
737+ _build_shard_family_cache_fingerprint (
738+ shard_family_key ,
739+ scanned_file_paths ,
740+ content_hashes ,
741+ )
742+ )
743+ file_result = scan_file (representative_file , file_config )
662744 finally :
663745 _finish_phase_timing (phase_timings , "file_scan_dispatch" , file_scan_started_at )
664746
@@ -667,8 +749,8 @@ def scan_model_directory_or_file(
667749 if _scan_result_has_operational_error (file_result ):
668750 scan_metadata ["has_operational_errors" ] = True
669751 results .bytes_scanned += file_result .bytes_scanned
670- results .files_scanned += 1
671- processed_files += 1
752+ results .files_scanned += len ( scanned_file_paths )
753+ processed_files += len ( scanned_file_paths )
672754
673755 # Add scanner to tracking list (different from scanner_names)
674756 scanner_name = file_result .scanner_name
@@ -734,32 +816,38 @@ def scan_model_directory_or_file(
734816
735817 results .checks .append (Check (** check_dict ))
736818
737- _add_asset_to_results (results , representative_file , file_result )
819+ for scanned_file_path in scanned_file_paths :
820+ _add_asset_to_results (results , scanned_file_path , file_result )
738821 _finish_phase_timing (phase_timings , "result_merge" , result_merge_started_at )
739822
740823 # Add metadata for this path using Pydantic models
741- license_metadata_started_at = _start_phase_timing (phase_timings )
742- try :
743- license_metadata = collect_license_metadata (
744- representative_file ,
745- nearby_license_cache = nearby_license_cache ,
746- )
747- finally :
748- _finish_phase_timing (phase_timings , "license_metadata" , license_metadata_started_at )
749- combined_metadata = {** file_result .metadata , ** license_metadata }
750- combined_metadata ["content_hash" ] = content_hash
751- duplicate_files = duplicate_paths_by_hash .get (content_hash , [])
752- combined_metadata ["duplicate_files" ] = duplicate_files if len (duplicate_files ) > 1 else None
824+ from .models import FileMetadataModel
753825
754- # Convert ml_context if present
755- if "ml_context" in combined_metadata and isinstance (combined_metadata ["ml_context" ], dict ):
756- from .models import MLContextModel
826+ for scanned_file_path in scanned_file_paths :
827+ license_metadata_started_at = _start_phase_timing (phase_timings )
828+ try :
829+ license_metadata = collect_license_metadata (
830+ scanned_file_path ,
831+ nearby_license_cache = nearby_license_cache ,
832+ )
833+ finally :
834+ _finish_phase_timing (phase_timings , "license_metadata" , license_metadata_started_at )
835+ combined_metadata = {** file_result .metadata , ** license_metadata }
836+ path_content_hash = content_hashes .get (scanned_file_path )
837+ if path_content_hash is not None :
838+ combined_metadata ["content_hash" ] = path_content_hash
839+ duplicate_files = duplicate_paths_by_hash .get (path_content_hash , [])
840+ combined_metadata ["duplicate_files" ] = (
841+ duplicate_files if len (duplicate_files ) > 1 else None
842+ )
757843
758- combined_metadata ["ml_context" ] = MLContextModel (** combined_metadata ["ml_context" ])
844+ # Convert ml_context if present
845+ if "ml_context" in combined_metadata and isinstance (combined_metadata ["ml_context" ], dict ):
846+ from .models import MLContextModel
759847
760- from . models import FileMetadataModel
848+ combined_metadata [ "ml_context" ] = MLContextModel ( ** combined_metadata [ "ml_context" ])
761849
762- results .file_metadata [representative_file ] = FileMetadataModel (** combined_metadata )
850+ results .file_metadata [scanned_file_path ] = FileMetadataModel (** combined_metadata )
763851
764852 if max_total_size > 0 and results .bytes_scanned > max_total_size :
765853 _add_issue_to_model (
@@ -783,7 +871,8 @@ def scan_model_directory_or_file(
783871 location = representative_file ,
784872 details = {"exception_type" : type (e ).__name__ },
785873 )
786- _add_error_asset_to_results (results , representative_file )
874+ for scanned_file_path in scanned_file_paths :
875+ _add_error_asset_to_results (results , scanned_file_path )
787876
788877 # Final progress update for directory scan
789878 if progress_callback and not limit_reached and total_files is not None and total_files > 0 :
0 commit comments