Skip to content

Use Transformers helper get_text_config() instead of checking for text_config #17105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,9 +553,8 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
if not hasattr(config, "hidden_size"):
# Support for llama4
config = config.text_config
# Support for llama4
config = config.get_text_config()
# Default: Mixtral.
E = config.num_local_experts
topk = config.num_experts_per_tok
Expand Down
5 changes: 1 addition & 4 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ def test_can_initialize(model_arch):
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides)

if hasattr(hf_config, "text_config"):
text_config: PretrainedConfig = hf_config.text_config
else:
text_config = hf_config
text_config = hf_config.get_text_config()

text_config.update({
"num_layers": 1,
Expand Down
6 changes: 2 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2841,12 +2841,10 @@ def _get_and_verify_dtype(
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
config_dtype = getattr(config.get_text_config(), "torch_dtype", None)

# Fallbacks for multi-modal models if the root config
# Fallback for multi-modal models if the root config
# does not define torch_dtype
if config_dtype is None and hasattr(config, "text_config"):
config_dtype = getattr(config.text_config, "torch_dtype", None)
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None)

Expand Down
21 changes: 12 additions & 9 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,19 +760,22 @@ def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
elif hasattr(config, "thinker_config"):
# This block should be unnecessary after https://github.com/huggingface/transformers/pull/37517
if hasattr(config, "thinker_config"):
# TODO(suyang.fy): Refactor code.
# For Qwen2.5-Omni, change hf_text_config to
# thinker_config.text_config.
return config.thinker_config.text_config
else:
return config

text_config = config.get_text_config()

if text_config is not config:
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(text_config, "num_attention_heads")

return text_config


def try_get_generation_config(
Expand Down
11 changes: 3 additions & 8 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,13 +508,8 @@ def load_model(self) -> None:
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")

# It's necessary to distinguish between the max_position_embeddings
# of VLMs and LLMs.
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = self.model.config.max_position_embeddings
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)
# Use get_text_config() in case of multimodal models
text_config = self.model_config.hf_config.get_text_config()

self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
Expand All @@ -524,7 +519,7 @@ def load_model(self) -> None:
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
max_position_embeddings=text_config.max_position_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)

Expand Down
14 changes: 5 additions & 9 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,14 +724,9 @@ def load_model(self) -> None:
"Bias support in LoRA is not enabled in HPU yet."
assert not self.lora_config.fully_sharded_loras, \
"Fully sharded LoRAs is not enabled in HPU yet."
# It's necessary to distinguish between the
# max_position_embeddings of VLMs and LLMs.
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = (
self.model.config.max_position_embeddings)
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)

# Use get_text_config() in case of multimodal models
text_config = self.model_config.hf_config.get_text_config()

self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
Expand All @@ -741,7 +736,8 @@ def load_model(self) -> None:
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
max_position_embeddings=text_config.
max_position_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)

Expand Down
14 changes: 5 additions & 9 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,14 +1130,9 @@ def load_model(self) -> None:
logger.warning(
"Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# It's necessary to distinguish between the
# max_position_embeddings of VLMs and LLMs.
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = (
self.model.config.max_position_embeddings)
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)

# Use get_text_config() in case of multimodal models
text_config = self.model_config.hf_config.get_text_config()

self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
Expand All @@ -1147,7 +1142,8 @@ def load_model(self) -> None:
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
max_position_embeddings=text_config.
max_position_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
time_after_load = time.perf_counter()
Expand Down