Skip to content

Commit 2f39b94

Browse files
AddOp(embedding bag) | feat(torchlib) (#909)
- This PR is only for aten_embedding_bag function. - There have 4 outputs for this function, we only care about the first one, for other 3, we just make the shape correct, with all zero values filled. - aten_embedding_bag_padding_idx will be another PR. - max_norm, I think this is rare case but not sure. If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. Not sure if we need implement ```embedding_renorm``` function. --------- Co-authored-by: Justin Chu <[email protected]>
1 parent c8959ff commit 2f39b94

File tree

3 files changed

+263
-9
lines changed

3 files changed

+263
-9
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 134 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,19 +2266,144 @@ def aten_embedding_backward(
22662266
raise NotImplementedError()
22672267

22682268

2269+
@torch_op(
2270+
(
2271+
"aten::embedding_bag",
2272+
"aten::_embedding_bag",
2273+
"aten::_embedding_bag_forward_only",
2274+
),
2275+
trace_only=True,
2276+
)
22692277
def aten_embedding_bag(
2270-
weight: TensorType,
2271-
indices: TensorType,
2272-
offsets: TensorType,
2273-
scale_grad_by_freq: bool = False,
2274-
mode: int = 0,
2275-
sparse: bool = False,
2276-
per_sample_weights: Optional[TensorType] = None,
2278+
weight: TFloat,
2279+
indices: INT64,
2280+
offsets: INT64 = None, # Could be None accotding to the doc, go 2d branch
2281+
scale_grad_by_freq: bool = False, # pylint: disable=unused-argument
2282+
mode: int = 1, # [0,1,2] indicate ["sum", "mean", "max"], default is "mean"
2283+
sparse: bool = False, # pylint: disable=unused-argument
2284+
per_sample_weights: Optional[TFloat] = None,
22772285
include_last_offset: bool = False,
2278-
) -> tuple[TensorType, TensorType, TensorType, TensorType]:
2286+
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
22792287
"""embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)"""
22802288

2281-
raise NotImplementedError()
2289+
# assert(rank(indices) in [1,2])
2290+
# assert(rank(offsets) == 1)
2291+
# assert(op.Size(per_sample_weights) == op.Size(indices))
2292+
if per_sample_weights is None:
2293+
# Set per_sample_weights to 1.0, because cannot check 'None' in ONNX-Script
2294+
# Size of persample_weights is the same as indices, and should be 1d tensor
2295+
indices_1d = op.Reshape(indices, [-1])
2296+
per_sample_weights = op.Expand(1, op.Shape(indices_1d))
2297+
# Dtype of per_sample_weights is the same as weight
2298+
per_sample_weights = op.CastLike(per_sample_weights, weight)
2299+
2300+
result = _aten_embedding_bag_onnx(
2301+
weight, indices, offsets, mode, per_sample_weights, include_last_offset
2302+
)
2303+
offset2bag, bag_size, max_indices = _compute_output_others_shape(
2304+
weight, indices, offsets, mode, include_last_offset
2305+
)
2306+
return result, offset2bag, bag_size, max_indices
2307+
2308+
2309+
# This python function only compute the shape of outputs instead of values, fill with 0
2310+
def _compute_output_others_shape(weight, indices, offsets, mode, include_last_off):
2311+
if mode == 0: # sum
2312+
offset2bag = op.Shape(indices, start=0, end=0) # Generate empty tensor
2313+
bag_size = op.Expand(0, op.Shape(offsets))
2314+
max_indices = op.Expand(0, op.Shape(offsets))
2315+
elif mode == 1: # mean
2316+
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
2317+
if include_last_off is True:
2318+
bag_size = op.Expand(0, op.Shape(offsets) - 1)
2319+
else:
2320+
bag_size = op.Expand(0, op.Shape(offsets))
2321+
max_indices = op.Expand(0, op.Shape(bag_size))
2322+
else: # max
2323+
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
2324+
if include_last_off is True:
2325+
bag_size = op.Expand(0, op.Shape(offsets) - 1)
2326+
else:
2327+
bag_size = op.Expand(0, op.Shape(offsets))
2328+
# shape = (bag_size.dim[0], weight.dim[1])
2329+
dim_0 = op.Shape(bag_size, start=0, end=1)
2330+
dim_1 = op.Shape(weight, start=1, end=2)
2331+
max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0))
2332+
2333+
return offset2bag, bag_size, max_indices
2334+
2335+
2336+
@torch_op("aten::embedding_bag", private=True)
2337+
def _aten_embedding_bag_onnx(
2338+
weight: TFloat,
2339+
indices: INT64,
2340+
offsets: INT64,
2341+
mode: int,
2342+
per_sample_weights: TFloat,
2343+
include_last_offset: bool,
2344+
) -> TFloat:
2345+
neg_1 = op.Constant(value_ints=[-1])
2346+
# Assume indices is shape(5,2), indices_1d is shape(10,)
2347+
indices_1d = op.Reshape(indices, neg_1)
2348+
# Get weight out according to indices_1d,
2349+
new_weight = op.Gather(weight, indices_1d)
2350+
# This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
2351+
new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=1))
2352+
weight_dim_1 = op.Reshape(op.Shape(weight, start=1), neg_1)
2353+
indices_size = op.Shape(indices_1d)
2354+
2355+
# Assume indices is shape(5,2), offsets=[0,2,3], include_last_offset = False
2356+
# [0,2,3] -> [0:2], [2:3], [3:10]
2357+
num_bag = op.Reshape(op.Size(offsets), neg_1) # 3 bags, means 15 is the last index
2358+
if op.Equal(include_last_offset, True):
2359+
num_bag = num_bag - 1 # 2 bags, means 3 is the last index
2360+
else:
2361+
offsets = op.Concat(offsets, indices_size, axis=0) # Replace end with number
2362+
2363+
# The element in sequence must be FLOAT32 dtype due to ORT bug
2364+
new_weight = op.Cast(new_weight, to=FLOAT.dtype)
2365+
# FIXME: https://github.com/microsoft/onnxruntime/issues/16846
2366+
result = op.SequenceEmpty()
2367+
2368+
index_tensor = op.Reshape(op.Constant(value_int=0), neg_1) # Used for iterator
2369+
cond = index_tensor < num_bag
2370+
# Process each bag
2371+
while cond:
2372+
start = op.Slice(offsets, index_tensor, index_tensor + 1)
2373+
end = op.Slice(offsets, index_tensor + 1, index_tensor + 2)
2374+
# row_result should be 0, need to generate (1,N) shape tensor with 0 values
2375+
if start == end:
2376+
row_result = op.Expand(
2377+
op.Constant(value_floats=[0.0]),
2378+
op.Concat(op.Constant(value_ints=[1]), weight_dim_1, axis=0),
2379+
)
2380+
else:
2381+
if mode == 0: # sum
2382+
weight_rows = op.Slice(new_weight, start, end)
2383+
row_result = op.ReduceSum(weight_rows, axes=[0])
2384+
elif mode == 1: # mean
2385+
weight_rows = op.Slice(new_weight, start, end)
2386+
if op.Equal(index_tensor, num_bag - 1): # The last bag
2387+
row_result = op.ReduceSum(weight_rows, axes=[0])
2388+
# When include_last_offset=False, offsets=[0,2,3], denominator=5-3=2
2389+
# When include_last_offset=True, offsets=[0,2,3], denominator=5-2=3
2390+
denominator = op.Sub(op.Shape(indices, start=0, end=1), start)
2391+
if op.Greater(denominator, 0):
2392+
row_result = op.Div(row_result, op.CastLike(denominator, new_weight))
2393+
else:
2394+
row_result = op.ReduceMean(weight_rows, axes=[0])
2395+
else: # max
2396+
if op.Equal(index_tensor, num_bag - 1): # The last bag
2397+
weight_rows = op.Slice(new_weight, start, indices_size)
2398+
else:
2399+
weight_rows = op.Slice(new_weight, start, end)
2400+
row_result = op.ReduceMax(weight_rows, axes=[0])
2401+
2402+
result = op.SequenceInsert(result, row_result)
2403+
index_tensor = index_tensor + 1
2404+
cond = index_tensor < num_bag
2405+
result = op.ConcatFromSequence(result, axis=0)
2406+
return op.CastLike(result, weight)
22822407

