Skip to content

Commit fdf0046

Browse files
authored
[trainer] fix: model engine vlm multi_modal_inputs to NonTensorStack (#4492)
### What does this PR do? Fix RL model engine for VLM. Qwen/Qwen3-VL-30B-A3B-Instruct fsdp vs megatron on geo3k: <img width="386" height="310" alt="image" src="https://github.com/user-attachments/assets/f04e38b7-514a-4792-9806-3ad7964aa797" />
1 parent a0e8e44 commit fdf0046

File tree

17 files changed

+163
-61
lines changed

17 files changed

+163
-61
lines changed

.github/workflows/e2e_sft.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,13 @@ jobs:
146146
- name: Running GSM8K E2E training tests with multiturn and various configs and compare results
147147
run: |
148148
bash tests/special_e2e/sft/test_sft_engine_all.sh
149-
149+
- name: Prepare pokemon-gpt4o-captions dataset
150+
run: |
151+
ray stop --force
152+
python3 examples/data_preprocess/pokemon.py --local_dataset_path ${HOME}/models/hf_data/pokemon-gpt4o-captions
153+
- name: Running Pokemon E2E training tests with multiturn and various configs and compare results
154+
run: |
155+
MODEL_ID=Qwen/Qwen3-VL-2B-Instruct DATASET_DIR=~/data/pokemon-gpt4o-captions VPP_SIZE=null bash tests/special_e2e/sft/test_sft_engine_all.sh
150156
151157
cleanup:
152158
runs-on: ubuntu-latest

tests/single_controller/test_decorator_on_cpu.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ async def async_dp_compute(self, data: DataProto) -> DataProto:
6666
def dp_compute_td(self, data: TensorDict) -> TensorDict:
6767
rank_value = torch.tensor(self.rank, device=data["input"].device, dtype=data["input"].dtype)
6868
data["output"] = data["input"] + self.value + rank_value
69+
position_ids = data.pop("position_ids")
70+
for i, position_id in enumerate(position_ids.unbind(dim=0)):
71+
assert (position_id == torch.arange(4 + rank_value * 2 + i).expand(position_id.shape)).all()
6972
return data
7073

7174

@@ -159,7 +162,16 @@ def test_decorator_dp_compute_td(ray_init_shutdown):
159162

160163
# Prepare input data (size 4, for 2 workers)
161164
input_tensor = torch.arange(4, dtype=torch.float32)
162-
data = TensorDict({"input": input_tensor}, batch_size=[4])
165+
position_ids = torch.nested.as_nested_tensor(
166+
[
167+
torch.arange(4).expand(4, 4),
168+
torch.arange(5).expand(4, 5),
169+
torch.arange(6).expand(4, 6),
170+
torch.arange(7).expand(4, 7),
171+
],
172+
layout=torch.jagged,
173+
)
174+
data = TensorDict({"input": input_tensor, "position_ids": position_ids}, batch_size=[4])
163175

164176
# Call the decorated method
165177
output = worker_group.dp_compute_td(data)

tests/special_e2e/sft/compare_sft_engine_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def compare_results(golden_results, other_result):
3636
grad_norm = other_result[0]["data"]["train/grad_norm"]
3737

3838
torch.testing.assert_close(golden_loss, loss, atol=1e-2, rtol=1e-2)
39-
torch.testing.assert_close(golden_grad_norm, grad_norm, atol=1e-4, rtol=2e-2)
39+
torch.testing.assert_close(golden_grad_norm, grad_norm, atol=1e-4, rtol=3e-2)
4040

4141

4242
if __name__ == "__main__":

tests/special_e2e/sft/run_sft_engine_gsm8k.sh renamed to tests/special_e2e/sft/run_sft_engine.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ else
1313
COMMAND="python ${ENTRYPOINT} trainer.nnodes=${NNODES:-1} trainer.n_gpus_per_node=${NUM_GPUS:-1}"
1414
fi
1515

16-
17-
TRAIN_FILES=~/data/gsm8k_sft/train.parquet
18-
VAL_FILES=~/data/gsm8k_sft/test.parquet
16+
DATASET_DIR=${DATASET_DIR:-~/data/gsm8k_sft}
17+
TRAIN_FILES=${DATASET_DIR}/train.parquet
18+
VAL_FILES=${DATASET_DIR}/test.parquet
1919

2020
backend=${BACKEND:-fsdp}
2121

@@ -71,7 +71,8 @@ MEGATRON_ENGINE_CONFIG="\
7171
engine.tensor_model_parallel_size=${TP_SIZE} \
7272
engine.pipeline_model_parallel_size=${PP_SIZE} \
7373
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
74-
engine.context_parallel_size=${CP_SIZE}"
74+
engine.context_parallel_size=${CP_SIZE}
75+
engine.use_mbridge=True"
7576

7677
if [ "$backend" = "fsdp" ]; then
7778
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
@@ -88,11 +89,11 @@ mkdir -p "${ckpts_home}"
8889
$COMMAND \
8990
data.train_files="${TRAIN_FILES}" \
9091
data.val_files="${VAL_FILES}" \
91-
data.train_batch_size=256 \
92+
data.train_batch_size=128 \
9293
data.pad_mode=${PAD_MODE} \
9394
data.truncation=error \
9495
data.use_dynamic_bsz=True \
95-
data.max_token_len_per_gpu=8192 \
96+
data.max_token_len_per_gpu=2048 \
9697
data.messages_key=messages \
9798
model.path=$MODEL_PATH \
9899
model.use_remove_padding=${USE_REMOVE_PADDING} \

tests/special_e2e/sft/test_sft_engine_all.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,32 @@ rm -rf ~/verl/test/log
55
mkdir -p ~/verl/test/log
66

77
export VERL_FILE_LOGGER_ROOT=~/verl/test/log
8+
VPP_SIZE=${VPP_SIZE:-2}
89

910
# test with single gpu as golden
1011
echo "run with single gpu as golden"
11-
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
12+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine.sh
1213

1314
# test with fsdp 1
1415
echo "run with sp2 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
15-
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
16+
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine.sh
1617

1718
# test with fsdp 1 use_remove_padding and pad_mode no_padding
1819
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False"
19-
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
20+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine.sh
2021

2122

2223
# test with fsdp 2
2324
echo "run with sp2 fsdp_size2 num_gpus8 fsdp_strategy fsdp2"
24-
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
25+
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine.sh
2526

2627
# test with megatron
2728
echo "run with tp2 pp2 vpp2 cp2 num_gpus8"
28-
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=2 CP_SIZE=2 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
29+
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine.sh
2930

3031
# test with cp in ray
3132
echo "run with tp2 pp2 vpp2 cp2 num_gpus8 mode=ray"
32-
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=2 CP_SIZE=2 NUM_GPUS=8 mode=ray bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
33+
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 mode=ray bash tests/special_e2e/sft/run_sft_engine.sh
3334

3435
python3 tests/special_e2e/sft/compare_sft_engine_results.py
3536

tests/test_protocol_v2_on_cpu.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,55 @@ def test_concat_tensordict():
741741
assert output["temp"] == 1.0
742742

743743

744+
def test_chunk_tensordict():
745+
# Qwen-VL 3d position_ids
746+
position_ids = torch.nested.as_nested_tensor(
747+
[
748+
torch.arange(4).expand(4, 4),
749+
torch.arange(5).expand(4, 5),
750+
torch.arange(6).expand(4, 6),
751+
torch.arange(7).expand(4, 7),
752+
],
753+
layout=torch.jagged,
754+
)
755+
input_ids = torch.nested.as_nested_tensor(
756+
[torch.arange(4), torch.arange(5), torch.arange(6), torch.arange(7)], layout=torch.jagged
757+
)
758+
759+
multi_modal_inputs = torch.stack(
760+
[
761+
NonTensorData({"pixel_values": torch.randn(3, 224, 224)}),
762+
NonTensorData(None),
763+
NonTensorData({"pixel_values": torch.randn(3, 128, 128)}),
764+
NonTensorData({"pixel_values": torch.randn(3, 128, 128)}),
765+
]
766+
)
767+
td = tu.get_tensordict(
768+
{
769+
"input_ids": input_ids,
770+
"position_ids": position_ids,
771+
"multi_modal_inputs": multi_modal_inputs,
772+
},
773+
)
774+
assert len(td) == 4
775+
chunks = tu.chunk_tensordict(td, chunks=2)
776+
777+
for i, chunk in enumerate(chunks):
778+
assert len(chunk) == 2
779+
for key, val in chunk.items():
780+
if isinstance(val, torch.Tensor) and val.is_nested:
781+
tensors = td[key].unbind(dim=0)
782+
expected = torch.nested.as_nested_tensor(tensors[i * 2 : (i + 1) * 2], layout=torch.jagged)
783+
assert torch.all(torch.eq(val.values(), expected.values())).item()
784+
else:
785+
expected = td[key][i * 2 : (i + 1) * 2]
786+
for tensor, expect in zip(val, expected, strict=False):
787+
if tensor.data is None:
788+
assert expect is None
789+
else:
790+
assert torch.all(torch.eq(tensor.data["pixel_values"], expect["pixel_values"])).item()
791+
792+
744793
def test_assign_non_tensor_stack_with_nested_lists():
745794
"""Test assign_non_tensor_stack with lists of lists."""
746795
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})

tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from verl.utils.dataset.dataset_utils import DatasetPadMode, SFTTensorCollator
3232
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
33-
from verl.utils.model import extract_multi_modal_inputs_tensordict
33+
from verl.utils.model import extract_multi_modal_inputs
3434

3535

3636
@pytest.mark.parametrize(
@@ -342,8 +342,8 @@ def test_multiturn_sft_vlm_dataset_on_cpu(vlm_data_file):
342342
input_ids = item["input_ids"]
343343
loss_mask = item["loss_mask"]
344344
position_ids = item["position_ids"]
345-
pixel_values = item.get("pixel_values", None)
346-
image_grid_thw = item.get("image_grid_thw", None)
345+
pixel_values = item.get("multi_modal_inputs", {}).get("pixel_values")
346+
image_grid_thw = item.get("multi_modal_inputs", {}).get("image_grid_thw")
347347

348348
assert input_ids.shape == loss_mask.shape, "Shapes of input_ids and loss_mask must be equal"
349349
assert position_ids.dim() == 2, "position_ids must be 2-dimensional"
@@ -425,7 +425,7 @@ def test_multiturn_sft_vlm_dataloader_on_cpu(vlm_data_file):
425425

426426
# 3. verify multi-modal data
427427
td = TensorDict(**batch, batch_size=batch_size)
428-
multi_modal_inputs = extract_multi_modal_inputs_tensordict(td)
428+
multi_modal_inputs = extract_multi_modal_inputs(td["multi_modal_inputs"])
429429
pixel_values = multi_modal_inputs["pixel_values"]
430430
image_grid_thw = multi_modal_inputs["image_grid_thw"]
431431

verl/models/mcore/model_forward.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import torch
1718

1819
from verl.utils.megatron_utils import unwrap_model
1920

@@ -181,7 +182,10 @@ def gptmodel_forward_no_padding(
181182
attention_mask = None
182183
if vision_model:
183184
input_ids_rmpad = input_ids.to_padded_tensor(pad_token_id)
184-
attention_mask = (input_ids_rmpad != pad_token_id).bool()
185+
seqlens_in_batch = input_ids.offsets().diff()
186+
attention_mask = torch.zeros_like(input_ids_rmpad, dtype=torch.bool)
187+
for i, seqlen in enumerate(seqlens_in_batch):
188+
attention_mask[i, :seqlen] = True
185189

186190
output_orig = model(
187191
input_ids=input_ids_rmpad,

verl/single_controller/base/decorator.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
import inspect
1615
from functools import partial, wraps
1716
from types import FunctionType
@@ -20,7 +19,7 @@
2019

2120
from verl.protocol import DataProtoFuture, _padding_size_key
2221
from verl.utils.py_functional import DynamicEnum
23-
from verl.utils.tensordict_utils import concat_tensordict
22+
from verl.utils.tensordict_utils import chunk_tensordict, concat_tensordict
2423
from verl.utils.transferqueue_utils import BatchMeta
2524

2625
# here we add a magic number of avoid user-defined function already have this attribute
@@ -78,14 +77,20 @@ def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
7877
splitted_args = []
7978
for arg in args:
8079
assert isinstance(arg, DataProto | DataProtoFuture | BatchMeta | TensorDict)
81-
chunked_arg = arg.chunk(chunks=chunks)
80+
if isinstance(arg, TensorDict):
81+
chunked_arg = chunk_tensordict(arg, chunks)
82+
else:
83+
chunked_arg = arg.chunk(chunks=chunks)
8284
assert len(chunked_arg) == chunks
8385
splitted_args.append(chunked_arg)
8486

8587
splitted_kwargs = {}
8688
for key, val in kwargs.items():
8789
assert isinstance(val, DataProto | DataProtoFuture | BatchMeta | TensorDict)
88-
chunked_kwarg = val.chunk(chunks=chunks)
90+
if isinstance(val, TensorDict):
91+
chunked_kwarg = chunk_tensordict(val, chunks)
92+
else:
93+
chunked_kwarg = val.chunk(chunks=chunks)
8994
assert len(chunked_kwarg) == chunks
9095
splitted_kwargs[key] = chunked_kwarg
9196

verl/trainer/main_ppo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ def add_critic_worker(self, config):
193193

194194
elif config.critic.strategy == "megatron":
195195
# TODO: switch this to TrainingWorker as well
196-
assert use_legacy_worker_impl != "disable", "Megatron critic only supports legacy worker implementation"
197196
from verl.workers.megatron_workers import CriticWorker
198197

199198
else:

0 commit comments

Comments
 (0)