Skip to content

Conversation

@yao-matrix
Copy link
Contributor

below 4 cases passed
tests/trainer/test_trainer.py::TrainerIntegrationTest::test_galore_lr_display_with_scheduler
tests/trainer/test_trainer.py::TrainerIntegrationTest::test_galore_lr_display_without_scheduler
tests/trainer/test_trainer.py::TrainerIntegrationTest::test_schedulefree_radam
tests/trainer/test_trainer.py::TrainerIntegrationTest::test_use_liger_kernel_trainer

@github-actions github-actions bot marked this pull request as draft April 21, 2025 03:17
@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@yao-matrix yao-matrix marked this pull request as ready for review April 21, 2025 04:17
# creating log history of trainer, results don't matter
trainer.train()
logs = trainer.state.log_history[1:][:-1]
logs = trainer.state.log_history[1:-1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log_history is an array of dicts, which mean the integer indexing should be 1-dim, so the prior code's meaning actually is (trainer.state.log_history[1:])[:-1], which is slicing the 1: and get a result 1-d array, then slicing the 0:-1, I rewrite it with a one-shot slicing [1:-1]. And it's quite confusing for me why the original code slices out the 0, which actually the first step hyper-parameters.

The trainer.state.log_history i see is as below. We can see it has 16 steps as expected, and the last item is training summary. so remove -1 is expected, but don't know why remove 0, but i didn't change the logic, in case there are some background i an not aware of.

[{'loss': 4.5726, 'grad_norm': 1.2674919366836548, 'learning_rate': 0.0, 'epoch': 0.125, 'step': 1},
{'loss': 4.5726, 'grad_norm': 1.2674919366836548, 'learning_rate': 4e-05, 'epoch': 0.25, 'step': 2},
{'loss': 4.5609, 'grad_norm': 1.2704461812973022, 'learning_rate': 8e-05, 'epoch': 0.375, 'step': 3},
{'loss': 4.5375, 'grad_norm': 1.2671568393707275, 'learning_rate': 0.00012, 'epoch': 0.5, 'step': 4},
{'loss': 4.503, 'grad_norm': 1.2401748895645142, 'learning_rate': 0.00016, 'epoch': 0.625, 'step': 5},
{'loss': 4.4594, 'grad_norm': 1.1698124408721924, 'learning_rate': 0.0002, 'epoch': 0.75, 'step': 6},
{'loss': 4.4105, 'grad_norm': 1.0548782348632812, 'learning_rate': 0.00019594929736144976, 'epoch': 0.875, 'step': 7},
{'loss': 4.3693, 'grad_norm': 0.953703761100769, 'learning_rate': 0.00018412535328311814, 'epoch': 1.0, 'step': 8},
{'loss': 4.3359, 'grad_norm': 0.8863608837127686, 'learning_rate': 0.00016548607339452853, 'epoch': 1.125, 'step': 9},
{'loss': 4.3095, 'grad_norm': 0.8462928533554077, 'learning_rate': 0.00014154150130018866, 'epoch': 1.25, 'step': 10},
{'loss': 4.2889, 'grad_norm': 0.8223782777786255, 'learning_rate': 0.00011423148382732853, 'epoch': 1.375, 'step': 11},
{'loss': 4.2733, 'grad_norm': 0.8070436120033264, 'learning_rate': 8.57685161726715e-05, 'epoch': 1.5, 'step': 12},
{'loss': 4.262, 'grad_norm': 0.7965378761291504, 'learning_rate': 5.845849869981137e-05, 'epoch': 1.625, 'step': 13},
{'loss': 4.2544, 'grad_norm': 0.7894120812416077, 'learning_rate': 3.45139266054715e-05, 'epoch': 1.75, 'step': 14},
{'loss': 4.2499, 'grad_norm': 0.7850890755653381, 'learning_rate': 1.587464671688187e-05, 'epoch': 1.875, 'step': 15},
{'loss': 4.2478, 'grad_norm': 0.7830389738082886, 'learning_rate': 4.050702638550275e-06, 'epoch': 2.0, 'step': 16},
{'train_runtime': 2.2504, 'train_samples_per_second': 56.879, 'train_steps_per_second': 7.11, 'total_flos': 313198116864.0, 'train_loss': 4.38797390460968, 'epoch': 2.0, 'step': 16}]

# reach given learning rate peak and end with 0 lr
self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate)
self.assertTrue(logs[-1]["learning_rate"] == 0)
self.assertTrue(logs[num_warmup_steps - 1]["learning_rate"] == learning_rate)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For both A100 and XPU log(actually they are same), the 0.0002 learning_rate happens in step 6. Don't know T4's situation, but i suppose this is not planform-dependent behavior, so maybe case need to be updated. My galore-torch version is 1.0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand, it's the 6th item, so the index is 5, which is num_warmup_steps ..?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but maybe @SunMarc can check this part

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the 6th item's index is 5 in 0-based indexing, and then original first item is sliced out, so the index becomes num_warmup_steps-1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but maybe @SunMarc can check this part

@SunMarc , could you shed your insights here? maybe in prior version of galore-torch the log is different from current version or other considerations. I am using galore-torch 1.0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our daily CI don't have galore-torc so we didn't see it on reports. Manually tested on T4, this PR indeed fixes the issue. So I will merge.

regarding if to add galore-torch in the corresponding CI docker images, I will wait @SunMarc response (here or internally).

Thank you @yao-matrix

logs[i]["learning_rate"] < logs[i + 1]["learning_rate"]
for i in range(len(logs))
if i < num_warmup_steps - 2
if i < num_warmup_steps - 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as line 2282

logs[i]["learning_rate"] > logs[i + 1]["learning_rate"]
for i in range(len(logs) - 1)
if i >= num_warmup_steps - 2
if i >= num_warmup_steps - 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as line 2282

self.assertTrue(logs[-1]["learning_rate"] == 0)
self.assertTrue(logs[num_warmup_steps - 1]["learning_rate"] == learning_rate)
# self.assertTrue(logs[-1]["learning_rate"] == 0)
self.assertTrue(np.allclose(logs[-1]["learning_rate"], 0, atol=5e-6))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

learning_rate is a float value, using == seems not appropriate. On A100 and XPU, it's 4.050702638550275e-06, so i use atol as 5e-16 here.

@Rocketknight1
Copy link
Member

cc @ydshieh

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, just have a question regarding the indices. I will wait @SunMarc

@Rocketknight1
Copy link
Member

Actually, cc @IlyasMoutawwakil! We discussed internally and Ilyas will take on reviewing XPU PRs

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 23, 2025

For PRs like this, I can review too. It's just mainly about require_torch_gpu --> require_torch_accelerator and change some expected values most of the time.

If @IlyasMoutawwakil can share the tasks, more than welcome 🤗

@ydshieh ydshieh merged commit 19e9079 into huggingface:main Apr 23, 2025
9 checks passed
@yao-matrix yao-matrix deleted the trainer-xpu branch April 23, 2025 22:40
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants