Hi, it seems like from the beginning of this project rmsnorm is not trainable, which is not usual in other implementations I guess. How does it effect the training speed and the final loss?
What I mean by trainable is:
class RMSNorm(nn.Module):
def __init__(self, n_embd, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(n_embd))# trainable, requires_grad=True by default
def forward(self, x):
x = self.weight * x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return x
x = torch.rand(1, 196, 768)# batch size, sequence length, embedding dimensionality (n_embd)
rmsnorm = RMSNorm(n_embd=768)
y = rmsnorm(x)
assert x.shape == y.shape
Hi, it seems like from the beginning of this project
rmsnormis not trainable, which is not usual in other implementations I guess. How does it effect the training speed and the final loss?What I mean by trainable is: