Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/e2e_sft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,13 @@ jobs:
- name: Running GSM8K E2E training tests with multiturn and various configs and compare results
run: |
bash tests/special_e2e/sft/test_sft_engine_all.sh

- name: Prepare pokemon-gpt4o-captions dataset
run: |
ray stop --force
python3 examples/data_preprocess/pokemon.py --local_dataset_path ${HOME}/models/hf_data/pokemon-gpt4o-captions
- name: Running Pokemon E2E training tests with multiturn and various configs and compare results
run: |
MODEL_ID=Qwen/Qwen3-VL-2B-Instruct DATASET_DIR=~/data/pokemon-gpt4o-captions VPP_SIZE=null bash tests/special_e2e/sft/test_sft_engine_all.sh

cleanup:
runs-on: ubuntu-latest
Expand Down
14 changes: 13 additions & 1 deletion tests/single_controller/test_decorator_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ async def async_dp_compute(self, data: DataProto) -> DataProto:
def dp_compute_td(self, data: TensorDict) -> TensorDict:
rank_value = torch.tensor(self.rank, device=data["input"].device, dtype=data["input"].dtype)
data["output"] = data["input"] + self.value + rank_value
position_ids = data.pop("position_ids")
for i, position_id in enumerate(position_ids.unbind(dim=0)):
assert (position_id == torch.arange(4 + rank_value * 2 + i).expand(position_id.shape)).all()
return data


Expand Down Expand Up @@ -159,7 +162,16 @@ def test_decorator_dp_compute_td(ray_init_shutdown):

# Prepare input data (size 4, for 2 workers)
input_tensor = torch.arange(4, dtype=torch.float32)
data = TensorDict({"input": input_tensor}, batch_size=[4])
position_ids = torch.nested.as_nested_tensor(
[
torch.arange(4).expand(4, 4),
torch.arange(5).expand(4, 5),
torch.arange(6).expand(4, 6),
torch.arange(7).expand(4, 7),
],
layout=torch.jagged,
)
data = TensorDict({"input": input_tensor, "position_ids": position_ids}, batch_size=[4])

# Call the decorated method
output = worker_group.dp_compute_td(data)
Expand Down
2 changes: 1 addition & 1 deletion tests/special_e2e/sft/compare_sft_engine_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def compare_results(golden_results, other_result):
grad_norm = other_result[0]["data"]["train/grad_norm"]

