Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions examples/stack_llama/scripts/rl_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@
from trl.core import LengthSampler


DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"

tqdm.pandas()


Expand Down Expand Up @@ -65,13 +60,21 @@ class ScriptArguments:
save_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"})
output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"})
seed: Optional[int] = field(default=0, metadata={"help": "the seed"})
steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"})
init_kl_coef: Optional[float] = field(
default=0.2,
metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},
)

adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})


parser = HfArgumentParser(ScriptArguments)
script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0]
reward_model_name = script_args.reward_model_name
dataset_name = "lvwerra/stack-exchange-paired"
config = PPOConfig(
steps=script_args.steps,
model_name=script_args.model_name,
learning_rate=script_args.learning_rate,
log_with=script_args.log_with,
Expand All @@ -83,36 +86,35 @@ class ScriptArguments:
target_kl=script_args.target_kl,
ppo_epochs=script_args.ppo_epochs,
seed=script_args.seed,
init_kl_coef=script_args.init_kl_coef,
adap_kl_ctrl=script_args.adap_kl_ctrl,
)

train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/rl", split="train")
train_dataset = train_dataset.select(range(100000))
# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16, "truncation": True}
sent_kwargs = {
"return_all_scores": True,
"function_to_apply": "none",
"batch_size": 16,
"truncation": True,
}

tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name)
# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token.
# only for this model.

if "llama" in script_args.tokenizer_name:
tokenizer.add_special_tokens(
{
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
"pad_token": DEFAULT_PAD_TOKEN,
}
)
else:
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token


# Below is an example function to build the dataset. In our case, we use the IMDB dataset
# from the `datasets` library. One should customize this function to train the model on
# its own dataset.
def build_dataset(
tokenizer, dataset_name="lvwerra/stack-exchange-paired", input_min_text_length=2, input_max_text_length=8
tokenizer,
dataset_name="lvwerra/stack-exchange-paired",
):
"""
Build dataset for training. This builds the dataset from `load_dataset`, one should
Expand Down Expand Up @@ -217,6 +219,7 @@ def collator(data):
device_map={"": current_device},
model_kwargs={"load_in_8bit": True},
tokenizer=tokenizer,
return_token_type_ids=False,
)

# We then define the arguments to pass to the `generate` function. These arguments
Expand All @@ -235,6 +238,9 @@ def collator(data):
output_length_sampler = LengthSampler(output_min_length, output_max_length)

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
if epoch >= config.total_ppo_epochs:
break

question_tensors = batch["input_ids"]

response_tensors = ppo_trainer.generate(
Expand Down