Skip to content

Commit 605a538

Browse files
ThomasJannaudfacebook-github-bot
authored andcommitted
RMSNorm support - Executorch (#9844)
Summary: This follows D72014553 which adds support for RMSNorm (cpu backend) This is a separate diff for Executorch / Github Reviewed By: Vysarat Differential Revision: D72258890
1 parent 0844c38 commit 605a538

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@
139139
"int in_zero_point, bool channel_last=False) -> (Tensor out)"
140140
)
141141
lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)")
142+
lib.define(
143+
"rms_norm(Tensor X, float eps, Tensor W) -> (Tensor Y)"
144+
)
142145
lib.define(
143146
"transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
144147
"int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"
@@ -210,6 +213,7 @@
210213
"fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)"
211214
)
212215
lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)")
216+
lib.define("rms_norm.out(Tensor X, float eps, Tensor W, *, Tensor(a!) out) -> Tensor(a!)")
213217
lib.define(
214218
"quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
215219
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
@@ -615,6 +619,15 @@ def linalg_vector_norm_meta(
615619
return X.new_empty([], dtype=X.dtype)
616620

617621

622+
@register_fake("cadence::rms_norm")
623+
def rms_norm_meta(
624+
X: torch.Tensor,
625+
eps: float,
626+
weight: torch.Tensor,
627+
) -> torch.Tensor:
628+
return X.new_empty(X.shape, dtype=X.dtype)
629+
630+
618631
@register_fake("cadence::requantize")
619632
def requantize_meta(
620633
input: torch.Tensor,

0 commit comments

Comments
 (0)