diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 9e604ae42aa..dec6feb1b8d 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -139,6 +139,7 @@ "int in_zero_point, bool channel_last=False) -> (Tensor out)" ) lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)") +lib.define("rms_norm(Tensor X, float eps, Tensor W) -> (Tensor Y)") lib.define( "transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, " "int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)" @@ -210,6 +211,9 @@ "fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)") +lib.define( + "rms_norm.out(Tensor X, float eps, Tensor W, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" @@ -615,6 +619,15 @@ def linalg_vector_norm_meta( return X.new_empty([], dtype=X.dtype) +@register_fake("cadence::rms_norm") +def rms_norm_meta( + X: torch.Tensor, + eps: float, + weight: torch.Tensor, +) -> torch.Tensor: + return X.new_empty(X.shape, dtype=X.dtype) + + @register_fake("cadence::requantize") def requantize_meta( input: torch.Tensor,