Skip to content

Commit 0035390

Browse files
authored
Add embedding_renorm code | feat(torchlib) (#1098)
1 parent 2a610c4 commit 0035390

File tree

3 files changed

+95
-0
lines changed

3 files changed

+95
-0
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3000,6 +3000,57 @@ def aten_embedding_dense_backward(
30003000
raise NotImplementedError()
30013001

30023002

3003+
@torch_op("aten::embedding_renorm", trace_only=True)
3004+
def aten_embedding_renorm(
3005+
weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0
3006+
) -> TFloat:
3007+
"""embedding_renorm(Tensor weight, Tensor indices, float max_norm, float norm_type) -> Tensor"""
3008+
3009+
unique_indices = op.Unique(indices)
3010+
unique_indices_Y = op.SequenceAt(unique_indices, 0)
3011+
# using _onnx private function because op.SrquenceAt(unique_indices, 0) cannot pass module checker
3012+
# The error message is:
3013+
# onnx.onnx_cpp2py_export.shape_inference.InferenceError:
3014+
# [ShapeInferenceError] Shape inference error(s): (op_type:aten_embedding_renorm,
3015+
# node name: aten_embedding_renorm_0): [ShapeInferenceError] (op_type:SequenceAt,
3016+
# node name: n2): input_sequence typestr: S, has unsupported type: tensor(int64)
3017+
return aten_embedding_renorm_onnx(weight, unique_indices_Y, max_norm, norm_type)
3018+
3019+
3020+
@torch_op("aten::embedding_renorm", private=True)
3021+
def aten_embedding_renorm_onnx(
3022+
weight: TFloat, indices: INT64, max_norm: float, norm_type: float = 2.0
3023+
) -> TFloat:
3024+
partial_weight = op.Gather(weight, indices)
3025+
# partial_weight_norm = sum(|w|^p)^(1/p)
3026+
if norm_type == 1.0:
3027+
# This is not necessary, but op.ReduceL1 is faster than function list in 'else'
3028+
partial_weight_norm = op.ReduceL1(partial_weight, axes=[1], keepdims=True)
3029+
elif norm_type == 2.0:
3030+
# This is not necessary, but op.ReduceL2 is faster than function list in 'else'
3031+
partial_weight_norm = op.ReduceL2(partial_weight, axes=[1], keepdims=True)
3032+
else:
3033+
# Abs -> Pow -> ReduceSum -> Pow -> Pow
3034+
partial_weight_abs = op.Abs(partial_weight)
3035+
partial_weight_pow = op.Pow(partial_weight_abs, op.Constant(value_float=norm_type))
3036+
partial_weight_norm = op.ReduceSum(partial_weight_pow, axes=[1], keepdims=True)
3037+
pow_value = op.CastLike(1.0 / norm_type, weight)
3038+
partial_weight_norm = op.Pow(partial_weight_norm, pow_value)
3039+
3040+
max_norm = op.CastLike(op.Constant(value_float=max_norm), weight)
3041+
# This is to avoid weight is zero
3042+
err = op.CastLike(op.Constant(value_float=1e-7), weight)
3043+
partial_weight_norm_ = op.Add(partial_weight_norm, err)
3044+
scales = op.Div(max_norm, partial_weight_norm_)
3045+
partial_weight_renorm = op.Mul(partial_weight, scales)
3046+
# Set values to renormed values where weight_norm > max_norm, but keep the original values where weight_norm <= max_norm
3047+
partial_weight_renorm = op.Where(
3048+
op.Greater(partial_weight_norm, max_norm), partial_weight_renorm, partial_weight
3049+
)
3050+
value = op.ScatterND(weight, op.Unsqueeze(indices, [1]), partial_weight_renorm)
3051+
return value
3052+
3053+
30033054
def aten_embedding_sparse_backward(
30043055
grad: TensorType,
30053056
indices: TensorType,

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,36 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
818818
yield opinfo_core.SampleInput(t, kwargs={"p": p})
819819

820820

821+
def sample_inputs_embedding_renorm(op_info, device, dtype, requires_grad, **kwargs):
822+
del op_info
823+
del kwargs
824+
825+
def make_input(shape):
826+
return common_methods_invocations.make_tensor(
827+
shape, device=device, dtype=dtype, requires_grad=requires_grad
828+
)
829+
830+
def make_long_input(shape, *, low, high, noncontiguous=False):
831+
return common_methods_invocations.make_tensor(
832+
shape,
833+
device=device,
834+
dtype=torch.long,
835+
low=low,
836+
high=high,
837+
noncontiguous=noncontiguous,
838+
)
839+
840+
for max_norm in (0.5, 1.0, 5.0):
841+
for norm_type in (0.8, 1.0, 2.0, 2.5):
842+
idx = make_long_input((6,), low=0, high=S)
843+
weights = make_input((S, S)) * 2
844+
yield common_methods_invocations.SampleInput(
845+
weights,
846+
args=(idx,),
847+
kwargs={"max_norm": max_norm, "norm_type": norm_type},
848+
)
849+
850+
821851
def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs):
822852
del op_info
823853
del kwargs
@@ -1240,6 +1270,13 @@ def sample_inputs_scaled_dot_product_flash_attention(
12401270
sample_inputs_func=sample_inputs_embedding_bag_padding_idx,
12411271
supports_out=False,
12421272
),
1273+
opinfo_core.OpInfo(
1274+
"ops.aten.embedding_renorm",
1275+
aten_name="embedding_renorm",
1276+
dtypes=common_dtype.floating_types_and_half(),
1277+
sample_inputs_func=sample_inputs_embedding_renorm,
1278+
supports_out=False,
1279+
),
12431280
opinfo_core.OpInfo(
12441281
"nn.functional.conv3d",
12451282
aten_name="conv3d",

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,13 @@ def _where_input_wrangler(
10441044
tolerance={torch.float16: (1e-2, 1e-2)},
10451045
compare_shape_only_for_output=(1, 2, 3),
10461046
),
1047+
TorchLibOpInfo(
1048+
"ops.aten.embedding_renorm",
1049+
core_ops.aten_embedding_renorm,
1050+
tolerance={torch.float16: (1e-2, 1e-2)},
1051+
trace_only=True,
1052+
compare_shape_only_for_output=(1, 2, 3),
1053+
),
10471054
TorchLibOpInfo(
10481055
"nn.functional.embedding",
10491056
core_ops.aten_embedding,

0 commit comments

Comments
 (0)