Skip to content

qwen3.5 rope_utils IndexError: too many indices for tensor of dimension 2 #5772

@wade0604

Description

@wade0604

System Info

package info:
vllm 0.17.0
megatron-core 0.16.0
mbridge 0.15.1
torch 2.10.0+cu128
transformers 5.2.0

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

run_scripts:

export VLLM_USE_V1=0
export CUDA_DEVICE_MAX_CONNECTIONS=1
export VLLM_ALLREDUCE_USE_SYMM_MEM=0
export WORKING_DIR="${PWD}"
export MASTER_PORT="4444"


HF_MODEL_PATH=Qwen/Qwen3.5-35B-A3B
train_path=/xxx
test_path=/xxx


TP=${TP:-2}
PP=${PP:-1}
CP=${CP:-1}
EP=${EP:-8}
ETP=${ETP:-1}
GEN_TP=${GEN_TP:-8}

ALL_OFFLOAD=${ALL_OFFLOAD:-True}

rollout_name="vllm"
project_name=xxx'
exp_name='qwen3_5_35ba3b_megatron'
adv_estimator=grpo


########################### Parameter Arrays ###########################

DATA=(
    data.train_files=${train_path}
    data.val_files=${test_path}
    data.train_batch_size=32
    data.max_prompt_length=2048
    data.max_response_length=2048
    data.truncation='error'
    data.filter_overlong_prompts=True
)

MODEL=(
    actor_rollout_ref.model.path=${HF_MODEL_PATH}
    actor_rollout_ref.model.trust_remote_code=True
    actor_rollout_ref.model.use_remove_padding=False
)

ACTOR=(
    actor_rollout_ref.actor.optim.lr=1e-6
    actor_rollout_ref.actor.ppo_mini_batch_size=32
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=4096
    actor_rollout_ref.actor.use_dynamic_bsz=False
    actor_rollout_ref.actor.use_kl_loss=True
    actor_rollout_ref.actor.kl_loss_coef=0.01
    actor_rollout_ref.actor.kl_loss_type=low_var_kl
    actor_rollout_ref.actor.entropy_coeff=0
    actor_rollout_ref.actor.megatron.use_mbridge=True
    actor_rollout_ref.actor.megatron.vanilla_mbridge=True
    actor_rollout_ref.actor.megatron.use_remove_padding=False
    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${TP}
    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${PP}
    actor_rollout_ref.actor.megatron.context_parallel_size=${CP}
    actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP}
    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP}
    actor_rollout_ref.actor.megatron.param_offload=${ALL_OFFLOAD}
    actor_rollout_ref.actor.megatron.optimizer_offload=${ALL_OFFLOAD}
    actor_rollout_ref.actor.megatron.grad_offload=${ALL_OFFLOAD}
    actor_rollout_ref.actor.megatron.dtype=bfloat16
    ++actor_rollout_ref.actor.megatron.override_transformer_config.attention_backend=auto
    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform
    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full
    +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1
    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_aux_loss_coeff=0.01
    +actor_rollout_ref.actor.megatron.override_transformer_config.moe_z_loss_coeff=0.001
    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1
    +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True
    +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True
    +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True
)

ROLLOUT=(
    actor_rollout_ref.rollout.name=${rollout_name}
    actor_rollout_ref.rollout.tensor_model_parallel_size=${GEN_TP}
    actor_rollout_ref.rollout.data_parallel_size=1
    actor_rollout_ref.rollout.expert_parallel_size=${GEN_TP}
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6
    actor_rollout_ref.rollout.n=4
    actor_rollout_ref.rollout.mode=async
    actor_rollout_ref.rollout.dtype=bfloat16
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=4096
    ++actor_rollout_ref.rollout.engine_kwargs.mm_encoder_tp_mode=data
    ++actor_rollout_ref.rollout.engine_kwargs.mm_processor_cache_type=shm
    ++actor_rollout_ref.rollout.engine_kwargs.reasoning_parser=qwen3
)

REF=(
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=4096
    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${TP}
    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${PP}
    actor_rollout_ref.ref.megatron.context_parallel_size=${CP}
    actor_rollout_ref.ref.megatron.expert_model_parallel_size=${EP}
    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${ETP}
    actor_rollout_ref.ref.megatron.param_offload=${ALL_OFFLOAD}
)

ALGORITHM=(
    algorithm.adv_estimator=${adv_estimator}
    algorithm.use_kl_in_reward=False
)

TRAINER=(
    trainer.critic_warmup=0
    trainer.logger='["console","tensorboard"]'
    trainer.project_name=${project_name}
    trainer.experiment_name=${exp_name}
    trainer.n_gpus_per_node=8
    trainer.nnodes=1
    trainer.log_val_generations=5 
    trainer.ray_wait_register_center_timeout=500
    trainer.save_freq=10
    trainer.val_before_train=False
    trainer.test_freq=10
    trainer.total_epochs=15
)

########################### Launch ###########################

