Skip to content

[llm][WIP] Support TensorRT-LLM in ray data llm #61409

Draft
owenowenisme wants to merge 1 commit intoray-project:masterfrom
owenowenisme:data-llm/add-trtllm-support
Draft

[llm][WIP] Support TensorRT-LLM in ray data llm #61409
owenowenisme wants to merge 1 commit intoray-project:masterfrom
owenowenisme:data-llm/add-trtllm-support

Conversation

@owenowenisme
Copy link
Member

Thank you for contributing to Ray! 🚀
Please review the Ray Contribution Guide before opening a pull request.

⚠️ Remove these instructions before submitting your PR.

💡 Tip: Mark as draft if you want early feedback, or ready for review when it's complete.

Description

Briefly describe what this PR accomplishes and why it's needed.

Related issues

Link related issues: "Fixes #1234", "Closes #1234", or "Related to #1234".

Additional information

Optional: Add implementation details, API changes, usage examples, screenshots, etc.

Signed-off-by: You-Cheng Lin <mses010108@gmail.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for TensorRT-LLM in Ray Data LLM. However, the implementation appears to be a work-in-progress with significant copy-paste errors from the existing SGLang integration. Both new files, trtllm_engine_proc.py and trtllm_engine_stage.py, contain numerous incorrect references to SGLang in class names, comments, logic, and documentation links. The core logic for interacting with the TensorRT-LLM engine is also flawed. I've provided detailed comments to highlight these critical issues that need to be addressed.

Comment on lines +1 to +238
"""The SGLang engine processor."""

import logging
from typing import Any, Dict, Optional

import transformers
from pydantic import Field, root_validator

import ray
from ray.data.block import UserDefinedFunction
from ray.llm._internal.batch.constants import SGLangTaskType, TypeSGLangTaskType
from ray.llm._internal.batch.observability.usage_telemetry.usage import (
BatchModelTelemetry,
TelemetryAgent,
get_or_create_telemetry_agent,
)
from ray.llm._internal.batch.processor.base import (
DEFAULT_MAX_TASKS_IN_FLIGHT,
OfflineProcessorConfig,
Processor,
ProcessorBuilder,
)
from ray.llm._internal.batch.processor.utils import (
build_cpu_stage_map_kwargs,
get_value_or_fallback,
)
from ray.llm._internal.batch.stages import (
ChatTemplateStage,
DetokenizeStage,
SGLangEngineStage,
TokenizeStage,
)
from ray.llm._internal.batch.stages.configs import (
ChatTemplateStageConfig,
DetokenizeStageConfig,
TokenizerStageConfig,
resolve_stage_config,
)
from ray.llm._internal.common.observability.telemetry_utils import DEFAULT_GPU_TYPE

logger = logging.getLogger(__name__)


DEFAULT_MODEL_ARCHITECTURE = "UNKNOWN_MODEL_ARCHITECTURE"


class SGLangEngineProcessorConfig(OfflineProcessorConfig):
"""The configuration for the SGLang engine processor."""

# SGLang stage configurations.
engine_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description="The kwargs to pass to the SGLang engine. See "
"https://docs.sglang.ai/backend/server_arguments.html "
"for more details.",
)
task_type: TypeSGLangTaskType = Field(
default=SGLangTaskType.GENERATE,
description="The task type to use. If not specified, will use "
"'generate' by default.",
)

@root_validator(pre=True)
def validate_task_type(cls, values):
task_type = values.get("task_type", SGLangTaskType.GENERATE)
if task_type not in SGLangTaskType.values():
raise ValueError(f"Invalid task type: {task_type}")

engine_kwargs = values.get("engine_kwargs", {})
engine_kwargs_task = engine_kwargs.get("task", "")
if engine_kwargs_task != task_type:
logger.warning(
"The task set in engine kwargs (%s) is different from the "
"stage (%s). Overriding the task in engine kwargs to %s.",
engine_kwargs_task,
task_type,
task_type,
)
engine_kwargs["task"] = task_type
values["engine_kwargs"] = engine_kwargs
return values


