-
Notifications
You must be signed in to change notification settings - Fork 653
Switch to PyTorch's built-in RMSNorm #2054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2054
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ecf9748 with merge base a9aadf5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
assert_expected(output_fp16, expected_fp16, atol=1e-7, rtol=1e-3) | ||
assert output_fp16.dtype == torch.float32 | ||
assert output_fp16.dtype == torch.float16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you know why this wasnt failing before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah the model wasn't casted to fp16, which means the scale parameter was still fp32. And since x_normed * self.scale
occurred after the cast back to fp16, the output ended up in fp32.
weight=self.scale, | ||
eps=self.eps, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob question: when we load the model in bf16, will self.eps and self.scale also become bf16 or do they stay float32?
If they are cast to bf16, its might be worth digging a bit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eps
is just a python float, not a pytorch tensor, so it'll basically just adapt to whichever dtype it is being applied to.
scale
will become bf16, but F.rms_norm
will cast it to fp32 when it gets multiplied by the fp32 output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice finding! Approving to unblock, just left a few comments wanting to make sure that the test and the bf16 are ok.
also, fyi, i think that using torch.cuda api would be the way to go here.
class Time:
def __init__(self, name):
self.name = name
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
def __enter__(self):
self.start_event.record()
def __exit__(self, *args, **kwargs):
self.end_event.record()
torch.cuda.synchronize() # Wait for the events to be recorded
elapsed_time = self.start_event.elapsed_time(
self.end_event
) # Time in milliseconds
print(f"TIME_{self.name}: {elapsed_time:.3f} ms")
using shape = [8, 512, 4096], i got
TIME_inference_uncompiled1: 129.710 ms
TIME_inference_uncompiled2: 11.391 ms
inference uncompiled mse 0.0
TIME_inference_initial_compile1: 2834.852 ms
TIME_inference_initial_compile2: 403.543 ms
TIME_inference_compiled1: 1.052 ms
TIME_inference_compiled2: 0.869 ms
inference compiled mse 2.7247249363426818e-06
TIME_train_uncompiled1: 40.218 ms
TIME_train_uncompiled2: 10.865 ms
train uncompiled mse 0.0
TIME_train_initial_compile1: 928.825 ms
TIME_train_initial_compile2: 708.388 ms
TIME_train_compiled1: 6.275 ms
TIME_train_compiled2: 5.314 ms
train compiled mse 2.7247249363426818e-06
This reverts commit 1450d61.
Context
What is the purpose of this PR? Is it to
Results
Inference
Uncompiled: 31x faster
Compiled: compilation and first batch = 5x faster, post-first batches = 1.5x faster
Train (forward + backward):
Uncompiled: 22x faster
Compiled: compilation and first batch = 1.5x faster, post-first batches = 1.2x faster
Parity:
Uncompiled: perfect
Compiled: 2.76e-6 MSE
Test plan
Test code (output is below):
1 = original RMSNorm, 2 = new RMSNorm
Output on 1xH100: