Skip to content

Commit 9c0324d

Browse files
minor update
1 parent 9b26ff5 commit 9c0324d

1 file changed

Lines changed: 9 additions & 6 deletions

File tree

train_model.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,8 @@ def train(use_softmax1: bool = False, test_pipeline: bool = False):
4747

4848
# Load the raw dataset
4949
data_dir = Path.cwd() / "data"
50-
if test_pipeline:
51-
dataset = load_dataset(path=str(data_dir), split='train[:2]')
52-
else:
53-
dataset = load_dataset(path=str(data_dir), split='train')
50+
split = 'train[:2]' if test_pipeline else 'train'
51+
dataset = load_dataset(path=str(data_dir), split=split)
5452

5553
# We'll build our dataset by applying our tokenizer.json to our text file.
5654
def process_data(examples):
@@ -81,7 +79,13 @@ def process_data(examples):
8179
save_strategy="no",
8280
)
8381

84-
trainer = Trainer(model=model, args=training_args, data_collator=data_collator, train_dataset=tokenized_dataset)
82+
trainer = Trainer(
83+
model=model,
84+
args=training_args,
85+
tokenizer=tokenizer,
86+
data_collator=data_collator,
87+
train_dataset=tokenized_dataset
88+
)
8589

8690
# Start training
8791
trainer.train()
@@ -95,7 +99,6 @@ def process_data(examples):
9599
try:
96100
login(token=getenv("HUGGINGFACE_TOKEN"))
97101
trainer.push_to_hub()
98-
trainer.tokenizer.push_to_hub(output_dir)
99102
except (ValueError, RuntimeError, OSError, FileNotFoundError, TypeError) as e: # I've seen it all XD
100103
warn(f"Unable to upload model due to, {e}. Trying to write to disk instead.", UserWarning)
101104
trainer.save_model()

0 commit comments

Comments
 (0)