def build_sglang_engine_processor(
config: SGLangEngineProcessorConfig,
chat_template_kwargs: Optional[Dict[str, Any]] = None,
preprocess: Optional[UserDefinedFunction] = None,
postprocess: Optional[UserDefinedFunction] = None,
preprocess_map_kwargs: Optional[Dict[str, Any]] = None,
postprocess_map_kwargs: Optional[Dict[str, Any]] = None,
telemetry_agent: Optional[TelemetryAgent] = None,
) -> Processor:
"""Construct a Processor and configure stages.

Args:
config: The configuration for the processor.
chat_template_kwargs: The optional kwargs to pass to apply_chat_template.
preprocess: An optional lambda function that takes a row (dict) as input
and returns a preprocessed row (dict). The output row must contain the
required fields for the following processing stages.
postprocess: An optional lambda function that takes a row (dict) as input
and returns a postprocessed row (dict).
preprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
preprocess stage (e.g., num_cpus, memory, concurrency).
postprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
postprocess stage (e.g., num_cpus, memory, concurrency).
telemetry_agent: An optional telemetry agent for collecting usage telemetry.

Returns:
The constructed processor.
"""
ray.init(runtime_env=config.runtime_env, ignore_reinit_error=True)

stages = []

# Prepare processor defaults for merging into stage configs
processor_defaults = {
"batch_size": config.batch_size,
"concurrency": config.concurrency,
"runtime_env": config.runtime_env,
"model_source": config.model_source,
}

# Resolve and build ChatTemplateStage if enabled
chat_template_stage_cfg = resolve_stage_config(
config.chat_template_stage,
ChatTemplateStageConfig,
processor_defaults,
)
if chat_template_stage_cfg.enabled:
stages.append(
ChatTemplateStage(
fn_constructor_kwargs=dict(
model=chat_template_stage_cfg.model_source,
chat_template=get_value_or_fallback(
chat_template_stage_cfg.chat_template, config.chat_template
),
chat_template_kwargs=get_value_or_fallback(
chat_template_stage_cfg.chat_template_kwargs,
chat_template_kwargs,
),
),
map_batches_kwargs=build_cpu_stage_map_kwargs(chat_template_stage_cfg),
)
)

# Resolve and build TokenizeStage if enabled
tokenize_stage_cfg = resolve_stage_config(
getattr(config, "tokenize_stage", config.tokenize),
TokenizerStageConfig,
processor_defaults,
)
if tokenize_stage_cfg.enabled:
stages.append(
TokenizeStage(
fn_constructor_kwargs=dict(
model=tokenize_stage_cfg.model_source,
),
map_batches_kwargs=build_cpu_stage_map_kwargs(tokenize_stage_cfg),
)
)

# Core stage -- the SGLang engine.
stages.append(
SGLangEngineStage(
fn_constructor_kwargs=dict(
model=config.model_source,
engine_kwargs=config.engine_kwargs,
task_type=config.task_type,
max_pending_requests=config.max_pending_requests,
),
map_batches_kwargs=dict(
zero_copy_batch=True,
# The number of running replicas. This is a deprecated field, but
# we need to set `max_tasks_in_flight_per_actor` through `compute`,
# which initiates enough many overlapping UDF calls per actor, to
# saturate `max_concurrency`.
compute=ray.data.ActorPoolStrategy(
**config.get_concurrency(autoscaling_enabled=True),
max_tasks_in_flight_per_actor=config.experimental.get(
"max_tasks_in_flight_per_actor", DEFAULT_MAX_TASKS_IN_FLIGHT
),
),
# The number of running batches "per actor" in Ray Core level.
# This is used to make sure we overlap batches to avoid the tail
# latency of each batch.
max_concurrency=config.max_concurrent_batches,
accelerator_type=config.accelerator_type,
runtime_env=config.runtime_env,
),
)
)

# Resolve and build DetokenizeStage if enabled
detokenize_stage_cfg = resolve_stage_config(
getattr(config, "detokenize_stage", config.detokenize),
DetokenizeStageConfig,
processor_defaults,
)
if detokenize_stage_cfg.enabled:
stages.append(
DetokenizeStage(
fn_constructor_kwargs=dict(
model=detokenize_stage_cfg.model_source,
),
map_batches_kwargs=build_cpu_stage_map_kwargs(detokenize_stage_cfg),
)
)

hf_config = transformers.AutoConfig.from_pretrained(config.model_source)
architecture = getattr(hf_config, "architectures", [DEFAULT_MODEL_ARCHITECTURE])[0]

telemetry_agent = get_or_create_telemetry_agent()
telemetry_agent.push_telemetry_report(
BatchModelTelemetry(
processor_config_name=type(config).__name__,
model_architecture=architecture,
batch_size=config.batch_size,
accelerator_type=config.accelerator_type or DEFAULT_GPU_TYPE,
concurrency=config.concurrency,
task_type=config.task_type,
tensor_parallel_size=config.engine_kwargs.get("tp_size", 1),
data_parallel_size=config.engine_kwargs.get("dp_size", 1),
)
)

processor = Processor(
config,
stages,
preprocess=preprocess,
postprocess=postprocess,
preprocess_map_kwargs=preprocess_map_kwargs,
postprocess_map_kwargs=postprocess_map_kwargs,
)
return processor


ProcessorBuilder.register(SGLangEngineProcessorConfig, build_sglang_engine_processor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This file appears to be a direct copy of sglang_engine_proc.py and is not adapted for TensorRT-LLM. All references, from module docstrings and imports to class names (SGLangEngineProcessorConfig), function names (build_sglang_engine_processor), and internal logic, are for SGLang. The entire file needs to be refactored to implement the processor for the TensorRT-LLM engine.

Comment on lines +105 to +106
# We need to rename the `model` to `model_path` for SGLang.
kwargs["model_path"] = self.model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The comment refers to SGLang. More importantly, tensorrt_llm.LLM expects the model path to be passed via the model_dir argument, not model_path.

Suggested change
# We need to rename the `model` to `model_path` for SGLang.
kwargs["model_path"] = self.model
# We need to rename the `model` to `model_dir` for TensorRT-LLM.
kwargs["model_dir"] = self.model

tokenized_prompt = None

# Prepare sampling parameters.
if self.task_type == SGLangTaskType.GENERATE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This should check against TensorRTLLMTaskType.GENERATE instead of SGLangTaskType.GENERATE.

Suggested change
if self.task_type == SGLangTaskType.GENERATE:
if self.task_type == TensorRTLLMTaskType.GENERATE:

Comment on lines +193 to +230
async def _generate_async(self, request: TensorRTLLMEngineRequest) -> Any:
"""Process a single request.

Args:
request: The request.

Returns:
The output of the request.
"""
# TRTLLM api
# generate_async(
# inputs: str | List[int] | TextPrompt | TokensPrompt,
# sampling_params: SamplingParams | None = None,
# lora_request: LoRARequest | None = None,
# prompt_adapter_request: PromptAdapterRequest | None = None,
# streaming: bool = False,
# kv_cache_retention_config: KvCacheRetentionConfig | None = None,
# disaggregated_params: DisaggregatedParams | None = None,
# _postproc_params: PostprocParams | None = None,
# ) → RequestOutput
# Send the request to the LLM engine.
stream = await self.engine.generate_async(
inputs=request.prompt,
input_ids=request.prompt_token_ids,
sampling_params=request.params,
stream=True,
)

# Consume the stream until the request is finished.
async for output in stream:
if output["meta_info"]["finish_reason"] is not None:
output["prompt"] = request.prompt
output["prompt_token_ids"] = request.prompt_token_ids
return output

raise RuntimeError(
"[SGLang] The request is not finished. This should not happen. Please report this issue to the Ray team."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of _generate_async and its interaction with from_trtllm_engine_output seems incorrect for the tensorrt_llm backend.

  1. The call to self.engine.generate_async on lines 214-219 uses input_ids and stream parameters, which are not valid for tensorrt_llm.LLM.generate_async. It should use inputs and streaming.
  2. The stream consumption logic on lines 222-226 is from the SGLang implementation and will fail, as tensorrt_llm yields RequestOutput objects, not dictionaries with a meta_info key. You should check the .finished attribute of the RequestOutput object.
  3. The from_trtllm_engine_output method (lines 57-78) incorrectly tries to access output.prompt and output.prompt_token_ids, which don't exist on the RequestOutput object.

A significant refactoring is needed here. Consider passing the request object into from_trtllm_engine_output to provide the necessary prompt information.

self.engine.shutdown()


class SGLangEngineStageUDF(StatefulStageUDF):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This class and other related definitions seem to be incorrectly named from a copy-paste. SGLangEngineStageUDF should be TensorRTLLMEngineStageUDF. Similar changes are needed for SGLangEngineStage (line 349) and the usage of SGLangTaskType and TypeSGLangTaskType throughout the file (e.g., lines 151, 246, 388).

Suggested change
class SGLangEngineStageUDF(StatefulStageUDF):
class TensorRTLLMEngineStageUDF(StatefulStageUDF):

Comment on lines +349 to +354
class SGLangEngineStage(StatefulStage):
"""
A stage that runs SGLang engine.
"""

fn: Type[StatefulStageUDF] = SGLangEngineStageUDF
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The class SGLangEngineStage should be renamed to TensorRTLLMEngineStage, and its fn attribute should point to the correctly named UDF class (TensorRTLLMEngineStageUDF).

Suggested change
class SGLangEngineStage(StatefulStage):
"""
A stage that runs SGLang engine.
"""
fn: Type[StatefulStageUDF] = SGLangEngineStageUDF
class TensorRTLLMEngineStage(StatefulStage):
"""
A stage that runs TensorRT-LLM engine.
"""
fn: Type[StatefulStageUDF] = TensorRTLLMEngineStageUDF

Comment on lines +388 to +393
task_type = self.fn_constructor_kwargs.get("task_type", SGLangTaskType.GENERATE)
if task_type == SGLangTaskType.GENERATE:
ret[
"sampling_params"
] = "The sampling parameters. See https://docs.sglang.ai/backend/sampling_params.html for details."
return ret
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The task type check and the documentation URL are incorrect. They should use TensorRTLLMTaskType and point to TensorRT-LLM documentation.

Suggested change
task_type = self.fn_constructor_kwargs.get("task_type", SGLangTaskType.GENERATE)
if task_type == SGLangTaskType.GENERATE:
ret[
"sampling_params"
] = "The sampling parameters. See https://docs.sglang.ai/backend/sampling_params.html for details."
return ret
task_type = self.fn_constructor_kwargs.get("task_type", TensorRTLLMTaskType.GENERATE)
if task_type == TensorRTLLMTaskType.GENERATE:
ret[
"sampling_params"
] = "The sampling parameters. See TensorRT-LLM documentation for details."
return ret

@@ -0,0 +1,397 @@
"""The stage that runs SGLang engine."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This file contains many references to SGLang in docstrings, comments, and log messages (e.g., lines 1, 23, 86, 108, 122, 158, 229, 233, 235, 250, 255, 264, 311, 317, 338, 351). These should be updated to refer to TensorRT-LLM for clarity and correctness.

Suggested change
"""The stage that runs SGLang engine."""
"""The stage that runs the TensorRT-LLM engine."""

Comment on lines +117 to +120
raise ImportError(
"TensorRT LLM is not installed or failed to import. Please run "
"`pip install sglang[all]` to install required dependencies."
) from e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The installation instruction in the error message is for sglang. This should be updated to guide the user on how to install tensorrt_llm.

Suggested change
raise ImportError(
"TensorRT LLM is not installed or failed to import. Please run "
"`pip install sglang[all]` to install required dependencies."
) from e
raise ImportError(
"TensorRT-LLM is not installed or failed to import. Please follow the "
"instructions at https://github.com/NVIDIA/TensorRT-LLM#installation "
"to install it."
) from e

Comment on lines +33 to +34
# The sampling parameters (more details can be seen in https://docs.sglang.ai/backend/sampling_params.html).
params: Optional[Dict[str, Any]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The documentation link for sampling parameters points to sglang.ai. This should be updated to point to the relevant TensorRT-LLM documentation. This also applies to the link on line 392.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Ray fails to serialize self-reference objects

1 participant