|
139 | 139 | "int in_zero_point, bool channel_last=False) -> (Tensor out)"
|
140 | 140 | )
|
141 | 141 | lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)")
|
| 142 | +lib.define( |
| 143 | + "rms_norm(Tensor X, float eps, Tensor W) -> (Tensor Y)" |
| 144 | +) |
142 | 145 | lib.define(
|
143 | 146 | "transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
|
144 | 147 | "int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"
|
|
210 | 213 | "fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)"
|
211 | 214 | )
|
212 | 215 | lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)")
|
| 216 | +lib.define("rms_norm.out(Tensor X, float eps, Tensor W, *, Tensor(a!) out) -> Tensor(a!)") |
213 | 217 | lib.define(
|
214 | 218 | "quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
|
215 | 219 | "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
|
@@ -615,6 +619,15 @@ def linalg_vector_norm_meta(
|
615 | 619 | return X.new_empty([], dtype=X.dtype)
|
616 | 620 |
|
617 | 621 |
|
| 622 | +@register_fake("cadence::rms_norm") |
| 623 | +def rms_norm_meta( |
| 624 | + X: torch.Tensor, |
| 625 | + eps: float, |
| 626 | + weight: torch.Tensor, |
| 627 | +) -> torch.Tensor: |
| 628 | + return X.new_empty(X.shape, dtype=X.dtype) |
| 629 | + |
| 630 | + |
618 | 631 | @register_fake("cadence::requantize")
|
619 | 632 | def requantize_meta(
|
620 | 633 | input: torch.Tensor,
|
|
0 commit comments