python3 -m verl.trainer.main_ppo \
    --config-path=config \
    --config-name='ppo_megatron_trainer.yaml' \
    "${DATA[@]}" \
    "${ALGORITHM[@]}" \
    "${MODEL[@]}" \
    "${ROLLOUT[@]}" \
    "${ACTOR[@]}" \
    "${REF[@]}" \
    "${TRAINER[@]}" \
    custom_reward_function.path=xxx.py \
    custom_reward_function.name=compute_score \
    "$@"

Expected behavior

tracebacck:

ray.exceptions.RayTaskError(IndexError): ray::WorkerDict.actor_rollout_compute_log_prob() (pid=19099, ip=9.1.16.9,                       
  actor_id=e2723632cd2714dad0f14b9501000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f0d9222ffa0>)                   
    File "/opt/conda/envs/python3.10.13/lib/python3.10/concurrent/futures/_base.py", line 458, in result                                   
      return self.__get_result()                                                                                                           
    File "/opt/conda/envs/python3.10.13/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result                             
      raise self._exception                                                                                                                
    File "/data/oss_bucket_0/zhuye/verl/verl/single_controller/ray/base.py", line 932, in func                                             
      return getattr(self.worker_dict[key], name)(*args, **kwargs)                                                                         
    File "/data/oss_bucket_0/zhuye/verl/verl/single_controller/base/decorator.py", line 427, in inner                                      
      return func(*args, **kwargs)                                                                                                         
    File "/data/oss_bucket_0/zhuye/verl/verl/utils/profiler/performance.py", line 105, in f                                                
      return self.log(decorated_function, *args, **kwargs)                                                                                 
    File "/data/oss_bucket_0/zhuye/verl/verl/utils/profiler/performance.py", line 118, in log                                              
      output = func(*args, **kwargs)                                                                                                       
    File "/data/oss_bucket_0/zhuye/verl/verl/utils/profiler/profile.py", line 173, in wrapper                                              
      return func(self_instance, *args, **kwargs_inner)                                                                                    
    File "/data/oss_bucket_0/zhuye/verl/verl/workers/megatron_workers.py", line 891, in compute_log_prob                                   
      output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora)                            
    File "/data/oss_bucket_0/zhuye/verl/verl/utils/profiler/performance.py", line 105, in f                                                
      return self.log(decorated_function, *args, **kwargs)                                                                                 
    File "/data/oss_bucket_0/zhuye/verl/verl/utils/profiler/performance.py", line 118, in log                                              
      output = func(*args, **kwargs)                                                                                                       
    File "/data/oss_bucket_0/zhuye/verl/verl/workers/actor/megatron_actor.py", line 259, in compute_log_prob                               
      output = self.forward_backward_batch(                                                                                                
    File "/data/oss_bucket_0/zhuye/verl/verl/workers/actor/megatron_actor.py", line 725, in forward_backward_batch                         
      losses_reduced = forward_backward_func(                                                                                              
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 636, in           
  forward_backward_no_pipelining                                                                                                           
      output_tensor, num_tokens = forward_step(                                                                                            
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 423, in           
  forward_step                                                                                                                             
      output_tensor, loss_func = forward_step_func(data_iterator, model)                                                                   
    File "/data/oss_bucket_0/zhuye/verl/verl/workers/actor/megatron_actor.py", line 675, in forward_step                                   
      output = forward_fn(                                                                                                                 
    File "/data/oss_bucket_0/zhuye/verl/verl/models/mcore/model_forward.py", line 132, in model_forward                                    
      output_orig = model(                                                                                                                 
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl         
      return self._call_impl(*args, **kwargs)                                                                                              
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl                 
      return forward_call(*args, **kwargs)                                                                                                 
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/megatron/core/distributed/data_parallel_base.py", line 22, in forward 
      return self.module(*inputs, **kwargs)                                                                                                
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl         
      return self._call_impl(*args, **kwargs)                                                                                              
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl                 
      return forward_call(*args, **kwargs)                                                                                                 
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/megatron/core/transformer/module.py", line 489, in forward            
      outputs = self.module(*inputs, **kwargs)                                                                                             
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl         
      return self._call_impl(*args, **kwargs)                                                                                              
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl                 
      return forward_call(*args, **kwargs)                                                                                                 
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/mbridge/models/qwen3_5/model.py", line 367, in forward                
      output = self.language_model(                                                                                                        
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl         
      return self._call_impl(*args, **kwargs)                                                                                              
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl                 
      return forward_call(*args, **kwargs)                                                                                                 
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/megatron/core/models/gpt/gpt_model.py", line 504, in forward          
      preproc_output = self._preprocess(                                                                                                   
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/megatron/core/models/gpt/gpt_model.py", line 388, in _preprocess      
      rotary_pos_emb = self.rotary_pos_emb(                                                                                                
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl         
      return self._call_impl(*args, **kwargs)                                                                                              
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl                 
      return forward_call(*args, **kwargs)                                                                                                 
    File "/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/mbridge/models/qwen3_vl/rope_utils.py", line 130, in forward          
      seq_expanded = seq[:, :, None, :].float()                                                                                            
  IndexError: too many indices for tensor of dimension 2

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions