Skip to content

Commit ba4a2e1

Browse files
hmellorliuzijing2014
authored andcommitted
Use Transformers helper get_text_config() instead of checking for text_config (vllm-project#17105)
Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Zijing Liu <[email protected]>
1 parent 715c07f commit ba4a2e1

File tree

7 files changed

+30
-46
lines changed

7 files changed

+30
-46
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,9 +553,8 @@ def main(args: argparse.Namespace):
553553
intermediate_size = config.moe_intermediate_size
554554
shard_intermediate_size = 2 * intermediate_size // args.tp_size
555555
else:
556-
if not hasattr(config, "hidden_size"):
557-
# Support for llama4
558-
config = config.text_config
556+
# Support for llama4
557+
config = config.get_text_config()
559558
# Default: Mixtral.
560559
E = config.num_local_experts
561560
topk = config.num_experts_per_tok

tests/models/test_initialization.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,7 @@ def test_can_initialize(model_arch):
2424
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
2525
hf_config.update(model_info.hf_overrides)
2626

27-
if hasattr(hf_config, "text_config"):
28-
text_config: PretrainedConfig = hf_config.text_config
29-
else:
30-
text_config = hf_config
27+
text_config = hf_config.get_text_config()
3128

3229
text_config.update({
3330
"num_layers": 1,

vllm/config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,12 +2841,10 @@ def _get_and_verify_dtype(
28412841
) -> torch.dtype:
28422842
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
28432843
# because config.torch_dtype can be None.
2844-
config_dtype = getattr(config, "torch_dtype", None)
2844+
config_dtype = getattr(config.get_text_config(), "torch_dtype", None)
28452845

2846-
# Fallbacks for multi-modal models if the root config
2846+
# Fallback for multi-modal models if the root config
28472847
# does not define torch_dtype
2848-
if config_dtype is None and hasattr(config, "text_config"):
2849-
config_dtype = getattr(config.text_config, "torch_dtype", None)
28502848
if config_dtype is None and hasattr(config, "vision_config"):
28512849
config_dtype = getattr(config.vision_config, "torch_dtype", None)
28522850

vllm/transformers_utils/config.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -760,19 +760,22 @@ def get_hf_text_config(config: PretrainedConfig):
760760
"""Get the "sub" config relevant to llm for multi modal models.
761761
No op for pure text models.
762762
"""
763-
if hasattr(config, "text_config"):
764-
# The code operates under the assumption that text_config should have
765-
# `num_attention_heads` (among others). Assert here to fail early
766-
# if transformers config doesn't align with this assumption.
767-
assert hasattr(config.text_config, "num_attention_heads")
768-
return config.text_config
769-
elif hasattr(config, "thinker_config"):
763+
# This block should be unnecessary after https://github.com/huggingface/transformers/pull/37517
764+
if hasattr(config, "thinker_config"):
770765
# TODO(suyang.fy): Refactor code.
771766
# For Qwen2.5-Omni, change hf_text_config to
772767
# thinker_config.text_config.
773768
return config.thinker_config.text_config
774-
else:
775-
return config
769+
770+
text_config = config.get_text_config()
771+
772+
if text_config is not config:
773+
# The code operates under the assumption that text_config should have
774+
# `num_attention_heads` (among others). Assert here to fail early
775+
# if transformers config doesn't align with this assumption.
776+
assert hasattr(text_config, "num_attention_heads")
777+
778+
return text_config
776779

777780

778781
def try_get_generation_config(

vllm/worker/cpu_model_runner.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -508,13 +508,8 @@ def load_model(self) -> None:
508508
logger.warning("Regarding multimodal models, vLLM currently "
509509
"only supports adding LoRA to language model.")
510510

511-
# It's necessary to distinguish between the max_position_embeddings
512-
# of VLMs and LLMs.
513-
if hasattr(self.model.config, "max_position_embeddings"):
514-
max_pos_embeddings = self.model.config.max_position_embeddings
515-
else:
516-
max_pos_embeddings = (
517-
self.model.config.text_config.max_position_embeddings)
511+
# Use get_text_config() in case of multimodal models
512+
text_config = self.model_config.hf_config.get_text_config()
518513

519514
self.lora_manager = LRUCacheWorkerLoRAManager(
520515
self.scheduler_config.max_num_seqs,
@@ -524,7 +519,7 @@ def load_model(self) -> None:
524519
self.device,
525520
self.model.embedding_modules,
526521
self.model.embedding_padding_modules,
527-
max_position_embeddings=max_pos_embeddings,
522+
max_position_embeddings=text_config.max_position_embeddings,
528523
)
529524
self.model = self.lora_manager.create_lora_manager(self.model)
530525

vllm/worker/hpu_model_runner.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -724,14 +724,9 @@ def load_model(self) -> None:
724724
"Bias support in LoRA is not enabled in HPU yet."
725725
assert not self.lora_config.fully_sharded_loras, \
726726
"Fully sharded LoRAs is not enabled in HPU yet."
727-
# It's necessary to distinguish between the
728-
# max_position_embeddings of VLMs and LLMs.
729-
if hasattr(self.model.config, "max_position_embeddings"):
730-
max_pos_embeddings = (
731-
self.model.config.max_position_embeddings)
732-
else:
733-
max_pos_embeddings = (
734-
self.model.config.text_config.max_position_embeddings)
727+
728+
# Use get_text_config() in case of multimodal models
729+
text_config = self.model_config.hf_config.get_text_config()
735730

736731
self.lora_manager = LRUCacheWorkerLoRAManager(
737732
self.scheduler_config.max_num_seqs,
@@ -741,7 +736,8 @@ def load_model(self) -> None:
741736
self.device,
742737
self.model.embedding_modules,
743738
self.model.embedding_padding_modules,
744-
max_position_embeddings=max_pos_embeddings,
739+
max_position_embeddings=text_config.
740+
max_position_embeddings,
745741
)
746742
self.model = self.lora_manager.create_lora_manager(self.model)
747743

vllm/worker/model_runner.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,14 +1130,9 @@ def load_model(self) -> None:
11301130
logger.warning(
11311131
"Regarding multimodal models, vLLM currently "
11321132
"only supports adding LoRA to language model.")
1133-
# It's necessary to distinguish between the
1134-
# max_position_embeddings of VLMs and LLMs.
1135-
if hasattr(self.model.config, "max_position_embeddings"):
1136-
max_pos_embeddings = (
1137-
self.model.config.max_position_embeddings)
1138-
else:
1139-
max_pos_embeddings = (
1140-
self.model.config.text_config.max_position_embeddings)
1133+
1134+
# Use get_text_config() in case of multimodal models
1135+
text_config = self.model_config.hf_config.get_text_config()
11411136

11421137
self.lora_manager = LRUCacheWorkerLoRAManager(
11431138
self.scheduler_config.max_num_seqs,
@@ -1147,7 +1142,8 @@ def load_model(self) -> None:
11471142
self.device,
11481143
self.model.embedding_modules,
11491144
self.model.embedding_padding_modules,
1150-
max_position_embeddings=max_pos_embeddings,
1145+
max_position_embeddings=text_config.
1146+
max_position_embeddings,
11511147
)
11521148
self.model = self.lora_manager.create_lora_manager(self.model)
11531149
time_after_load = time.perf_counter()

0 commit comments

Comments
 (0)