torch.testing.assert_close(golden_loss, loss, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(golden_grad_norm, grad_norm, atol=1e-4, rtol=2e-2)
torch.testing.assert_close(golden_grad_norm, grad_norm, atol=1e-4, rtol=3e-2)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ else
COMMAND="python ${ENTRYPOINT} trainer.nnodes=${NNODES:-1} trainer.n_gpus_per_node=${NUM_GPUS:-1}"
fi


TRAIN_FILES=~/data/gsm8k_sft/train.parquet
VAL_FILES=~/data/gsm8k_sft/test.parquet
DATASET_DIR=${DATASET_DIR:-~/data/gsm8k_sft}
TRAIN_FILES=${DATASET_DIR}/train.parquet
VAL_FILES=${DATASET_DIR}/test.parquet

backend=${BACKEND:-fsdp}

Expand Down Expand Up @@ -71,7 +71,8 @@ MEGATRON_ENGINE_CONFIG="\
engine.tensor_model_parallel_size=${TP_SIZE} \
engine.pipeline_model_parallel_size=${PP_SIZE} \
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
engine.context_parallel_size=${CP_SIZE}"
engine.context_parallel_size=${CP_SIZE}
engine.use_mbridge=True"

if [ "$backend" = "fsdp" ]; then
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
Expand All @@ -88,11 +89,11 @@ mkdir -p "${ckpts_home}"
$COMMAND \
data.train_files="${TRAIN_FILES}" \
data.val_files="${VAL_FILES}" \
data.train_batch_size=256 \
data.train_batch_size=128 \
data.pad_mode=${PAD_MODE} \
data.truncation=error \
data.use_dynamic_bsz=True \
data.max_token_len_per_gpu=8192 \
data.max_token_len_per_gpu=2048 \
data.messages_key=messages \
model.path=$MODEL_PATH \
model.use_remove_padding=${USE_REMOVE_PADDING} \
Expand Down
13 changes: 7 additions & 6 deletions tests/special_e2e/sft/test_sft_engine_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,32 @@ rm -rf ~/verl/test/log
mkdir -p ~/verl/test/log

export VERL_FILE_LOGGER_ROOT=~/verl/test/log
VPP_SIZE=${VPP_SIZE:-2}

# test with single gpu as golden
echo "run with single gpu as golden"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine.sh

# test with fsdp 1
echo "run with sp2 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine.sh

# test with fsdp 1 use_remove_padding and pad_mode no_padding
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False"
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine.sh


# test with fsdp 2
echo "run with sp2 fsdp_size2 num_gpus8 fsdp_strategy fsdp2"
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine.sh

# test with megatron
echo "run with tp2 pp2 vpp2 cp2 num_gpus8"
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=2 CP_SIZE=2 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine.sh

# test with cp in ray
echo "run with tp2 pp2 vpp2 cp2 num_gpus8 mode=ray"
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=2 CP_SIZE=2 NUM_GPUS=8 mode=ray bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 mode=ray bash tests/special_e2e/sft/run_sft_engine.sh

python3 tests/special_e2e/sft/compare_sft_engine_results.py

Expand Down
49 changes: 49 additions & 0 deletions tests/test_protocol_v2_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,55 @@ def test_concat_tensordict():
assert output["temp"] == 1.0


def test_chunk_tensordict():
# Qwen-VL 3d position_ids
position_ids = torch.nested.as_nested_tensor(
[
torch.arange(4).expand(4, 4),
torch.arange(5).expand(4, 5),
torch.arange(6).expand(4, 6),
torch.arange(7).expand(4, 7),
],
layout=torch.jagged,
)
input_ids = torch.nested.as_nested_tensor(
[torch.arange(4), torch.arange(5), torch.arange(6), torch.arange(7)], layout=torch.jagged
)

multi_modal_inputs = torch.stack(
[
NonTensorData({"pixel_values": torch.randn(3, 224, 224)}),
NonTensorData(None),
NonTensorData({"pixel_values": torch.randn(3, 128, 128)}),
NonTensorData({"pixel_values": torch.randn(3, 128, 128)}),
]
)
td = tu.get_tensordict(
{
"input_ids": input_ids,
"position_ids": position_ids,
"multi_modal_inputs": multi_modal_inputs,
},
)
assert len(td) == 4
chunks = tu.chunk_tensordict(td, chunks=2)

for i, chunk in enumerate(chunks):
assert len(chunk) == 2
for key, val in chunk.items():
if isinstance(val, torch.Tensor) and val.is_nested:
tensors = td[key].unbind(dim=0)
expected = torch.nested.as_nested_tensor(tensors[i * 2 : (i + 1) * 2], layout=torch.jagged)
assert torch.all(torch.eq(val.values(), expected.values())).item()
else:
expected = td[key][i * 2 : (i + 1) * 2]
for tensor, expect in zip(val, expected, strict=False):
if tensor.data is None:
assert expect is None
else:
assert torch.all(torch.eq(tensor.data["pixel_values"], expect["pixel_values"])).item()


def test_assign_non_tensor_stack_with_nested_lists():
"""Test assign_non_tensor_stack with lists of lists."""
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})
Expand Down
8 changes: 4 additions & 4 deletions tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from verl.utils.dataset.dataset_utils import DatasetPadMode, SFTTensorCollator
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
from verl.utils.model import extract_multi_modal_inputs_tensordict
from verl.utils.model import extract_multi_modal_inputs


@pytest.mark.parametrize(
Expand Down Expand Up @@ -342,8 +342,8 @@ def test_multiturn_sft_vlm_dataset_on_cpu(vlm_data_file):
input_ids = item["input_ids"]
loss_mask = item["loss_mask"]
position_ids = item["position_ids"]
pixel_values = item.get("pixel_values", None)
image_grid_thw = item.get("image_grid_thw", None)
pixel_values = item.get("multi_modal_inputs", {}).get("pixel_values")
image_grid_thw = item.get("multi_modal_inputs", {}).get("image_grid_thw")

assert input_ids.shape == loss_mask.shape, "Shapes of input_ids and loss_mask must be equal"
assert position_ids.dim() == 2, "position_ids must be 2-dimensional"
Expand Down Expand Up @@ -425,7 +425,7 @@ def test_multiturn_sft_vlm_dataloader_on_cpu(vlm_data_file):

# 3. verify multi-modal data
td = TensorDict(**batch, batch_size=batch_size)
multi_modal_inputs = extract_multi_modal_inputs_tensordict(td)
multi_modal_inputs = extract_multi_modal_inputs(td["multi_modal_inputs"])
pixel_values = multi_modal_inputs["pixel_values"]
image_grid_thw = multi_modal_inputs["image_grid_thw"]

Expand Down
6 changes: 5 additions & 1 deletion verl/models/mcore/model_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from verl.utils.megatron_utils import unwrap_model

Expand Down Expand Up @@ -181,7 +182,10 @@ def gptmodel_forward_no_padding(
attention_mask = None
if vision_model:
input_ids_rmpad = input_ids.to_padded_tensor(pad_token_id)
attention_mask = (input_ids_rmpad != pad_token_id).bool()
seqlens_in_batch = input_ids.offsets().diff()
attention_mask = torch.zeros_like(input_ids_rmpad, dtype=torch.bool)
for i, seqlen in enumerate(seqlens_in_batch):
attention_mask[i, :seqlen] = True

output_orig = model(
input_ids=input_ids_rmpad,
Expand Down
13 changes: 9 additions & 4 deletions verl/single_controller/base/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from functools import partial, wraps
from types import FunctionType
Expand All @@ -20,7 +19,7 @@

from verl.protocol import DataProtoFuture, _padding_size_key
from verl.utils.py_functional import DynamicEnum
from verl.utils.tensordict_utils import concat_tensordict
from verl.utils.tensordict_utils import chunk_tensordict, concat_tensordict
from verl.utils.transferqueue_utils import BatchMeta

# here we add a magic number of avoid user-defined function already have this attribute
Expand Down Expand Up @@ -78,14 +77,20 @@ def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
splitted_args = []
for arg in args:
assert isinstance(arg, DataProto | DataProtoFuture | BatchMeta | TensorDict)
chunked_arg = arg.chunk(chunks=chunks)
if isinstance(arg, TensorDict):
chunked_arg = chunk_tensordict(arg, chunks)
else:
chunked_arg = arg.chunk(chunks=chunks)
assert len(chunked_arg) == chunks
splitted_args.append(chunked_arg)

splitted_kwargs = {}
for key, val in kwargs.items():
assert isinstance(val, DataProto | DataProtoFuture | BatchMeta | TensorDict)
chunked_kwarg = val.chunk(chunks=chunks)
if isinstance(val, TensorDict):
chunked_kwarg = chunk_tensordict(val, chunks)
else:
chunked_kwarg = val.chunk(chunks=chunks)
assert len(chunked_kwarg) == chunks
splitted_kwargs[key] = chunked_kwarg

Expand Down
1 change: 0 additions & 1 deletion verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def add_critic_worker(self, config):

elif config.critic.strategy == "megatron":
# TODO: switch this to TrainingWorker as well
assert use_legacy_worker_impl != "disable", "Megatron critic only supports legacy worker implementation"
from verl.workers.megatron_workers import CriticWorker

else:
Expand Down
23 changes: 17 additions & 6 deletions verl/trainer/sft_trainer_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,21 @@ def _build_engine(self):
def _build_dataset(self):
config = self.config
tokenizer = self.model_config.tokenizer
processor = self.model_config.processor
train_dataset = create_sft_dataset(
config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1)
config.data.train_files,
config.data,
tokenizer,
processor=processor,
max_samples=config.data.get("train_max_samples", -1),
)
if config.data.val_files:
val_dataset = create_sft_dataset(
config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1)
config.data.val_files,
config.data,
tokenizer,
processor=processor,
max_samples=config.data.get("val_max_samples", -1),
)
else:
val_dataset = None
Expand Down Expand Up @@ -157,7 +166,7 @@ def _build_dataloader(self):
sampler=self.train_sampler,
collate_fn=self.collate_fn,
num_workers=8,
pin_memory=True,
pin_memory=False,
drop_last=True,
pin_memory_device=device_name,
)
Expand All @@ -172,7 +181,7 @@ def _build_dataloader(self):
sampler=self.val_sampler,
collate_fn=self.collate_fn,
num_workers=8,
pin_memory=True,
pin_memory=False,
drop_last=True,
pin_memory_device=device_name,
)
Expand Down Expand Up @@ -327,7 +336,7 @@ def main(config):
run_sft(config)


