Skip to content

rmsnorm not trainable #1

@chenlinear

Description

@chenlinear

Hi, it seems like from the beginning of this project rmsnorm is not trainable, which is not usual in other implementations I guess. How does it effect the training speed and the final loss?

What I mean by trainable is:

class RMSNorm(nn.Module):
    def __init__(self, n_embd, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(n_embd))# trainable, requires_grad=True by default

    def forward(self, x):
        x = self.weight * x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return x
x = torch.rand(1, 196, 768)# batch size, sequence length, embedding dimensionality (n_embd)
rmsnorm = RMSNorm(n_embd=768)
y = rmsnorm(x)
assert x.shape == y.shape

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions