Skip to content

Commit 455e44c

Browse files
[fsdp,megatron,vllm,trainer,algo] feat: On-Policy Distillation (#5041)
# What does this PR do? Adds on-policy distillation support across FSDP and Megatron backends. Collaboration with @wuxibin89, including design guidance, restructuring, and added support for parallelism. Supports: - FSDP and Megatron engines - top-k distillation loss and KL estimator distillation losses - Supervised and policy-gradient-style updates - Teacher logprobs computation using a vLLM teacher server - LLM and VLM distillation - FSDP sequence parallel - Megatron context parallel and tensor parallel ## Losses 1. top-k distillation loss: **forward** KL estimated using top-k logits **from teacher**. 2. KL estimator distillation losses: **reverse** KL estimated using only the log prob for the sampled token via the same estimators used by the reference model (e.g., k1, k3) ## Updates 1. Supervised: distillation loss is directly backpropagated, as in https://arxiv.org/abs/2306.13649 2. Policy gradient: negative distillation loss is used as a reward, as in https://thinkingmachines.ai/blog/on-policy-distillation/ # Test - LLM distillation with FSDP: `examples/on_policy_distillation_trainer/run_qwen_gsmk8k.sh` - VLM distillation with FSDP: `examples/on_policy_distillation_trainer/run_qwen3_vl_geo3k.sh` - LLM distillation with megatron: `examples/on_policy_distillation_trainer/run_qwen_gsmk8k_megatron.sh`. ## Main results ### LLM Distillation These experiments compare 3 training runs with student model Qwen2.5-0.5B using `examples/on_policy_distillation_trainer/run_qwen_gsmk8k.sh`: 1. Forward top-k KL with Qwen2.5-3B-Instruct teacher (gold) 2. Forward top-k KL with Qwen2.5-7B-Instruct teacher (green) 3. k3 estimator KL with Qwen2.5-7B-Instruct teacher (red) #### GSM8K eval acc <img width="969" height="641" alt="image" src="https://github.com/user-attachments/assets/19c1bee7-b688-4d24-a41e-4426761a26f1" /> #### GSM8K train acc <img width="972" height="639" alt="image" src="https://github.com/user-attachments/assets/8f932649-5d45-4964-9d59-18a3706004d5" /> #### Distillation loss <img width="963" height="633" alt="image" src="https://github.com/user-attachments/assets/609f817b-2247-42df-86a8-5e07d637ea7c" /> ### VLM Distillation - Data: Geometry3K - Student: Qwen3-VL-2B-Instruct - Teacher: Qwen3-VL-4B-Instruct - OPD algo: k1 KL estimator as reward with policy gradient loss #### Geo3K eval acc <img width="967" height="640" alt="image" src="https://github.com/user-attachments/assets/e511fedb-8cf1-4576-b214-e992984c7550" /> #### Geo3K train acc <img width="963" height="630" alt="image" src="https://github.com/user-attachments/assets/0eb0e281-218d-4283-825e-2d0b1aa095d4" /> #### Distillation loss <img width="964" height="633" alt="image" src="https://github.com/user-attachments/assets/7b539c60-2af8-4d30-a9d3-3cb4cf485847" /> ## LLM Distillation: Top-k training stability Clamping the top-k forward KL loss was needed for training stability. These experiments compare 3 types of clamping: 1. No clamping (grey) 2. Clamping the distillation loss to a maximum value of 10 (blue) 3. Clamping the student and teacher log probs to a minimum value of -10 (gold) ### Distillation loss <img width="947" height="633" alt="image" src="https://github.com/user-attachments/assets/e90815f5-f745-4a07-bc41-dbb3eb16b1dc" /> ### GSM8K eval acc <img width="971" height="639" alt="image" src="https://github.com/user-attachments/assets/d5cd1833-7e66-40f3-b0a3-fae6a9bc0d7b" /> ### GSM8K train acc <img width="958" height="636" alt="image" src="https://github.com/user-attachments/assets/53742e8d-cd41-4483-943e-cea2b5f0c27b" /> ## LLM Distillation: Policy-gradient results While the VLM results in this PR use the k1 KL estimator with policy gradient updates, all LLM distillation results outside of this section rely on supervised updates. LLM distillation with policy gradient updates are validated in this section: 1. Forward top-k KL with supervised update (green) 2. k1 estimator KL with policy gradient update (purple) 3. k3 estimator KL with supervised update (red) While purple seems best, it also is generating responses that exceed the maximum response length of 512. ### Distillation loss <img width="971" height="627" alt="image" src="https://github.com/user-attachments/assets/208b8625-23d7-46d5-982d-6ab1b5049b21" /> ### GSM8K eval acc <img width="976" height="632" alt="image" src="https://github.com/user-attachments/assets/dd46c6d3-651a-45cb-8bc5-75a765a9a38e" /> ### GSM8K train acc <img width="969" height="634" alt="image" src="https://github.com/user-attachments/assets/f2db7e16-5426-4fa2-b017-fb78b58b8dd4" /> ### Response length <img width="979" height="640" alt="image" src="https://github.com/user-attachments/assets/f61fadd7-d1a8-4110-ba26-a5a2735f8107" /> ## LLM Distillation: Megatron To verify parity of megatron engine with FSDP, these experiments compare 3 training runs with student model Qwen2.5-0.5B: 1. Forward top-k KL with Qwen2.5-7B-Instruct teacher + clamping log probs to minimum value of -10.0 (teal) 2. Forward top-k KL with Qwen2.5-3B-Instruct teacher + clamping log probs to minimum value of -10.0 (red) 3. Forward top-k KL with Qwen2.5-3B-Instruct teacher + clamping loss to maximum value of 10.0 (blue) 4. k3 estimator reverse KL with Qwen2.5-7B-Instruct teacher + clamping loss to maximum value of 10.0 (green) The solid line uses megatron engine with TP=2, the dotted line uses FSDP. ### GSM8K Eval Acc. <img width="1656" height="839" alt="image" src="https://github.com/user-attachments/assets/e43a31a9-f012-452f-b606-99ede49a5fce" /> ### GSM8K Train Acc. <img width="1653" height="838" alt="image" src="https://github.com/user-attachments/assets/609b5917-dbb4-4a72-bbb8-acd4a7f4224a" /> ### Distillation Loss <img width="1651" height="846" alt="image" src="https://github.com/user-attachments/assets/6842295a-aa9a-4b11-9153-b9f83130e352" /> ### Grad Norm <img width="988" height="636" alt="image" src="https://github.com/user-attachments/assets/5883e67c-71e3-4e65-87e3-79e0a2527757" /> ## LLM Distillation: Note on reverse KL Initially, this PR included top-k reverse KL and top-k Jensen-Shannon divergences (JSD interpolates between forward and reverse KL). For the student distribution $q$ and teacher distribution $p$, the top-k reverse KL is given by $$ KL_{\text{top-}k}(q||p) = \sum_i \bf{1}(q_i\in \text{top-}k)q_i\log\frac{q_i}{p_i}. $$ Unfortunately, this was unstable. The reason is because one way to make this loss small is to make $q_i$ as small as possible for all $q_i \in \text{top}-k$. This can be seen from the logs tracking the amount of mass captured in the top-$k$ probabilities: <img width="1118" height="1286" alt="image" src="https://github.com/user-attachments/assets/6e335263-7f4d-48d1-96f4-181f89b24e21" /> ## LLM Distillation: Ablation: performance with more lenient parser Note that the only loss used is the distillation loss (no rewards for correctness on GSM8K). Any increase in the logged rewards=GSM8k accuracy are an indirect result of minimizing the distillation loss. The reason that the base model has Pass@1~=0 is because the default GSM8k answer formatting (`#### 42`) is OOD for the model. The base model is answering the questions correctly, but using incorrect formatting, so none of the answers can be parsed. The base model can be evaluated using a reward function that is more lenient on formatting by adding the following to the script: ```bash ... reward_model.reward_manager=remote \ custom_reward_function.path=tests/experimental/reward_loop/reward_fn.py \ custom_reward_function.name=compute_score_math_verify \ trainer.val_only=True ``` The results are: ```bash (TaskRunner pid=904198) ("Initial validation metrics: {'val-aux/openai/gsm8k/reward/mean@1': " (TaskRunner pid=904198) "np.float64(0.31766489764973466), 'val-core/openai/gsm8k/acc/mean@1': " (TaskRunner pid=904198) "np.float64(0.31766489764973466), 'val-aux/num_turns/min': np.int32(2), " (TaskRunner pid=904198) "'val-aux/num_turns/max': np.int32(2), 'val-aux/num_turns/mean': " (TaskRunner pid=904198) 'np.float64(2.0)}') ``` # Design & Code Changes - Teacher workers are used in the agent loop, similar to the generative reward model: after a student worker finishes its rollout, the teacher worker obtains logprobs - In the initial version of this PR (#4897), requests were submitted to the `vLLMHttpServer` via the `v1/completions` endpoint, which does not support multi-modal data. While `v1/chat/completions` does support multi-modal inputs, text must be passed as raw text instead of token IDs, preventing exact scoring of student generations since `student gen IDs -> student gen text -> teacher input IDs via v1/chat/completions tokenization` will not always give `student gen IDs == teacher input IDs` (https://vllm.ai/blog/agent-lightning). - This PR instead follows a path similar to how rollout replicas directly call the `generate` method on the `vLLMHttpServer`. This enables multi-modal inputs while representing text as token IDs. Requests to the teacher server now call the newly-added `compute_logprobs` method of `vLLMHttpServer`. --------- Co-authored-by: wuxibin <wuxibin@bytedance.com>
1 parent a4f53df commit 455e44c

37 files changed

+2432
-68
lines changed

.github/workflows/gpu_unit_tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ jobs:
117117
- name: Testing LinearCrossEntropyTP Correctness, Computation Time and Memory Consumption
118118
run: |
119119
LOW_MEMORY=True torchrun --standalone --nnodes=1 --nproc-per-node=8 tests/utils/test_special_linear_cross_entropy_tp.py
120+
- name: Testing Megatron KL Loss TP Correctness
121+
run: |
122+
torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/utils/test_special_megatron_kl_loss_tp.py
120123
- name: Testing FSDP2 actor functionality
121124
run: |
122125
torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/workers/actor/test_special_dp_actor.py
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#!/usr/bin/env bash
2+
set -xeuo pipefail
3+
4+
############################ Quick Config ############################
5+
6+
ROLLOUT_NAME="vllm" # sglang or vllm
7+
8+
FAMILY="Qwen"
9+
STUDENT_MODEL=Qwen3-VL-2B-Instruct
10+
TEACHER_MODEL=Qwen3-VL-4B-Instruct
11+
12+
# USE_POLICY_GRADIENT=False
13+
# DISTILLATION_LOSS_MODE="k3"
14+
# DISTILLATION_LOSS_MODE="forward_kl_topk"
15+
# USE_FUSED_KERNELS=False
16+
17+
USE_POLICY_GRADIENT=True
18+
DISTILLATION_LOSS_MODE="k1"
19+
USE_FUSED_KERNELS=True
20+
21+
DISTILLATION_LOSS_MAX_CLAMP=10.0
22+
DISTILLATION_LOG_PROB_MIN_CLAMP=-10.0
23+
24+
PROJECT_NAME='verl_on_policy_distillation_example_geo3k'
25+
26+
MAX_PROMPT=1024
27+
MAX_RESPONSE_LENGTH=2048
28+
MAX_NUM_TOKENS=$(( MAX_PROMPT + MAX_RESPONSE_LENGTH + 1 ))
29+
TRAIN_PROMPT_BSZ=128
30+
STUDENT_MICRO_BATCH_SIZE_PER_GPU=1
31+
STUDENT_MAX_TOKEN_LEN_PER_GPU=$(( STUDENT_MICRO_BATCH_SIZE_PER_GPU * (MAX_PROMPT + MAX_RESPONSE_LENGTH) ))
32+
USE_DYNAMIC_BSZ=False
33+
34+
STUDENT_WORLD_SIZE=4
35+
36+
TEACHER_RESOURCE_POOL=True
37+
TEACHER_WORLD_SIZE=4
38+
39+
SP=1
40+
41+
EXP_NAME="fsdp/student-${STUDENT_MODEL}/teacher-${TEACHER_MODEL}/loss-${DISTILLATION_LOSS_MODE}/pg-${USE_POLICY_GRADIENT}"
42+
43+
ENFORCE_EAGER=False # true for faster debugging
44+
45+
############################ Paths ############################
46+
47+
geo3k_train_path=$DATA_PATH/geo3k/train.parquet
48+
geo3k_test_path=$DATA_PATH/geo3k/test.parquet
49+
50+
TRAIN_FILES="['$geo3k_train_path']"
51+
TEST_FILES="['$geo3k_test_path']"
52+
53+
############################ Parameter Groups ############################
54+
55+
DATA=(
56+
data.train_files="$TRAIN_FILES"
57+
data.val_files="$TEST_FILES"
58+
data.max_prompt_length=$MAX_PROMPT
59+
data.max_response_length=$MAX_RESPONSE_LENGTH
60+
data.train_batch_size=$TRAIN_PROMPT_BSZ
61+
data.filter_overlong_prompts=True
62+
data.truncation='error'
63+
data.shuffle=False
64+
data.image_key=images
65+
)
66+
67+
MODEL=(
68+
actor_rollout_ref.model.path="${FAMILY}/${STUDENT_MODEL}"
69+
actor_rollout_ref.model.enable_gradient_checkpointing=True
70+
actor_rollout_ref.model.use_remove_padding=True
71+
actor_rollout_ref.model.use_fused_kernels=$USE_FUSED_KERNELS
72+
actor_rollout_ref.actor.use_torch_compile=True
73+
actor_rollout_ref.rollout.enforce_eager=$ENFORCE_EAGER
74+
)
75+
76+
DISTILLATION=(
77+
distillation.enabled=True
78+
distillation.num_workers=8
79+
distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
80+
distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
81+
distillation.teacher_model.nnodes=1
82+
distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
83+
distillation.teacher_model.inference.tensor_model_parallel_size=1
84+
distillation.teacher_model.inference.name=$ROLLOUT_NAME
85+
distillation.teacher_model.inference.gpu_memory_utilization=0.8
86+
distillation.teacher_model.inference.enforce_eager=$ENFORCE_EAGER
87+
distillation.teacher_model.inference.max_model_len=$MAX_NUM_TOKENS
88+
distillation.teacher_model.inference.max_num_batched_tokens=$MAX_NUM_TOKENS
89+
distillation.teacher_model.inference.max_num_seqs=$MAX_NUM_TOKENS
90+
+distillation.teacher_model.inference.engine_kwargs.vllm.disable_mm_preprocessor_cache=True
91+
distillation.distillation_loss.loss_mode=$DISTILLATION_LOSS_MODE
92+
distillation.distillation_loss.topk=64
93+
distillation.distillation_loss.use_task_rewards=False
94+
distillation.distillation_loss.use_policy_gradient=$USE_POLICY_GRADIENT
95+
distillation.distillation_loss.loss_max_clamp=$DISTILLATION_LOSS_MAX_CLAMP
96+
distillation.distillation_loss.log_prob_min_clamp=$DISTILLATION_LOG_PROB_MIN_CLAMP
97+
)
98+
99+
STUDENT=(
100+
actor_rollout_ref.actor.optim.lr=1e-6
101+
actor_rollout_ref.actor.ppo_mini_batch_size=$TRAIN_PROMPT_BSZ
102+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$STUDENT_MICRO_BATCH_SIZE_PER_GPU
103+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$STUDENT_MAX_TOKEN_LEN_PER_GPU
104+
actor_rollout_ref.actor.use_dynamic_bsz=$USE_DYNAMIC_BSZ
105+
actor_rollout_ref.actor.fsdp_config.param_offload=True
106+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True
107+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1
108+
)
109+
110+
ROLLOUT=(
111+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$STUDENT_MICRO_BATCH_SIZE_PER_GPU
112+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$STUDENT_MAX_TOKEN_LEN_PER_GPU
113+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=$USE_DYNAMIC_BSZ
114+
actor_rollout_ref.rollout.tensor_model_parallel_size=1
115+
actor_rollout_ref.rollout.name=$ROLLOUT_NAME
116+
actor_rollout_ref.rollout.gpu_memory_utilization=0.8
117+
actor_rollout_ref.rollout.calculate_log_probs=False
118+
actor_rollout_ref.rollout.max_model_len=$MAX_NUM_TOKENS
119+
actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_TOKENS
120+
actor_rollout_ref.rollout.max_num_seqs=$MAX_NUM_TOKENS
121+
actor_rollout_ref.rollout.n=1
122+
+actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True
123+
)
124+
125+
ALGORITHM=(
126+
algorithm.adv_estimator=grpo
127+
algorithm.use_kl_in_reward=False
128+
)
129+
130+
TRAINER=(
131+
trainer.logger='["console","wandb"]'
132+
trainer.project_name=$PROJECT_NAME
133+
trainer.experiment_name=$EXP_NAME
134+
trainer.n_gpus_per_node=$STUDENT_WORLD_SIZE
135+
trainer.nnodes=1
136+
trainer.save_freq=200
137+
trainer.test_freq=5
138+
trainer.total_epochs=15
139+
trainer.val_before_train=True
140+
trainer.use_legacy_worker_impl=disable
141+
trainer.resume_mode=disable
142+
trainer.log_val_generations=5
143+
)
144+
145+
146+
147+
############################ Launch ############################
148+
149+
python3 -m verl.trainer.main_ppo \
150+
--config-path=config \
151+
--config-name='ppo_trainer.yaml' \
152+
"${DATA[@]}" \
153+
"${ALGORITHM[@]}" \
154+
"${MODEL[@]}" \
155+
"${DISTILLATION[@]}" \
156+
"${ROLLOUT[@]}" \
157+
"${STUDENT[@]}" \
158+
"${TRAINER[@]}" \
159+
"$@"
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#!/usr/bin/env bash
2+
set -xeuo pipefail
3+
4+
############################ Quick Config ############################
5+
6+
ROLLOUT_NAME="vllm" # sglang or vllm
7+
8+
FAMILY="Qwen"
9+
STUDENT_MODEL=Qwen2.5-0.5B
10+
TEACHER_MODEL=Qwen2.5-0.5B
11+
12+
# USE_POLICY_GRADIENT=False
13+
# DISTILLATION_LOSS_MODE="k3"
14+
# DISTILLATION_LOSS_MODE="forward_kl_topk"
15+
# USE_FUSED_KERNELS=False
16+
17+
USE_POLICY_GRADIENT=True
18+
DISTILLATION_LOSS_MODE="k1"
19+
USE_FUSED_KERNELS=False
20+
21+
DISTILLATION_LOSS_MAX_CLAMP=10.0
22+
DISTILLATION_LOG_PROB_MIN_CLAMP=-10.0
23+
24+
PROJECT_NAME='verl_on_policy_distillation_example_gsm8k'
25+
26+
MAX_PROMPT=256
27+
MAX_RESPONSE_LENGTH=512
28+
MAX_NUM_TOKENS=$(( MAX_PROMPT + MAX_RESPONSE_LENGTH + 1 ))
29+
TRAIN_PROMPT_BSZ=128
30+
STUDENT_MICRO_BATCH_SIZE_PER_GPU=2
31+
STUDENT_MAX_TOKEN_LEN_PER_GPU=$(( STUDENT_MICRO_BATCH_SIZE_PER_GPU * (MAX_PROMPT + MAX_RESPONSE_LENGTH) ))
32+
USE_DYNAMIC_BSZ=True
33+
34+
STUDENT_WORLD_SIZE=2
35+
36+
TEACHER_RESOURCE_POOL=False
37+
TEACHER_WORLD_SIZE=4
38+
39+
SP=1
40+
41+
EXP_NAME="fsdp/student-${STUDENT_MODEL}/teacher-${TEACHER_MODEL}/loss-${DISTILLATION_LOSS_MODE}/pg-${USE_POLICY_GRADIENT}"
42+
43+
ENFORCE_EAGER=True # true for faster debugging
44+
45+
############################ Paths ############################
46+
47+
gsm8k_train_path=$DATA_PATH/gsm8k/train.parquet
48+
gsm8k_test_path=$DATA_PATH/gsm8k/test.parquet
49+
50+
TRAIN_FILES="['$gsm8k_train_path']"
51+
TEST_FILES="['$gsm8k_test_path']"
52+
53+
############################ Parameter Groups ############################
54+
55+
DATA=(
56+
data.train_files="$TRAIN_FILES"
57+
data.val_files="$TEST_FILES"
58+
data.max_prompt_length=$MAX_PROMPT
59+
data.max_response_length=$MAX_RESPONSE_LENGTH
60+
data.train_batch_size=$TRAIN_PROMPT_BSZ
61+
data.filter_overlong_prompts=True
62+
data.truncation='error'
63+
data.shuffle=False
64+
)
65+
66+
MODEL=(
67+
actor_rollout_ref.model.path="${FAMILY}/${STUDENT_MODEL}"
68+
actor_rollout_ref.model.enable_gradient_checkpointing=True
69+
actor_rollout_ref.model.use_remove_padding=True
70+
actor_rollout_ref.model.use_fused_kernels=$USE_FUSED_KERNELS
71+
actor_rollout_ref.actor.use_torch_compile=True
72+
actor_rollout_ref.rollout.enforce_eager=$ENFORCE_EAGER
73+
)
74+
75+
DISTILLATION=(
76+
distillation.enabled=True
77+
distillation.num_workers=8
78+
distillation.teacher_model.enable_resource_pool=$TEACHER_RESOURCE_POOL
79+
distillation.teacher_model.n_gpus_per_node=$TEACHER_WORLD_SIZE
80+
distillation.teacher_model.nnodes=1
81+
distillation.teacher_model.model_path="${FAMILY}/${TEACHER_MODEL}"
82+
distillation.teacher_model.inference.tensor_model_parallel_size=1
83+
distillation.teacher_model.inference.name=$ROLLOUT_NAME
84+
distillation.teacher_model.inference.gpu_memory_utilization=0.3
85+
distillation.teacher_model.inference.enforce_eager=$ENFORCE_EAGER
86+
distillation.teacher_model.inference.max_model_len=$MAX_NUM_TOKENS
87+
distillation.teacher_model.inference.max_num_batched_tokens=$MAX_NUM_TOKENS
88+
distillation.teacher_model.inference.max_num_seqs=$MAX_NUM_TOKENS
89+
distillation.distillation_loss.loss_mode=$DISTILLATION_LOSS_MODE
90+
distillation.distillation_loss.topk=64
91+
distillation.distillation_loss.use_task_rewards=False
92+
distillation.distillation_loss.use_policy_gradient=$USE_POLICY_GRADIENT
93+
distillation.distillation_loss.loss_max_clamp=$DISTILLATION_LOSS_MAX_CLAMP
94+
distillation.distillation_loss.log_prob_min_clamp=$DISTILLATION_LOG_PROB_MIN_CLAMP
95+
)
96+
97+
STUDENT=(
98+
actor_rollout_ref.actor.optim.lr=1e-6
99+
actor_rollout_ref.actor.ppo_mini_batch_size=$TRAIN_PROMPT_BSZ
100+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$STUDENT_MICRO_BATCH_SIZE_PER_GPU
101+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$STUDENT_MAX_TOKEN_LEN_PER_GPU
102+
actor_rollout_ref.actor.use_dynamic_bsz=$USE_DYNAMIC_BSZ
103+
actor_rollout_ref.actor.fsdp_config.param_offload=True
104+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True
105+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$SP
106+
)
107+
108+
ROLLOUT=(
109+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$STUDENT_MICRO_BATCH_SIZE_PER_GPU
110+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$STUDENT_MAX_TOKEN_LEN_PER_GPU
111+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=$USE_DYNAMIC_BSZ
112+
actor_rollout_ref.rollout.tensor_model_parallel_size=1
113+
actor_rollout_ref.rollout.name=$ROLLOUT_NAME
114+
actor_rollout_ref.rollout.gpu_memory_utilization=0.3
115+
actor_rollout_ref.rollout.calculate_log_probs=False
116+
actor_rollout_ref.rollout.max_model_len=$MAX_NUM_TOKENS
117+
actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_TOKENS
118+
actor_rollout_ref.rollout.max_num_seqs=$MAX_NUM_TOKENS
119+
actor_rollout_ref.rollout.n=1
120+
)
121+
122+
ALGORITHM=(
123+
algorithm.adv_estimator=grpo
124+
algorithm.use_kl_in_reward=False
125+
)
126+
127+
TRAINER=(
128+
trainer.logger='["console","wandb"]'
129+
trainer.project_name=$PROJECT_NAME
130+
trainer.experiment_name=$EXP_NAME
131+
trainer.n_gpus_per_node=$STUDENT_WORLD_SIZE
132+
trainer.nnodes=1
133+
trainer.save_freq=200
134+
trainer.test_freq=5
135+
trainer.total_epochs=15
136+
trainer.val_before_train=False
137+
trainer.use_legacy_worker_impl=disable
138+
trainer.resume_mode=disable
139+
trainer.log_val_generations=5
140+
)
141+
142+
143+
144+
############################ Launch ############################
145+
146+
python3 -m verl.trainer.main_ppo \
147+
--config-path=config \
148+
--config-name='ppo_trainer.yaml' \
149+
"${DATA[@]}" \
150+
"${ALGORITHM[@]}" \
151+
"${MODEL[@]}" \
152+
"${DISTILLATION[@]}" \
153+
"${ROLLOUT[@]}" \
154+
"${STUDENT[@]}" \
155+
"${TRAINER[@]}" \
156+
"$@"

0 commit comments

Comments
 (0)