def create_sft_dataset(data_paths, data_config, tokenizer, max_samples=-1):
def create_sft_dataset(data_paths, data_config, tokenizer, processor, max_samples=-1):
"""Create a dataset."""
# build dataset
# First check if a custom dataset class is specified
Expand All @@ -340,7 +349,9 @@ def create_sft_dataset(data_paths, data_config, tokenizer, max_samples=-1):
dataset_cls = MultiTurnSFTDataset

# Create datasets based on the selected class
dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config, max_samples=max_samples)
dataset = dataset_cls(
parquet_files=data_paths, tokenizer=tokenizer, config=data_config, processor=processor, max_samples=max_samples
)
return dataset


Expand Down
11 changes: 5 additions & 6 deletions verl/utils/dataset/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,15 @@ def collate_variable_batch(self, batch: list[dict[str, any]]) -> dict[str, any]:

final_batch = {}

tensor_keys = [key for key in batch[0].keys() if isinstance(batch[0][key], torch.Tensor)]
multi_modal_keys = {"pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"}
tensor_keys = set().union(*(d.keys() for d in batch))

# Handle tensor values by creating a NestedTensor.
for key in tensor_keys:
if key in multi_modal_keys:
tensors = [NonTensorData(item.get(key)) for item in batch]
final_batch[key] = torch.stack(tensors, dim=0)
else:
if isinstance(batch[0][key], torch.Tensor):
tensors = [item[key] for item in batch]
final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
else:
tensors = [NonTensorData(item.get(key)) for item in batch]
final_batch[key] = torch.stack(tensors, dim=0)
Comment on lines +68 to +73
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic isinstance(batch[0][key], torch.Tensor) is not robust and can lead to a KeyError. The tensor_keys set is a union of keys from all samples in the batch. If a key (e.g., multi_modal_inputs) is present in some samples but not in batch[0], accessing batch[0][key] will cause a crash. This is likely to happen when a batch mixes vision-language and text-only data.

Checking for the key's existence in batch[0] before checking its type will prevent this crash.

Suggested change
if isinstance(batch[0][key], torch.Tensor):
tensors = [item[key] for item in batch]
final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
else:
tensors = [NonTensorData(item.get(key)) for item in batch]
final_batch[key] = torch.stack(tensors, dim=0)
if key in batch[0] and isinstance(batch[0][key], torch.Tensor):
tensors = [item[key] for item in batch]
final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
else:
tensors = [NonTensorData(item.get(key)) for item in batch]
final_batch[key] = torch.stack(tensors, dim=0)


return final_batch
12 changes: 8 additions & 4 deletions verl/utils/dataset/multiturn_sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,15 @@ def __getitem__(self, item):
else:
raise ValueError(f"Unknown truncation method {self.truncation}")

return {
res = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"loss_mask": loss_mask,
**multi_modal_inputs,
}
if len(multi_modal_inputs) > 0:
res["multi_modal_inputs"] = multi_modal_inputs
return res
elif self.pad_mode == DatasetPadMode.NO_PADDING:
# truncate input_ids if it is longer than max_length
if len(input_ids) > self.max_length:
Expand All @@ -360,12 +362,14 @@ def __getitem__(self, item):
position_ids = position_ids[..., : self.max_length]

# return nested tensor with out padding
return {
res = {
"input_ids": input_ids,
"position_ids": position_ids,
"loss_mask": loss_mask,
**multi_modal_inputs,
}
if len(multi_modal_inputs) > 0:
res["multi_modal_inputs"] = multi_modal_inputs
return res
else:
raise ValueError(f"Unknown pad mode {self.pad_mode}")

Expand Down
Loading
Loading