Skip to content

Switch to PyTorch's built-in RMSNorm #2054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 25, 2024

Conversation

calvinpelletier
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)
  • zoom zoom

Results

Inference
Uncompiled: 31x faster
Compiled: compilation and first batch = 5x faster, post-first batches = 1.5x faster

Train (forward + backward):
Uncompiled: 22x faster
Compiled: compilation and first batch = 1.5x faster, post-first batches = 1.2x faster

Parity:
Uncompiled: perfect
Compiled: 2.76e-6 MSE

Test plan

Test code (output is below):
1 = original RMSNorm, 2 = new RMSNorm

 from time import time

import torch
import torch.nn.functional as F

from torchtune.modules.rms_norm import RMSNorm, RMSNorm2


class Time:
    def __init__(s, name):
        s.name = name

    def __enter__(s):
        s.start = time()

    def __exit__(s, *a, **kw):
        print(f"TIME_{s.name}", time() - s.start)


shape = [1024, 512, 4096]
xs = [torch.randn(shape, dtype=torch.bfloat16, device="cuda") for _ in range(10)]

n1 = RMSNorm(shape[-1], eps=1e-6).cuda()
n2 = RMSNorm2(shape[-1], eps=1e-6).cuda()

with torch.no_grad():
    n1 = RMSNorm(shape[-1], eps=1e-6).cuda().eval()
    n2 = RMSNorm2(shape[-1], eps=1e-6).cuda().eval()

    with Time("inference_uncompiled1"):
        for x in xs:
            y1 = n1(x)
    with Time("inference_uncompiled2"):
        for x in xs:
            y2 = n2(x)
    print("inference uncompiled mse", F.mse_loss(y1, y2).item())

    with Time("inference_initial_compile1"):
        n1 = torch.compile(n1)
        n1(xs[0])
    with Time("inference_initial_compile2"):
        n2 = torch.compile(n2)
        n2(xs[0])

    with Time("inference_compiled1"):
        for x in xs:
            y1 = n1(x)
    with Time("inference_compiled2"):
        for x in xs:
            y2 = n2(x)
    print("inference compiled mse", F.mse_loss(y1, y2).item())


n1 = RMSNorm(shape[-1], eps=1e-6).cuda().train()
n2 = RMSNorm2(shape[-1], eps=1e-6).cuda().train()

with Time("train_uncompiled1"):
    for x in xs:
        y1 = n1(x)
        y1.mean().backward()
with Time("train_uncompiled2"):
    for x in xs:
        y2 = n2(x)
        y2.mean().backward()
print("train uncompiled mse", F.mse_loss(y1, y2).item())

with Time("train_initial_compile1"):
    n1 = torch.compile(n1)
    y1 = n1(xs[0])
    y1.mean().backward()
with Time("train_initial_compile2"):
    n2 = torch.compile(n2)
    y2 = n2(xs[0])
    y2.mean().backward()

with Time("train_compiled1"):
    for x in xs:
        y1 = n1(x)
        y1.mean().backward()
with Time("train_compiled2"):
    for x in xs:
        y2 = n2(x)
        y2.mean().backward()
print("train compiled mse", F.mse_loss(y1, y2).item())

Output on 1xH100:

TIME_inference_uncompiled1 0.057599782943725586
TIME_inference_uncompiled2 0.0018465518951416016
inference uncompiled mse 0.0
TIME_inference_initial_compile1 0.7398524284362793
TIME_inference_initial_compile2 0.1506185531616211
TIME_inference_compiled1 0.0005228519439697266
TIME_inference_compiled2 0.0003452301025390625
inference compiled mse 2.760749339358881e-06
TIME_train_uncompiled1 0.07179093360900879
TIME_train_uncompiled2 0.0031974315643310547
train uncompiled mse 0.0
TIME_train_initial_compile1 0.2178797721862793
TIME_train_initial_compile2 0.15280580520629883
TIME_train_compiled1 0.005212545394897461
TIME_train_compiled2 0.004248380661010742
train compiled mse 2.760749339358881e-06

Copy link

pytorch-bot bot commented Nov 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2054

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ecf9748 with merge base a9aadf5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 22, 2024

assert_expected(output_fp16, expected_fp16, atol=1e-7, rtol=1e-3)
assert output_fp16.dtype == torch.float32
assert output_fp16.dtype == torch.float16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know why this wasnt failing before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah the model wasn't casted to fp16, which means the scale parameter was still fp32. And since x_normed * self.scale occurred after the cast back to fp16, the output ended up in fp32.

Comment on lines +41 to +42
weight=self.scale,
eps=self.eps,
Copy link
Contributor

@felipemello1 felipemello1 Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question: when we load the model in bf16, will self.eps and self.scale also become bf16 or do they stay float32?

If they are cast to bf16, its might be worth digging a bit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eps is just a python float, not a pytorch tensor, so it'll basically just adapt to whichever dtype it is being applied to.

scale will become bf16, but F.rms_norm will cast it to fp32 when it gets multiplied by the fp32 output.

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice finding! Approving to unblock, just left a few comments wanting to make sure that the test and the bf16 are ok.

also, fyi, i think that using torch.cuda api would be the way to go here.


class Time:
    def __init__(self, name):
        self.name = name
        self.start_event = torch.cuda.Event(enable_timing=True)
        self.end_event = torch.cuda.Event(enable_timing=True)

    def __enter__(self):
        self.start_event.record()

    def __exit__(self, *args, **kwargs):
        self.end_event.record()
        torch.cuda.synchronize()  # Wait for the events to be recorded
        elapsed_time = self.start_event.elapsed_time(
            self.end_event
        )  # Time in milliseconds
        print(f"TIME_{self.name}: {elapsed_time:.3f} ms")

using shape = [8, 512, 4096], i got

TIME_inference_uncompiled1: 129.710 ms
TIME_inference_uncompiled2: 11.391 ms
inference uncompiled mse 0.0
TIME_inference_initial_compile1: 2834.852 ms
TIME_inference_initial_compile2: 403.543 ms
TIME_inference_compiled1: 1.052 ms
TIME_inference_compiled2: 0.869 ms
inference compiled mse 2.7247249363426818e-06
TIME_train_uncompiled1: 40.218 ms
TIME_train_uncompiled2: 10.865 ms
train uncompiled mse 0.0
TIME_train_initial_compile1: 928.825 ms
TIME_train_initial_compile2: 708.388 ms
TIME_train_compiled1: 6.275 ms
TIME_train_compiled2: 5.314 ms
train compiled mse 2.7247249363426818e-06

@calvinpelletier calvinpelletier merged commit 1450d61 into pytorch:main Nov 25, 2024
17 checks passed
@calvinpelletier calvinpelletier deleted the rms_norm branch November 25, 2024 18:26
@ebsmothers ebsmothers mentioned this pull request Nov 26, 2024
44 tasks
ebsmothers added a commit to ebsmothers/ebs-torchtune that referenced this pull request Mar 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants