@@ -3000,6 +3000,57 @@ def aten_embedding_dense_backward(
3000
3000
raise NotImplementedError ()
3001
3001
3002
3002
3003
+ @torch_op ("aten::embedding_renorm" , trace_only = True )
3004
+ def aten_embedding_renorm (
3005
+ weight : TFloat , indices : INT64 , max_norm : float , norm_type : float = 2.0
3006
+ ) -> TFloat :
3007
+ """embedding_renorm(Tensor weight, Tensor indices, float max_norm, float norm_type) -> Tensor"""
3008
+
3009
+ unique_indices = op .Unique (indices )
3010
+ unique_indices_Y = op .SequenceAt (unique_indices , 0 )
3011
+ # using _onnx private function because op.SrquenceAt(unique_indices, 0) cannot pass module checker
3012
+ # The error message is:
3013
+ # onnx.onnx_cpp2py_export.shape_inference.InferenceError:
3014
+ # [ShapeInferenceError] Shape inference error(s): (op_type:aten_embedding_renorm,
3015
+ # node name: aten_embedding_renorm_0): [ShapeInferenceError] (op_type:SequenceAt,
3016
+ # node name: n2): input_sequence typestr: S, has unsupported type: tensor(int64)
3017
+ return aten_embedding_renorm_onnx (weight , unique_indices_Y , max_norm , norm_type )
3018
+
3019
+
3020
+ @torch_op ("aten::embedding_renorm" , private = True )
3021
+ def aten_embedding_renorm_onnx (
3022
+ weight : TFloat , indices : INT64 , max_norm : float , norm_type : float = 2.0
3023
+ ) -> TFloat :
3024
+ partial_weight = op .Gather (weight , indices )
3025
+ # partial_weight_norm = sum(|w|^p)^(1/p)
3026
+ if norm_type == 1.0 :
3027
+ # This is not necessary, but op.ReduceL1 is faster than function list in 'else'
3028
+ partial_weight_norm = op .ReduceL1 (partial_weight , axes = [1 ], keepdims = True )
3029
+ elif norm_type == 2.0 :
3030
+ # This is not necessary, but op.ReduceL2 is faster than function list in 'else'
3031
+ partial_weight_norm = op .ReduceL2 (partial_weight , axes = [1 ], keepdims = True )
3032
+ else :
3033
+ # Abs -> Pow -> ReduceSum -> Pow -> Pow
3034
+ partial_weight_abs = op .Abs (partial_weight )
3035
+ partial_weight_pow = op .Pow (partial_weight_abs , op .Constant (value_float = norm_type ))
3036
+ partial_weight_norm = op .ReduceSum (partial_weight_pow , axes = [1 ], keepdims = True )
3037
+ pow_value = op .CastLike (1.0 / norm_type , weight )
3038
+ partial_weight_norm = op .Pow (partial_weight_norm , pow_value )
3039
+
3040
+ max_norm = op .CastLike (op .Constant (value_float = max_norm ), weight )
3041
+ # This is to avoid weight is zero
3042
+ err = op .CastLike (op .Constant (value_float = 1e-7 ), weight )
3043
+ partial_weight_norm_ = op .Add (partial_weight_norm , err )
3044
+ scales = op .Div (max_norm , partial_weight_norm_ )
3045
+ partial_weight_renorm = op .Mul (partial_weight , scales )
3046
+ # Set values to renormed values where weight_norm > max_norm, but keep the original values where weight_norm <= max_norm
3047
+ partial_weight_renorm = op .Where (
3048
+ op .Greater (partial_weight_norm , max_norm ), partial_weight_renorm , partial_weight
3049
+ )
3050
+ value = op .ScatterND (weight , op .Unsqueeze (indices , [1 ]), partial_weight_renorm )
3051
+ return value
3052
+
3053
+
3003
3054
def aten_embedding_sparse_backward (
3004
3055
grad : TensorType ,
3005
3056
indices : TensorType ,
0 commit comments