-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling #2953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
zhaochenyang20
merged 17 commits into
verl-project:main
from
yaof20:truncated_importance_sampling
Aug 26, 2025
Merged
[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling #2953
Changes from 9 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
5e2181b
add truncated importance sampling
yaof20 3d98e74
Update core_algos.py
zdhNarsil cd03fd6
add check for rollout_prob
yaof20 6d8a9e1
Tiny fix
yaof20 5b49a5b
disable TIS by default in actor.yaml
yaof20 4b5d04f
disable calculate_log_probs in rollout.yaml
yaof20 cca83bb
add example scripts & dict key check
yaof20 ec987b3
add blog link
yaof20 259edc8
delete deprecated code change
yaof20 aaf4511
remove redundant argument
yaof20 abea330
added doc, following naming convention, updated example, updated data…
LiyuanLucasLiu 3fa967e
updated example script
LiyuanLucasLiu cb0686e
Merge branch 'main' into truncated_importance_sampling
yaof20 114145d
updated docstring and format
LiyuanLucasLiu 3ceb77c
updated docstring and format
LiyuanLucasLiu 3a55325
Merge remote-tracking branch 'feng/truncated_importance_sampling' int…
LiyuanLucasLiu 38d2391
formatting
LiyuanLucasLiu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| #!/usr/bin/env bash | ||
| set -xeuo pipefail | ||
|
|
||
| project_name='DAPO' | ||
| exp_name='DAPO-Qwen2.5-32B-TIS' # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl | ||
|
|
||
| adv_estimator=grpo | ||
|
|
||
| use_kl_in_reward=False | ||
| kl_coef=0.0 | ||
| use_kl_loss=False | ||
| kl_loss_coef=0.0 | ||
|
|
||
| clip_ratio_low=0.2 | ||
| clip_ratio_high=0.28 | ||
|
|
||
| max_prompt_length=$((1024 * 2)) | ||
| max_response_length=$((1024 * 20)) | ||
| enable_overlong_buffer=True | ||
| overlong_buffer_len=$((1024 * 4)) | ||
| overlong_penalty_factor=1.0 | ||
|
|
||
| loss_agg_mode="token-mean" | ||
|
|
||
| enable_filter_groups=True | ||
| filter_groups_metric=acc | ||
| max_num_gen_batches=10 | ||
| train_prompt_bsz=512 | ||
| gen_prompt_bsz=$((train_prompt_bsz * 3)) | ||
| n_resp_per_prompt=16 | ||
| train_prompt_mini_bsz=32 | ||
|
|
||
| # Ray | ||
| RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} | ||
| WORKING_DIR=${WORKING_DIR:-"${PWD}"} | ||
| RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} | ||
| NNODES=${NNODES:-16} | ||
| # Paths | ||
| RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} | ||
| MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} | ||
| CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} | ||
| TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} | ||
| TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} | ||
|
|
||
| # Algorithm | ||
| temperature=1.0 | ||
| top_p=1.0 | ||
| top_k=-1 # 0 for HF rollout, -1 for vLLM rollout | ||
| val_top_p=0.7 | ||
|
|
||
| # Performance Related Parameter | ||
| sp_size=8 | ||
| use_dynamic_bsz=True | ||
| actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) | ||
| infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) | ||
| offload=True | ||
| gen_tp=4 | ||
|
|
||
|
|
||
| # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl | ||
|
|
||
| # Please note that server mode(agent loop) hasn't return rollout_log_probs for now. | ||
| # so currently, server mode is not supported for TIS. | ||
|
|
||
| # To turn on TIS, you need to set the following parameters: | ||
| # 1. rollout.calculate_log_probs=True | ||
| # 2. rollout.imp_ratio_cap > 0 (the value can be tuned) | ||
|
|
||
| ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ | ||
| --working-dir "${WORKING_DIR}" \ | ||
| -- python3 -m recipe.dapo.main_dapo \ | ||
| data.train_files="${TRAIN_FILE}" \ | ||
| data.val_files="${TEST_FILE}" \ | ||
| data.prompt_key=prompt \ | ||
| data.truncation='left' \ | ||
| data.max_prompt_length=${max_prompt_length} \ | ||
| data.max_response_length=${max_response_length} \ | ||
| data.gen_batch_size=${gen_prompt_bsz} \ | ||
| data.train_batch_size=${train_prompt_bsz} \ | ||
| actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ | ||
| algorithm.adv_estimator=${adv_estimator} \ | ||
| algorithm.use_kl_in_reward=${use_kl_in_reward} \ | ||
| algorithm.kl_ctrl.kl_coef=${kl_coef} \ | ||
| actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ | ||
| actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ | ||
| actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ | ||
| actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ | ||
| actor_rollout_ref.actor.clip_ratio_c=10.0 \ | ||
| algorithm.filter_groups.enable=${enable_filter_groups} \ | ||
| algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ | ||
| algorithm.filter_groups.metric=${filter_groups_metric} \ | ||
| actor_rollout_ref.model.use_remove_padding=True \ | ||
| actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ | ||
| actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ | ||
| actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ | ||
| actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ | ||
| actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ | ||
| actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ | ||
| actor_rollout_ref.model.path="${MODEL_PATH}" \ | ||
| actor_rollout_ref.model.enable_gradient_checkpointing=True \ | ||
| actor_rollout_ref.actor.optim.lr=1e-6 \ | ||
| actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ | ||
| actor_rollout_ref.actor.optim.weight_decay=0.1 \ | ||
| actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ | ||
| actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ | ||
| actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ | ||
| actor_rollout_ref.actor.entropy_coeff=0 \ | ||
| actor_rollout_ref.actor.grad_clip=1.0 \ | ||
| actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ | ||
| actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ | ||
| actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ | ||
| actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ | ||
| actor_rollout_ref.rollout.enable_chunked_prefill=True \ | ||
| actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ | ||
| actor_rollout_ref.rollout.temperature=${temperature} \ | ||
| actor_rollout_ref.rollout.top_p=${top_p} \ | ||
| actor_rollout_ref.rollout.top_k="${top_k}" \ | ||
| actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ | ||
| actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ | ||
| actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ | ||
| actor_rollout_ref.rollout.val_kwargs.do_sample=True \ | ||
| actor_rollout_ref.rollout.val_kwargs.n=1 \ | ||
| actor_rollout_ref.rollout.name=vllm \ | ||
| actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ | ||
| actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ | ||
| actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ | ||
| reward_model.reward_manager=dapo \ | ||
| reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ | ||
| reward_model.overlong_buffer.len=${overlong_buffer_len} \ | ||
| reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ | ||
| trainer.logger='["console","wandb"]' \ | ||
| trainer.project_name="${project_name}" \ | ||
| trainer.experiment_name="${exp_name}" \ | ||
| trainer.n_gpus_per_node=8 \ | ||
| trainer.nnodes="${NNODES}" \ | ||
| trainer.val_before_train=True \ | ||
| trainer.test_freq=5 \ | ||
| trainer.save_freq=5 \ | ||
| trainer.total_epochs=1 \ | ||
| trainer.default_local_dir="${CKPTS_DIR}" \ | ||
| trainer.resume_mode=auto \ | ||
| rollout.calculate_log_probs=True \ | ||
| +rollout.imp_ratio_cap=2.0 # remember to turn on calculate_log_probs=True first, and set imp_ratio_cap > 0. The value can be tuned. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -376,6 +376,8 @@ def update_policy(self, data: DataProto): | |
| ] | ||
| if self.config.use_kl_loss: | ||
| select_keys.append("ref_log_prob") | ||
| if self.config.imp_ratio_cap > 0 and "rollout_log_probs" in data.batch.keys(): | ||
| select_keys.append("rollout_log_probs") | ||
|
|
||
| has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() | ||
| non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] | ||
|
|
@@ -405,6 +407,7 @@ def update_policy(self, data: DataProto): | |
| model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} | ||
| response_mask = model_inputs["response_mask"] | ||
| old_log_prob = model_inputs["old_log_probs"] | ||
| rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.imp_ratio_cap > 0 else None | ||
| advantages = model_inputs["advantages"] | ||
|
|
||
| entropy_coeff = self.config.entropy_coeff | ||
|
|
@@ -435,6 +438,8 @@ def update_policy(self, data: DataProto): | |
| response_mask=response_mask, | ||
| loss_agg_mode=loss_agg_mode, | ||
| config=self.config, | ||
| rollout_log_probs=rollout_log_probs, | ||
| imp_ratio_cap=self.config.imp_ratio_cap, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to pass it as it's included in the config already |
||
| ) | ||
|
|
||
| if entropy_coeff != 0: | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Server mode(agent loop) hasn't return
rollout_log_probsfor now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I have add a check here before adding
rollout_log_probs.