22import inspect
33import os
44import time
5+ from concurrent .futures import ThreadPoolExecutor
56from pathlib import Path
67from typing import (
78 Any ,
@@ -345,6 +346,53 @@ def list_subfolders(folder_uri: str) -> List[str]:
345346 logger .info (f"Error listing subfolders in { folder_uri } : { e } " )
346347 return []
347348
349+ @staticmethod
350+ def _filter_files (
351+ fs : pa_fs .FileSystem ,
352+ source_path : str ,
353+ destination_path : str ,
354+ substrings_to_include : Optional [List [str ]] = None ,
355+ suffixes_to_exclude : Optional [List [str ]] = None ,
356+ ) -> List [Tuple [str , str ]]:
357+ """Filter files from cloud storage based on inclusion and exclusion criteria.
358+
359+ Args:
360+ fs: PyArrow filesystem instance
361+ source_path: Source path in cloud storage
362+ destination_path: Local destination path
363+ substrings_to_include: Only include files containing these substrings
364+ suffixes_to_exclude: Exclude files ending with these suffixes
365+
366+ Returns:
367+ List of tuples containing (source_file_path, destination_file_path)
368+ """
369+ file_selector = pa_fs .FileSelector (source_path , recursive = True )
370+ file_infos = fs .get_file_info (file_selector )
371+
372+ path_pairs = []
373+ for file_info in file_infos :
374+ if file_info .type != pa_fs .FileType .File :
375+ continue
376+
377+ rel_path = file_info .path [len (source_path ) :].lstrip ("/" )
378+
379+ # Apply filters
380+ if substrings_to_include :
381+ if not any (
382+ substring in rel_path for substring in substrings_to_include
383+ ):
384+ continue
385+
386+ if suffixes_to_exclude :
387+ if any (rel_path .endswith (suffix ) for suffix in suffixes_to_exclude ):
388+ continue
389+
390+ path_pairs .append (
391+ (file_info .path , os .path .join (destination_path , rel_path ))
392+ )
393+
394+ return path_pairs
395+
348396 @staticmethod
349397 def download_files (
350398 path : str ,
@@ -366,40 +414,104 @@ def download_files(
366414 # Ensure the destination directory exists
367415 os .makedirs (path , exist_ok = True )
368416
369- # List all files in the bucket
370- file_selector = pa_fs .FileSelector (source_path , recursive = True )
371- file_infos = fs .get_file_info (file_selector )
417+ # Get filtered files to download
418+ files_to_download = CloudFileSystem ._filter_files (
419+ fs , source_path , path , substrings_to_include , suffixes_to_exclude
420+ )
372421
373422 # Download each file
374- for file_info in file_infos :
375- if file_info .type != pa_fs .FileType .File :
376- continue
423+ for source_file_path , dest_file_path in files_to_download :
424+ # Create destination directory if needed
425+ dest_dir = os .path .dirname (dest_file_path )
426+ if dest_dir :
427+ os .makedirs (dest_dir , exist_ok = True )
428+
429+ # Download the file
430+ with fs .open_input_file (source_file_path ) as source_file :
431+ with open (dest_file_path , "wb" ) as dest_file :
432+ dest_file .write (source_file .read ())
377433
378- # Get relative path from source prefix
379- rel_path = file_info .path [len (source_path ) :].lstrip ("/" )
434+ except Exception as e :
435+ logger .exception (f"Error downloading files from { bucket_uri } : { e } " )
436+ raise
380437
381- # Check if file matches substring filters
382- if substrings_to_include :
383- if not any (
384- substring in rel_path for substring in substrings_to_include
385- ):
386- continue
438+ @staticmethod
439+ def download_files_parallel (
440+ path : str ,
441+ bucket_uri : str ,
442+ substrings_to_include : Optional [List [str ]] = None ,
443+ suffixes_to_exclude : Optional [List [str ]] = None ,
444+ max_concurrency : int = 10 ,
445+ chunk_size : int = 64 * 1024 * 1024 ,
446+ ) -> None :
447+ """Multi-threaded download of files from cloud storage.
387448
388- # Check if file matches suffixes to exclude filter
389- if suffixes_to_exclude :
390- if any (rel_path .endswith (suffix ) for suffix in suffixes_to_exclude ):
391- continue
449+ Args:
450+ path: Local directory where files will be downloaded
451+ bucket_uri: URI of cloud directory
452+ substrings_to_include: Only include files containing these substrings
453+ suffixes_to_exclude: Exclude certain files from download
454+ max_concurrency: Maximum number of concurrent files to download (default: 10)
455+ chunk_size: Size of transfer chunks (default: 64MB)
456+ """
457+ try :
458+ fs , source_path = CloudFileSystem .get_fs_and_path (bucket_uri )
459+
460+ # Ensure destination exists
461+ os .makedirs (path , exist_ok = True )
462+
463+ # If no filters, use direct copy_files
464+ if not substrings_to_include and not suffixes_to_exclude :
465+ pa_fs .copy_files (
466+ source = source_path ,
467+ destination = path ,
468+ source_filesystem = fs ,
469+ destination_filesystem = pa_fs .LocalFileSystem (),
470+ use_threads = True ,
471+ chunk_size = chunk_size ,
472+ )
473+ return
392474
475+ # List and filter files
476+ files_to_download = CloudFileSystem ._filter_files (
477+ fs , source_path , path , substrings_to_include , suffixes_to_exclude
478+ )
479+
480+ if not files_to_download :
481+ logger .info ("Filters do not match any of the files, skipping download" )
482+ return
483+
484+ def download_single_file (file_paths ):
485+ source_file_path , dest_file_path = file_paths
393486 # Create destination directory if needed
394- if "/" in rel_path :
395- dest_dir = os . path . join ( path , os . path . dirname ( rel_path ))
487+ dest_dir = os . path . dirname ( dest_file_path )
488+ if dest_dir :
396489 os .makedirs (dest_dir , exist_ok = True )
397490
398- # Download the file
399- dest_path = os .path .join (path , rel_path )
400- with fs .open_input_file (file_info .path ) as source_file :
401- with open (dest_path , "wb" ) as dest_file :
402- dest_file .write (source_file .read ())
491+ # Use PyArrow's copy_files for individual files,
492+ pa_fs .copy_files (
493+ source = source_file_path ,
494+ destination = dest_file_path ,
495+ source_filesystem = fs ,
496+ destination_filesystem = pa_fs .LocalFileSystem (),
497+ use_threads = True ,
498+ chunk_size = chunk_size ,
499+ )
500+ return dest_file_path
501+
502+ max_workers = min (max_concurrency , len (files_to_download ))
503+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
504+ futures = [
505+ executor .submit (download_single_file , file_paths )
506+ for file_paths in files_to_download
507+ ]
508+
509+ for future in futures :
510+ try :
511+ future .result ()
512+ except Exception as e :
513+ logger .error (f"Failed to download file: { e } " )
514+ raise
403515
404516 except Exception as e :
405517 logger .exception (f"Error downloading files from { bucket_uri } : { e } " )
@@ -464,11 +576,12 @@ def download_model(
464576
465577 safetensors_to_exclude = [".safetensors" ] if exclude_safetensors else None
466578
467- CloudFileSystem .download_files (
579+ CloudFileSystem .download_files_parallel (
468580 path = destination_dir ,
469581 bucket_uri = bucket_uri ,
470582 substrings_to_include = tokenizer_file_substrings ,
471583 suffixes_to_exclude = safetensors_to_exclude ,
584+ chunk_size = 64 * 1024 * 1024 , # 64MB chunks for large model files
472585 )
473586
474587 except Exception as e :
0 commit comments