Skip to content

Commit 7e48795

Browse files
hao-aaronYoussefEssDS
authored andcommitted
[serve][llm] Disable model downloading for RunAI streamer, introduce optimized download function (ray-project#57854)
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
1 parent a8a3795 commit 7e48795

File tree

7 files changed

+357
-67
lines changed

7 files changed

+357
-67
lines changed

python/ray/llm/_internal/common/callbacks/cloud_downloader.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import time
23
from typing import Any, List, Tuple
34

45
from pydantic import BaseModel, field_validator
@@ -76,20 +77,12 @@ def on_before_download_model_files_distributed(self) -> None:
7677
from ray.llm._internal.common.utils.cloud_utils import CloudFileSystem
7778

7879
paths = self.kwargs["paths"]
80+
start_time = time.monotonic()
81+
for cloud_uri, local_path in paths:
82+
CloudFileSystem.download_files_parallel(
83+
path=local_path, bucket_uri=cloud_uri
84+
)
85+
end_time = time.monotonic()
7986
logger.info(
80-
f"CloudDownloader: Starting download of {len(paths)} files from cloud storage"
87+
f"CloudDownloader: Files downloaded in {end_time - start_time} seconds"
8188
)
82-
83-
for cloud_uri, local_path in paths:
84-
try:
85-
logger.info(f"CloudDownloader: Downloading {cloud_uri} to {local_path}")
86-
CloudFileSystem.download_files(path=local_path, bucket_uri=cloud_uri)
87-
logger.info(
88-
f"CloudDownloader: Successfully downloaded {cloud_uri} to {local_path}"
89-
)
90-
except Exception as e:
91-
logger.error(
92-
f"CloudDownloader: Failed to download {cloud_uri} to {local_path}: {e}"
93-
)
94-
if self.raise_error_on_callback:
95-
raise

python/ray/llm/_internal/common/utils/cloud_utils.py

Lines changed: 139 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import os
44
import time
5+
from concurrent.futures import ThreadPoolExecutor
56
from pathlib import Path
67
from typing import (
78
Any,
@@ -345,6 +346,53 @@ def list_subfolders(folder_uri: str) -> List[str]:
345346
logger.info(f"Error listing subfolders in {folder_uri}: {e}")
346347
return []
347348

349+
@staticmethod
350+
def _filter_files(
351+
fs: pa_fs.FileSystem,
352+
source_path: str,
353+
destination_path: str,
354+
substrings_to_include: Optional[List[str]] = None,
355+
suffixes_to_exclude: Optional[List[str]] = None,
356+
) -> List[Tuple[str, str]]:
357+
"""Filter files from cloud storage based on inclusion and exclusion criteria.
358+
359+
Args:
360+
fs: PyArrow filesystem instance
361+
source_path: Source path in cloud storage
362+
destination_path: Local destination path
363+
substrings_to_include: Only include files containing these substrings
364+
suffixes_to_exclude: Exclude files ending with these suffixes
365+
366+
Returns:
367+
List of tuples containing (source_file_path, destination_file_path)
368+
"""
369+
file_selector = pa_fs.FileSelector(source_path, recursive=True)
370+
file_infos = fs.get_file_info(file_selector)
371+
372+
path_pairs = []
373+
for file_info in file_infos:
374+
if file_info.type != pa_fs.FileType.File:
375+
continue
376+
377+
rel_path = file_info.path[len(source_path) :].lstrip("/")
378+
379+
# Apply filters
380+
if substrings_to_include:
381+
if not any(
382+
substring in rel_path for substring in substrings_to_include
383+
):
384+
continue
385+
386+
if suffixes_to_exclude:
387+
if any(rel_path.endswith(suffix) for suffix in suffixes_to_exclude):
388+
continue
389+
390+
path_pairs.append(
391+
(file_info.path, os.path.join(destination_path, rel_path))
392+
)
393+
394+
return path_pairs
395+
348396
@staticmethod
349397
def download_files(
350398
path: str,
@@ -366,40 +414,104 @@ def download_files(
366414
# Ensure the destination directory exists
367415
os.makedirs(path, exist_ok=True)
368416

369-
# List all files in the bucket
370-
file_selector = pa_fs.FileSelector(source_path, recursive=True)
371-
file_infos = fs.get_file_info(file_selector)
417+
# Get filtered files to download
418+
files_to_download = CloudFileSystem._filter_files(
419+
fs, source_path, path, substrings_to_include, suffixes_to_exclude
420+
)
372421

373422
# Download each file
374-
for file_info in file_infos:
375-
if file_info.type != pa_fs.FileType.File:
376-
continue
423+
for source_file_path, dest_file_path in files_to_download:
424+
# Create destination directory if needed
425+
dest_dir = os.path.dirname(dest_file_path)
426+
if dest_dir:
427+
os.makedirs(dest_dir, exist_ok=True)
428+
429+
# Download the file
430+
with fs.open_input_file(source_file_path) as source_file:
431+
with open(dest_file_path, "wb") as dest_file:
432+
dest_file.write(source_file.read())
377433

378-
# Get relative path from source prefix
379-
rel_path = file_info.path[len(source_path) :].lstrip("/")
434+
except Exception as e:
435+
logger.exception(f"Error downloading files from {bucket_uri}: {e}")
436+
raise
380437

381-
# Check if file matches substring filters
382-
if substrings_to_include:
383-
if not any(
384-
substring in rel_path for substring in substrings_to_include
385-
):
386-
continue
438+
@staticmethod
439+
def download_files_parallel(
440+
path: str,
441+
bucket_uri: str,
442+
substrings_to_include: Optional[List[str]] = None,
443+
suffixes_to_exclude: Optional[List[str]] = None,
444+
max_concurrency: int = 10,
445+
chunk_size: int = 64 * 1024 * 1024,
446+
) -> None:
447+
"""Multi-threaded download of files from cloud storage.
387448
388-
# Check if file matches suffixes to exclude filter
389-
if suffixes_to_exclude:
390-
if any(rel_path.endswith(suffix) for suffix in suffixes_to_exclude):
391-
continue
449+
Args:
450+
path: Local directory where files will be downloaded
451+
bucket_uri: URI of cloud directory
452+
substrings_to_include: Only include files containing these substrings
453+
suffixes_to_exclude: Exclude certain files from download
454+
max_concurrency: Maximum number of concurrent files to download (default: 10)
455+
chunk_size: Size of transfer chunks (default: 64MB)
456+
"""
457+
try:
458+
fs, source_path = CloudFileSystem.get_fs_and_path(bucket_uri)
459+
460+
# Ensure destination exists
461+
os.makedirs(path, exist_ok=True)
462+
463+
# If no filters, use direct copy_files
464+
if not substrings_to_include and not suffixes_to_exclude:
465+
pa_fs.copy_files(
466+
source=source_path,
467+
destination=path,
468+
source_filesystem=fs,
469+
destination_filesystem=pa_fs.LocalFileSystem(),
470+
use_threads=True,
471+
chunk_size=chunk_size,
472+
)
473+
return
392474

475+
# List and filter files
476+
files_to_download = CloudFileSystem._filter_files(
477+
fs, source_path, path, substrings_to_include, suffixes_to_exclude
478+
)
479+
480+
if not files_to_download:
481+
logger.info("Filters do not match any of the files, skipping download")
482+
return
483+
484+
def download_single_file(file_paths):
485+
source_file_path, dest_file_path = file_paths
393486
# Create destination directory if needed
394-
if "/" in rel_path:
395-
dest_dir = os.path.join(path, os.path.dirname(rel_path))
487+
dest_dir = os.path.dirname(dest_file_path)
488+
if dest_dir:
396489
os.makedirs(dest_dir, exist_ok=True)
397490

398-
# Download the file
399-
dest_path = os.path.join(path, rel_path)
400-
with fs.open_input_file(file_info.path) as source_file:
401-
with open(dest_path, "wb") as dest_file:
402-
dest_file.write(source_file.read())
491+
# Use PyArrow's copy_files for individual files,
492+
pa_fs.copy_files(
493+
source=source_file_path,
494+
destination=dest_file_path,
495+
source_filesystem=fs,
496+
destination_filesystem=pa_fs.LocalFileSystem(),
497+
use_threads=True,
498+
chunk_size=chunk_size,
499+
)
500+
return dest_file_path
501+
502+
max_workers = min(max_concurrency, len(files_to_download))
503+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
504+
futures = [
505+
executor.submit(download_single_file, file_paths)
506+
for file_paths in files_to_download
507+
]
508+
509+
for future in futures:
510+
try:
511+
future.result()
512+
except Exception as e:
513+
logger.error(f"Failed to download file: {e}")
514+
raise
403515

404516
except Exception as e:
405517
logger.exception(f"Error downloading files from {bucket_uri}: {e}")
@@ -464,11 +576,12 @@ def download_model(
464576

465577
safetensors_to_exclude = [".safetensors"] if exclude_safetensors else None
466578

467-
CloudFileSystem.download_files(
579+
CloudFileSystem.download_files_parallel(
468580
path=destination_dir,
469581
bucket_uri=bucket_uri,
470582
substrings_to_include=tokenizer_file_substrings,
471583
suffixes_to_exclude=safetensors_to_exclude,
584+
chunk_size=64 * 1024 * 1024, # 64MB chunks for large model files
472585
)
473586

474587
except Exception as e:

python/ray/llm/_internal/common/utils/download_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
logger = get_logger(__name__)
2121

22-
STREAMING_LOAD_FORMATS = ["runai_streamer", "tensorizer"]
22+
STREAMING_LOAD_FORMATS = ["runai_streamer", "runai_streamer_sharded", "tensorizer"]
2323

2424

2525
class NodeModelDownloadable(enum.Enum):
@@ -267,7 +267,7 @@ def download_model_files(
267267
# cannot be created by torch if the parent directory doesn't exist.
268268
torch_cache_home = torch.hub._get_torch_home()
269269
os.makedirs(os.path.join(torch_cache_home, "kernels"), exist_ok=True)
270-
model_path_or_id = None
270+
model_path_or_id = model_id
271271

272272
if callback is not None:
273273
callback.run_callback_sync("on_before_download_model_files_distributed")

python/ray/llm/_internal/serve/core/configs/llm_config.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
CloudMirrorConfig,
2727
is_remote_path,
2828
)
29-
from ray.llm._internal.common.utils.download_utils import NodeModelDownloadable
29+
from ray.llm._internal.common.utils.download_utils import (
30+
STREAMING_LOAD_FORMATS,
31+
NodeModelDownloadable,
32+
)
3033
from ray.llm._internal.common.utils.import_utils import load_class, try_import
3134
from ray.llm._internal.serve.constants import (
3235
DEFAULT_MULTIPLEX_DOWNLOAD_TIMEOUT_S,
@@ -297,6 +300,10 @@ def get_or_create_callback(self) -> Optional[CallbackBase]:
297300
assert engine_config is not None
298301
pg = engine_config.get_or_create_pg()
299302
runtime_env = engine_config.get_runtime_env_with_local_env_vars()
303+
if self.engine_kwargs.get("load_format", None) in STREAMING_LOAD_FORMATS:
304+
worker_node_download_model = NodeModelDownloadable.NONE
305+
else:
306+
worker_node_download_model = NodeModelDownloadable.MODEL_AND_TOKENIZER
300307

301308
# Create new instance
302309
if isinstance(self.callback_config.callback_class, str):
@@ -308,7 +315,7 @@ def get_or_create_callback(self) -> Optional[CallbackBase]:
308315
raise_error_on_callback=self.callback_config.raise_error_on_callback,
309316
llm_config=self,
310317
ctx_kwargs={
311-
"worker_node_download_model": NodeModelDownloadable.MODEL_AND_TOKENIZER,
318+
"worker_node_download_model": worker_node_download_model,
312319
"placement_group": pg,
313320
"runtime_env": runtime_env,
314321
},

python/ray/llm/_internal/serve/utils/node_initialization_utils.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import asyncio
2-
import os
32
from typing import Optional
43

54
import ray
65
from ray.llm._internal.common.utils.download_utils import (
76
download_model_files,
87
)
98
from ray.llm._internal.common.utils.import_utils import try_import
10-
from ray.llm._internal.serve.core.configs.llm_config import LLMConfig, LLMEngine
9+
from ray.llm._internal.serve.core.configs.llm_config import LLMConfig
1110
from ray.llm._internal.serve.observability.logging import get_logger
1211

1312
torch = try_import("torch")
@@ -33,24 +32,6 @@ def initialize_remote_node(llm_config: LLMConfig) -> Optional[str]:
3332
if local_path and local_path != engine_config.actual_hf_model_id:
3433
engine_config.hf_model_id = local_path
3534

36-
# Download the tokenizer if it isn't a local file path
37-
if not isinstance(local_path, str) or not os.path.exists(local_path):
38-
logger.info(f"Downloading the tokenizer for {engine_config.actual_hf_model_id}")
39-
40-
if llm_config.llm_engine == LLMEngine.vLLM:
41-
from vllm.transformers_utils.tokenizer import get_tokenizer
42-
43-
_ = get_tokenizer(
44-
engine_config.actual_hf_model_id,
45-
tokenizer_mode=engine_config.engine_kwargs.get("tokenizer_mode", None),
46-
trust_remote_code=engine_config.trust_remote_code,
47-
)
48-
else:
49-
_ = transformers.AutoTokenizer.from_pretrained(
50-
engine_config.actual_hf_model_id,
51-
trust_remote_code=engine_config.trust_remote_code,
52-
)
53-
5435
return local_path
5536

5637

0 commit comments

Comments
 (0)