22832408

22842409
def aten_embedding_dense_backward(

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch.testing._internal.opinfo import core as opinfo_core
1414

1515
S = 5
16+
M = 10
1617

1718

1819
def sample_inputs__local_scalar_dense(op_info, device, dtype, requires_grad, **kwargs):
@@ -622,6 +623,117 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
622623
yield opinfo_core.SampleInput(t, kwargs={"p": p})
623624

624625

626+
def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs):
627+
del op_info
628+
del kwargs
629+
630+
def make_input(shape):
631+
return common_methods_invocations.make_tensor(
632+
shape, device=device, dtype=dtype, requires_grad=requires_grad
633+
)
634+
635+
def make_long_input(shape, *, low, high, noncontiguous=False):
636+
return common_methods_invocations.make_tensor(
637+
shape,
638+
device=device,
639+
dtype=torch.long,
640+
low=low,
641+
high=high,
642+
noncontiguous=noncontiguous,
643+
)
644+
645+
def make_per_sample_weight(flag, idx):
646+
# a tensor of float / double weights, or None
647+
# to indicate all weights should be taken to be 1
648+
if flag:
649+
return make_input(idx.reshape(-1).shape)
650+
return None
651+
652+
offsets = [
653+
torch.tensor([0, 2, 3], device=device, dtype=torch.long),
654+
torch.tensor([0, 0, 2], device=device, dtype=torch.long),
655+
torch.tensor([0, 2, 2, 4], device=device, dtype=torch.long),
656+
]
657+
for offset in offsets:
658+
for include_last_offset in (True, False):
659+
for generate_per_sample_weight in (True, False):
660+
for mode in (
661+
0,
662+
1,
663+
2,
664+
): # ('sum', 'mean', 'max')
665+
# per_sample_weights only support mode='sum'
666+
if generate_per_sample_weight and mode in (1, 2): # ('mean', 'max'):
667+
continue
668+
669+
# 1-D index tensor
670+
indices = make_long_input((S,), low=0, high=M)
671+
per_sample_weights = make_per_sample_weight(
672+
generate_per_sample_weight, indices
673+
)
674+
# 0
675+
yield common_methods_invocations.SampleInput(
676+
make_input((M, S)),
677+
args=(indices,),
678+
kwargs={
679+
"offsets": offset,
680+
"mode": mode,
681+
"per_sample_weights": per_sample_weights,
682+
"include_last_offset": include_last_offset,
683+
},
684+
)
685+
686+
indices = make_long_input((S,), low=0, high=M, noncontiguous=True)
687+
per_sample_weights = make_per_sample_weight(
688+
generate_per_sample_weight, indices
689+
)
690+
# 1
691+
yield common_methods_invocations.SampleInput(
692+
make_input((M, S)),
693+
args=(indices,),
694+
kwargs={
695+
"offsets": offset,
696+
"mode": mode,
697+
"per_sample_weights": per_sample_weights,
698+
"include_last_offset": include_last_offset,
699+
},
700+
)
701+
702+
if mode != 2: # "max" mode in 2-D index tensor make aten func crash
703+
# 2-D index tensor
704+
indices = make_long_input((S, S), low=0, high=M)
705+
per_sample_weights = make_per_sample_weight(
706+
generate_per_sample_weight, indices
707+
)
708+
# 2
709+
yield common_methods_invocations.SampleInput(
710+
make_input((M, S)),
711+
args=(indices,),
712+
kwargs={
713+
"offsets": offset,
714+
"mode": mode,
715+
"per_sample_weights": per_sample_weights,
716+
"include_last_offset": include_last_offset,
717+
},
718+
)
719+
720+
indices = make_long_input((S, S), low=0, high=M, noncontiguous=True)
721+
per_sample_weights = make_per_sample_weight(
722+
generate_per_sample_weight, indices
723+
)
724+
# 3
725+
yield common_methods_invocations.SampleInput(
726+
make_input((M, S)),
727+
args=(indices,),
728+
kwargs={
729+
"offsets": offset,
730+
"mode": mode,
731+
"per_sample_weights": per_sample_weights,
732+
"include_last_offset": include_last_offset,
733+
},
734+
)
735+
736+
625737
# NOTE: How to create an OpInfo:
626738
# 1. Create a function that generates sample inputs for the op.
627739
# This function should yield SampleInputs.
@@ -651,6 +763,13 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
651763
sample_inputs_func=sample_inputs_col2im,
652764
supports_out=False,
653765
),
766+
opinfo_core.OpInfo(
767+
"ops.aten.embedding_bag",
768+
aten_name="embedding_bag",
769+
dtypes=common_dtype.floating_types_and_half(),
770+
sample_inputs_func=sample_inputs_embedding_bag,
771+
supports_out=False,
772+
),
654773
opinfo_core.OpInfo(
655774
"nn.functional.conv3d",
656775
aten_name="conv3d",

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,13 @@ def _where_input_wrangler(
969969
dtypes=(torch.float16,),
970970
reason="fixme: ONNX Runtime aborted",
971971
),
972+
TorchLibOpInfo(
973+
"ops.aten.embedding_bag",
974+
core_ops.aten_embedding_bag,
975+
trace_only=True,
976+
# Output[0] is OK, but other 3 outputs just have the same shape with zero values
977+
nondeterministic=True,
978+
),
972979
TorchLibOpInfo(
973980
"nn.functional.embedding",
974981
core_ops.aten_embedding,
@@ -1872,6 +1879,9 @@ def _where_input_wrangler(
18721879
ops_test_common.duplicate_opinfo(OPS_DB, "new_full", ("new_full_dtype",))
18731880
ops_test_common.duplicate_opinfo(OPS_DB, "new_ones", ("new_ones_dtype",))
18741881
ops_test_common.duplicate_opinfo(OPS_DB, "new_zeros", ("new_zeros_dtype",))
1882+
# ops_test_common.duplicate_opinfo(
1883+
# OPS_DB, "nn.functional.embedding_bag", ("nn.functional.embedding_bag.padding_idx",)
1884+
# )
18751885
ops_test_common.duplicate_opinfo(
18761886
OPS_DB, "nn.functional.linear", ("nn.functional.linear_bias",)
18771887
)

0 commit comments

Comments
 (0)