[llm][WIP] Support TensorRT-LLM in ray data llm #61409
[llm][WIP] Support TensorRT-LLM in ray data llm #61409owenowenisme wants to merge 1 commit intoray-project:masterfrom
Conversation
Signed-off-by: You-Cheng Lin <mses010108@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces support for TensorRT-LLM in Ray Data LLM. However, the implementation appears to be a work-in-progress with significant copy-paste errors from the existing SGLang integration. Both new files, trtllm_engine_proc.py and trtllm_engine_stage.py, contain numerous incorrect references to SGLang in class names, comments, logic, and documentation links. The core logic for interacting with the TensorRT-LLM engine is also flawed. I've provided detailed comments to highlight these critical issues that need to be addressed.
| """The SGLang engine processor.""" | ||
|
|
||
| import logging | ||
| from typing import Any, Dict, Optional | ||
|
|
||
| import transformers | ||
| from pydantic import Field, root_validator | ||
|
|
||
| import ray | ||
| from ray.data.block import UserDefinedFunction | ||
| from ray.llm._internal.batch.constants import SGLangTaskType, TypeSGLangTaskType | ||
| from ray.llm._internal.batch.observability.usage_telemetry.usage import ( | ||
| BatchModelTelemetry, | ||
| TelemetryAgent, | ||
| get_or_create_telemetry_agent, | ||
| ) | ||
| from ray.llm._internal.batch.processor.base import ( | ||
| DEFAULT_MAX_TASKS_IN_FLIGHT, | ||
| OfflineProcessorConfig, | ||
| Processor, | ||
| ProcessorBuilder, | ||
| ) | ||
| from ray.llm._internal.batch.processor.utils import ( | ||
| build_cpu_stage_map_kwargs, | ||
| get_value_or_fallback, | ||
| ) | ||
| from ray.llm._internal.batch.stages import ( | ||
| ChatTemplateStage, | ||
| DetokenizeStage, | ||
| SGLangEngineStage, | ||
| TokenizeStage, | ||
| ) | ||
| from ray.llm._internal.batch.stages.configs import ( | ||
| ChatTemplateStageConfig, | ||
| DetokenizeStageConfig, | ||
| TokenizerStageConfig, | ||
| resolve_stage_config, | ||
| ) | ||
| from ray.llm._internal.common.observability.telemetry_utils import DEFAULT_GPU_TYPE | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| DEFAULT_MODEL_ARCHITECTURE = "UNKNOWN_MODEL_ARCHITECTURE" | ||
|
|
||
|
|
||
| class SGLangEngineProcessorConfig(OfflineProcessorConfig): | ||
| """The configuration for the SGLang engine processor.""" | ||
|
|
||
| # SGLang stage configurations. | ||
| engine_kwargs: Dict[str, Any] = Field( | ||
| default_factory=dict, | ||
| description="The kwargs to pass to the SGLang engine. See " | ||
| "https://docs.sglang.ai/backend/server_arguments.html " | ||
| "for more details.", | ||
| ) | ||
| task_type: TypeSGLangTaskType = Field( | ||
| default=SGLangTaskType.GENERATE, | ||
| description="The task type to use. If not specified, will use " | ||
| "'generate' by default.", | ||
| ) | ||
|
|
||
| @root_validator(pre=True) | ||
| def validate_task_type(cls, values): | ||
| task_type = values.get("task_type", SGLangTaskType.GENERATE) | ||
| if task_type not in SGLangTaskType.values(): | ||
| raise ValueError(f"Invalid task type: {task_type}") | ||
|
|
||
| engine_kwargs = values.get("engine_kwargs", {}) | ||
| engine_kwargs_task = engine_kwargs.get("task", "") | ||
| if engine_kwargs_task != task_type: | ||
| logger.warning( | ||
| "The task set in engine kwargs (%s) is different from the " | ||
| "stage (%s). Overriding the task in engine kwargs to %s.", | ||
| engine_kwargs_task, | ||
| task_type, | ||
| task_type, | ||
| ) | ||
| engine_kwargs["task"] = task_type | ||
| values["engine_kwargs"] = engine_kwargs | ||
| return values | ||
|
|
||
|
|
||
| def build_sglang_engine_processor( | ||
| config: SGLangEngineProcessorConfig, | ||
| chat_template_kwargs: Optional[Dict[str, Any]] = None, | ||
| preprocess: Optional[UserDefinedFunction] = None, | ||
| postprocess: Optional[UserDefinedFunction] = None, | ||
| preprocess_map_kwargs: Optional[Dict[str, Any]] = None, | ||
| postprocess_map_kwargs: Optional[Dict[str, Any]] = None, | ||
| telemetry_agent: Optional[TelemetryAgent] = None, | ||
| ) -> Processor: | ||
| """Construct a Processor and configure stages. | ||
|
|
||
| Args: | ||
| config: The configuration for the processor. | ||
| chat_template_kwargs: The optional kwargs to pass to apply_chat_template. | ||
| preprocess: An optional lambda function that takes a row (dict) as input | ||
| and returns a preprocessed row (dict). The output row must contain the | ||
| required fields for the following processing stages. | ||
| postprocess: An optional lambda function that takes a row (dict) as input | ||
| and returns a postprocessed row (dict). | ||
| preprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the | ||
| preprocess stage (e.g., num_cpus, memory, concurrency). | ||
| postprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the | ||
| postprocess stage (e.g., num_cpus, memory, concurrency). | ||
| telemetry_agent: An optional telemetry agent for collecting usage telemetry. | ||
|
|
||
| Returns: | ||
| The constructed processor. | ||
| """ | ||
| ray.init(runtime_env=config.runtime_env, ignore_reinit_error=True) | ||
|
|
||
| stages = [] | ||
|
|
||
| # Prepare processor defaults for merging into stage configs | ||
| processor_defaults = { | ||
| "batch_size": config.batch_size, | ||
| "concurrency": config.concurrency, | ||
| "runtime_env": config.runtime_env, | ||
| "model_source": config.model_source, | ||
| } | ||
|
|
||
| # Resolve and build ChatTemplateStage if enabled | ||
| chat_template_stage_cfg = resolve_stage_config( | ||
| config.chat_template_stage, | ||
| ChatTemplateStageConfig, | ||
| processor_defaults, | ||
| ) | ||
| if chat_template_stage_cfg.enabled: | ||
| stages.append( | ||
| ChatTemplateStage( | ||
| fn_constructor_kwargs=dict( | ||
| model=chat_template_stage_cfg.model_source, | ||
| chat_template=get_value_or_fallback( | ||
| chat_template_stage_cfg.chat_template, config.chat_template | ||
| ), | ||
| chat_template_kwargs=get_value_or_fallback( | ||
| chat_template_stage_cfg.chat_template_kwargs, | ||
| chat_template_kwargs, | ||
| ), | ||
| ), | ||
| map_batches_kwargs=build_cpu_stage_map_kwargs(chat_template_stage_cfg), | ||
| ) | ||
| ) | ||
|
|
||
| # Resolve and build TokenizeStage if enabled | ||
| tokenize_stage_cfg = resolve_stage_config( | ||
| getattr(config, "tokenize_stage", config.tokenize), | ||
| TokenizerStageConfig, | ||
| processor_defaults, | ||
| ) | ||
| if tokenize_stage_cfg.enabled: | ||
| stages.append( | ||
| TokenizeStage( | ||
| fn_constructor_kwargs=dict( | ||
| model=tokenize_stage_cfg.model_source, | ||
| ), | ||
| map_batches_kwargs=build_cpu_stage_map_kwargs(tokenize_stage_cfg), | ||
| ) | ||
| ) | ||
|
|
||
| # Core stage -- the SGLang engine. | ||
| stages.append( | ||
| SGLangEngineStage( | ||
| fn_constructor_kwargs=dict( | ||
| model=config.model_source, | ||
| engine_kwargs=config.engine_kwargs, | ||
| task_type=config.task_type, | ||
| max_pending_requests=config.max_pending_requests, | ||
| ), | ||
| map_batches_kwargs=dict( | ||
| zero_copy_batch=True, | ||
| # The number of running replicas. This is a deprecated field, but | ||
| # we need to set `max_tasks_in_flight_per_actor` through `compute`, | ||
| # which initiates enough many overlapping UDF calls per actor, to | ||
| # saturate `max_concurrency`. | ||
| compute=ray.data.ActorPoolStrategy( | ||
| **config.get_concurrency(autoscaling_enabled=True), | ||
| max_tasks_in_flight_per_actor=config.experimental.get( | ||
| "max_tasks_in_flight_per_actor", DEFAULT_MAX_TASKS_IN_FLIGHT | ||
| ), | ||
| ), | ||
| # The number of running batches "per actor" in Ray Core level. | ||
| # This is used to make sure we overlap batches to avoid the tail | ||
| # latency of each batch. | ||
| max_concurrency=config.max_concurrent_batches, | ||
| accelerator_type=config.accelerator_type, | ||
| runtime_env=config.runtime_env, | ||
| ), | ||
| ) | ||
| ) | ||
|
|
||
| # Resolve and build DetokenizeStage if enabled | ||
| detokenize_stage_cfg = resolve_stage_config( | ||
| getattr(config, "detokenize_stage", config.detokenize), | ||
| DetokenizeStageConfig, | ||
| processor_defaults, | ||
| ) | ||
| if detokenize_stage_cfg.enabled: | ||
| stages.append( | ||
| DetokenizeStage( | ||
| fn_constructor_kwargs=dict( | ||
| model=detokenize_stage_cfg.model_source, | ||
| ), | ||
| map_batches_kwargs=build_cpu_stage_map_kwargs(detokenize_stage_cfg), | ||
| ) | ||
| ) | ||
|
|
||
| hf_config = transformers.AutoConfig.from_pretrained(config.model_source) | ||
| architecture = getattr(hf_config, "architectures", [DEFAULT_MODEL_ARCHITECTURE])[0] | ||
|
|
||
| telemetry_agent = get_or_create_telemetry_agent() | ||
| telemetry_agent.push_telemetry_report( | ||
| BatchModelTelemetry( | ||
| processor_config_name=type(config).__name__, | ||
| model_architecture=architecture, | ||
| batch_size=config.batch_size, | ||
| accelerator_type=config.accelerator_type or DEFAULT_GPU_TYPE, | ||
| concurrency=config.concurrency, | ||
| task_type=config.task_type, | ||
| tensor_parallel_size=config.engine_kwargs.get("tp_size", 1), | ||
| data_parallel_size=config.engine_kwargs.get("dp_size", 1), | ||
| ) | ||
| ) | ||
|
|
||
| processor = Processor( | ||
| config, | ||
| stages, | ||
| preprocess=preprocess, | ||
| postprocess=postprocess, | ||
| preprocess_map_kwargs=preprocess_map_kwargs, | ||
| postprocess_map_kwargs=postprocess_map_kwargs, | ||
| ) | ||
| return processor | ||
|
|
||
|
|
||
| ProcessorBuilder.register(SGLangEngineProcessorConfig, build_sglang_engine_processor) |
There was a problem hiding this comment.
This file appears to be a direct copy of sglang_engine_proc.py and is not adapted for TensorRT-LLM. All references, from module docstrings and imports to class names (SGLangEngineProcessorConfig), function names (build_sglang_engine_processor), and internal logic, are for SGLang. The entire file needs to be refactored to implement the processor for the TensorRT-LLM engine.
| # We need to rename the `model` to `model_path` for SGLang. | ||
| kwargs["model_path"] = self.model |
There was a problem hiding this comment.
The comment refers to SGLang. More importantly, tensorrt_llm.LLM expects the model path to be passed via the model_dir argument, not model_path.
| # We need to rename the `model` to `model_path` for SGLang. | |
| kwargs["model_path"] = self.model | |
| # We need to rename the `model` to `model_dir` for TensorRT-LLM. | |
| kwargs["model_dir"] = self.model |
| tokenized_prompt = None | ||
|
|
||
| # Prepare sampling parameters. | ||
| if self.task_type == SGLangTaskType.GENERATE: |
| async def _generate_async(self, request: TensorRTLLMEngineRequest) -> Any: | ||
| """Process a single request. | ||
|
|
||
| Args: | ||
| request: The request. | ||
|
|
||
| Returns: | ||
| The output of the request. | ||
| """ | ||
| # TRTLLM api | ||
| # generate_async( | ||
| # inputs: str | List[int] | TextPrompt | TokensPrompt, | ||
| # sampling_params: SamplingParams | None = None, | ||
| # lora_request: LoRARequest | None = None, | ||
| # prompt_adapter_request: PromptAdapterRequest | None = None, | ||
| # streaming: bool = False, | ||
| # kv_cache_retention_config: KvCacheRetentionConfig | None = None, | ||
| # disaggregated_params: DisaggregatedParams | None = None, | ||
| # _postproc_params: PostprocParams | None = None, | ||
| # ) → RequestOutput | ||
| # Send the request to the LLM engine. | ||
| stream = await self.engine.generate_async( | ||
| inputs=request.prompt, | ||
| input_ids=request.prompt_token_ids, | ||
| sampling_params=request.params, | ||
| stream=True, | ||
| ) | ||
|
|
||
| # Consume the stream until the request is finished. | ||
| async for output in stream: | ||
| if output["meta_info"]["finish_reason"] is not None: | ||
| output["prompt"] = request.prompt | ||
| output["prompt_token_ids"] = request.prompt_token_ids | ||
| return output | ||
|
|
||
| raise RuntimeError( | ||
| "[SGLang] The request is not finished. This should not happen. Please report this issue to the Ray team." | ||
| ) |
There was a problem hiding this comment.
The implementation of _generate_async and its interaction with from_trtllm_engine_output seems incorrect for the tensorrt_llm backend.
- The call to
self.engine.generate_asyncon lines 214-219 usesinput_idsandstreamparameters, which are not valid fortensorrt_llm.LLM.generate_async. It should useinputsandstreaming. - The stream consumption logic on lines 222-226 is from the SGLang implementation and will fail, as
tensorrt_llmyieldsRequestOutputobjects, not dictionaries with ameta_infokey. You should check the.finishedattribute of theRequestOutputobject. - The
from_trtllm_engine_outputmethod (lines 57-78) incorrectly tries to accessoutput.promptandoutput.prompt_token_ids, which don't exist on theRequestOutputobject.
A significant refactoring is needed here. Consider passing the request object into from_trtllm_engine_output to provide the necessary prompt information.
| self.engine.shutdown() | ||
|
|
||
|
|
||
| class SGLangEngineStageUDF(StatefulStageUDF): |
There was a problem hiding this comment.
This class and other related definitions seem to be incorrectly named from a copy-paste. SGLangEngineStageUDF should be TensorRTLLMEngineStageUDF. Similar changes are needed for SGLangEngineStage (line 349) and the usage of SGLangTaskType and TypeSGLangTaskType throughout the file (e.g., lines 151, 246, 388).
| class SGLangEngineStageUDF(StatefulStageUDF): | |
| class TensorRTLLMEngineStageUDF(StatefulStageUDF): |
| class SGLangEngineStage(StatefulStage): | ||
| """ | ||
| A stage that runs SGLang engine. | ||
| """ | ||
|
|
||
| fn: Type[StatefulStageUDF] = SGLangEngineStageUDF |
There was a problem hiding this comment.
The class SGLangEngineStage should be renamed to TensorRTLLMEngineStage, and its fn attribute should point to the correctly named UDF class (TensorRTLLMEngineStageUDF).
| class SGLangEngineStage(StatefulStage): | |
| """ | |
| A stage that runs SGLang engine. | |
| """ | |
| fn: Type[StatefulStageUDF] = SGLangEngineStageUDF | |
| class TensorRTLLMEngineStage(StatefulStage): | |
| """ | |
| A stage that runs TensorRT-LLM engine. | |
| """ | |
| fn: Type[StatefulStageUDF] = TensorRTLLMEngineStageUDF |
| task_type = self.fn_constructor_kwargs.get("task_type", SGLangTaskType.GENERATE) | ||
| if task_type == SGLangTaskType.GENERATE: | ||
| ret[ | ||
| "sampling_params" | ||
| ] = "The sampling parameters. See https://docs.sglang.ai/backend/sampling_params.html for details." | ||
| return ret |
There was a problem hiding this comment.
The task type check and the documentation URL are incorrect. They should use TensorRTLLMTaskType and point to TensorRT-LLM documentation.
| task_type = self.fn_constructor_kwargs.get("task_type", SGLangTaskType.GENERATE) | |
| if task_type == SGLangTaskType.GENERATE: | |
| ret[ | |
| "sampling_params" | |
| ] = "The sampling parameters. See https://docs.sglang.ai/backend/sampling_params.html for details." | |
| return ret | |
| task_type = self.fn_constructor_kwargs.get("task_type", TensorRTLLMTaskType.GENERATE) | |
| if task_type == TensorRTLLMTaskType.GENERATE: | |
| ret[ | |
| "sampling_params" | |
| ] = "The sampling parameters. See TensorRT-LLM documentation for details." | |
| return ret |
| @@ -0,0 +1,397 @@ | |||
| """The stage that runs SGLang engine.""" | |||
There was a problem hiding this comment.
This file contains many references to SGLang in docstrings, comments, and log messages (e.g., lines 1, 23, 86, 108, 122, 158, 229, 233, 235, 250, 255, 264, 311, 317, 338, 351). These should be updated to refer to TensorRT-LLM for clarity and correctness.
| """The stage that runs SGLang engine.""" | |
| """The stage that runs the TensorRT-LLM engine.""" |
| raise ImportError( | ||
| "TensorRT LLM is not installed or failed to import. Please run " | ||
| "`pip install sglang[all]` to install required dependencies." | ||
| ) from e |
There was a problem hiding this comment.
The installation instruction in the error message is for sglang. This should be updated to guide the user on how to install tensorrt_llm.
| raise ImportError( | |
| "TensorRT LLM is not installed or failed to import. Please run " | |
| "`pip install sglang[all]` to install required dependencies." | |
| ) from e | |
| raise ImportError( | |
| "TensorRT-LLM is not installed or failed to import. Please follow the " | |
| "instructions at https://github.com/NVIDIA/TensorRT-LLM#installation " | |
| "to install it." | |
| ) from e |
| # The sampling parameters (more details can be seen in https://docs.sglang.ai/backend/sampling_params.html). | ||
| params: Optional[Dict[str, Any]] |
Description
Related issues
Additional information