-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgrpo.py
More file actions
324 lines (282 loc) · 11.6 KB
/
grpo.py
File metadata and controls
324 lines (282 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import re
import sys
from dataclasses import dataclass, field
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward as open_r1_accuracy_reward,
# format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
reasoning_steps_reward,
)
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config, GRPOTrainer
# from grpo_trainer import GRPOTrainer
from my_modeling_qwen2 import Qwen2ForCausalLM as MyQwen2ForCausalLM
import deepspeed
import torch.distributed as dist
logger = logging.getLogger(__name__)
def accuracy_reward(completions, answer, **kwargs):
solution = ["\\boxed{answer}".replace("answer", a) for a in answer]
return open_r1_accuracy_reward(completions, solution)
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think>.*\\boxed\{.+?\}.*$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [
re.match(pattern, content, re.DOTALL | re.MULTILINE)
for content in completion_contents
]
return [1.0 if match else 0.0 for match in matches]
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Maximum reward for cosine scaling for wrong answers.
cosine_min_value_correct (`float`):
Minimum reward for cosine scaling for correct answers.
cosine_max_value_correct (`float`):
Maximum reward for cosine scaling for correct answers.
cosine_max_len (`int`):
Maximum length for cosine scaling.
"""
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
)
cosine_max_value_wrong: float = field(
default=-0.5,
metadata={"help": "Maximum reward for wrong answers"},
)
cosine_min_value_correct: float = field(
default=0.5,
metadata={"help": "Minimum reward for correct answers"},
)
cosine_max_value_correct: float = field(
default=1.0,
metadata={"help": "Maximum reward for correct answers"},
)
cosine_max_len: int = field(
default=1000,
metadata={"help": "Maximum length for scaling"},
)
repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
)
repetition_max_penalty: float = field(
default=-1.0,
metadata={
"help": "Maximum (negative) penalty for for repetition penalty reward"
},
)
SYSTEM_PROMPT = r"You are a helpful AI Assistant. First, think through the reasoning inside <think>...</think>. Then, always present the final answer in \boxed{}."
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Training parameters {training_args}")
# Check for last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
if "wandb" in training_args.report_to:
init_wandb_training(training_args)
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": get_cosine_scaled_reward(
min_value_wrong=script_args.cosine_min_value_wrong,
max_value_wrong=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
dataset = dataset.map(make_conversation)
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
wait_token_id = 151652 # Our experiments are text only, so instead of adding a new token we used <|vision_start|> as our <|continue-thinking|> token.
model = MyQwen2ForCausalLM.from_pretrained(
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
wait_token_id,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
with deepspeed.zero.GatheredParameters(
[model.model.embed_tokens.weight, model.model.wait_token_embedding_parameter],
modifier_rank=0,
):
if dist.get_rank() == 0: # Only rank 0 modifies
with torch.no_grad():
model.model.wait_token_embedding_parameter.zero_()
model.model.embed_tokens.weight[wait_token_id].data.copy_(
model.model.embed_tokens.weight[14190].data
) # "Wait" token
for param in model.parameters():
param.requires_grad_(False)
model.model.wait_token_embedding_parameter.requires_grad_(True)
training_args.beta = 0.0
#############################
# Initialize the GRPO trainer
#############################
trainer = GRPOTrainer(
model=model,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
)
trainer._move_model_to_vllm()
###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
##########
# Evaluate
##########
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
#############
# push to hub
#############
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)
with deepspeed.zero.GatheredParameters(
model.model.embed_tokens.weight, modifier_rank=0
):
if dist.get_rank() == 0: # Only rank 0 modifies
with torch.no_grad():
if trainer.accelerator.is_main_process:
torch.save(
model.model.wait_token_embedding_parameter,
os.path.join(training_args.output_dir, "wait_token.pt"),
)
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
asyncio.run(main(script_args, training_args, model_args))