[trainer] fix: model engine vlm multi_modal_inputs to NonTensorStack#4492
[trainer] fix: model engine vlm multi_modal_inputs to NonTensorStack#4492wuxibin89 merged 9 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses an issue with handling multi_modal_inputs for VLM model engines by correctly transposing them into a NonTensorStack. It also includes a good refactoring in verl/workers/engine_workers.py to centralize device placement logic by moving inference results to the CPU and removing redundant .cpu() calls. My review includes one suggestion in verl/protocol.py to improve the memory efficiency of key collection, which is an important consideration for large-scale models.
8358145 to
954033c
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces several important fixes and refactorings to improve support for Vision Language Models (VLMs), particularly around handling multi_modal_inputs and variable-length sequences. The core change is to group multi-modal data under a single multi_modal_inputs key, which is a cleaner data structure. The PR correctly disables pin_memory in dataloaders to prevent crashes with NestedTensors, and adds a necessary workaround for chunking TensorDicts containing 3D jagged tensors. My review identifies one critical issue in the updated data collation logic that could lead to a KeyError when processing mixed batches of VLM and text-only data. A code suggestion is provided to fix this.
| if isinstance(batch[0][key], torch.Tensor): | ||
| tensors = [item[key] for item in batch] | ||
| final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged) | ||
| else: | ||
| tensors = [NonTensorData(item.get(key)) for item in batch] | ||
| final_batch[key] = torch.stack(tensors, dim=0) |
There was a problem hiding this comment.
The logic isinstance(batch[0][key], torch.Tensor) is not robust and can lead to a KeyError. The tensor_keys set is a union of keys from all samples in the batch. If a key (e.g., multi_modal_inputs) is present in some samples but not in batch[0], accessing batch[0][key] will cause a crash. This is likely to happen when a batch mixes vision-language and text-only data.
Checking for the key's existence in batch[0] before checking its type will prevent this crash.
| if isinstance(batch[0][key], torch.Tensor): | |
| tensors = [item[key] for item in batch] | |
| final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged) | |
| else: | |
| tensors = [NonTensorData(item.get(key)) for item in batch] | |
| final_batch[key] = torch.stack(tensors, dim=0) | |
| if key in batch[0] and isinstance(batch[0][key], torch.Tensor): | |
| tensors = [item[key] for item in batch] | |
| final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged) | |
| else: | |
| tensors = [NonTensorData(item.get(key)) for item in batch] | |
| final_batch[key] = torch.stack(tensors, dim=0) |
…erl-project#4492) ### What does this PR do? Fix RL model engine for VLM. Qwen/Qwen3-VL-30B-A3B-Instruct fsdp vs megatron on geo3k: <img width="386" height="310" alt="image" src="https://github.com/user-attachments/assets/f04e38b7-514a-4792-9806-3ad7964aa797" />
…erl-project#4492) ### What does this PR do? Fix RL model engine for VLM. Qwen/Qwen3-VL-30B-A3B-Instruct fsdp vs megatron on geo3k: <img width="386" height="310" alt="image" src="https://github.com/user-attachments/assets/f04e38b7-514a-4792-9806-3ad7964aa797" />
What does this PR do?
Fix RL model engine for VLM.
Qwen/Qwen3-VL-30B-A3B-Instruct fsdp vs megatron on geo3k:
