Skip to content

Commit a112108

Browse files
committed
Allow setting ACTOR_TP.
Signed-off-by: Jonas Yang <joyang@nvidia.com>
1 parent 0901205 commit a112108

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

.github/workflows/e2e_ppo_grpo_trainer_trtllm.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ jobs:
168168
run: |
169169
ray stop --force
170170
DATADIR=${HOME}/data \
171+
ACTOR_TP=2 \
171172
bash examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh 2 \
172173
trainer.total_training_steps=1 \
173174
data.train_files="['${HOME}/data/gsm8k/train.parquet']" \

examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export RAY_DEDUP_LOGS=0
1414
# Config
1515
# -----
1616
TP=${1:-4}
17+
ACTOR_TP=${ACTOR_TP:-4}
1718
PROJECT_NAME=${PROJECT_NAME:-"verl_grpo_example_gsm8k_math"}
1819
EXP_NAME=megatron-trtllm-qwen2-7b-tp${TP}-8gpus
1920

@@ -58,7 +59,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \
5859
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
5960
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
6061
actor_rollout_ref.actor.megatron.use_mbridge=True \
61-
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
62+
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${ACTOR_TP} \
6263
actor_rollout_ref.actor.use_kl_loss=True \
6364
actor_rollout_ref.actor.kl_loss_coef=0.001 \
6465
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
@@ -72,7 +73,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \
7273
actor_rollout_ref.rollout.max_batch_size=${MAX_BATCH_SIZE} \
7374
actor_rollout_ref.rollout.max_num_batched_tokens=32768 \
7475
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
75-
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \
76+
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${ACTOR_TP} \
7677
+actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_timeout_iters=32 \
7778
+actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_max_tokens_ratio=0.5 \
7879
actor_rollout_ref.rollout.calculate_log_probs=True \

0 commit comments

Comments
 (0)