Skip to content

Add Op(embedding_bag_padding_idx) | torchlib(feat) #1022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 29, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class TestDeduceTypeConstraints(unittest.TestCase):
"_aten_as_strided_onnx",
"_aten_unfold_onnx",
"_aten_embedding_bag_onnx",
"_aten_embedding_bag_1d_padding_idx_onnx",
)
_SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ()

Expand Down
193 changes: 156 additions & 37 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,42 +2364,12 @@ def aten_embedding_bag(
# Dtype of per_sample_weights is the same as weight
per_sample_weights = op.CastLike(per_sample_weights, weight)

result = _aten_embedding_bag_onnx(
result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx(
weight, indices, offsets, mode, per_sample_weights, include_last_offset
)
offset2bag, bag_size, max_indices = _compute_output_others_shape(
weight, indices, offsets, mode, include_last_offset
)
return result, offset2bag, bag_size, max_indices


# This python function only compute the shape of outputs instead of values, fill with 0
def _compute_output_others_shape(weight, indices, offsets, mode, include_last_off):
if mode == 0: # sum
offset2bag = op.Shape(indices, start=0, end=0) # Generate empty tensor
bag_size = op.Expand(0, op.Shape(offsets))
max_indices = op.Expand(0, op.Shape(offsets))
elif mode == 1: # mean
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
if include_last_off is True:
bag_size = op.Expand(0, op.Shape(offsets) - 1)
else:
bag_size = op.Expand(0, op.Shape(offsets))
max_indices = op.Expand(0, op.Shape(bag_size))
else: # max
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
if include_last_off is True:
bag_size = op.Expand(0, op.Shape(offsets) - 1)
else:
bag_size = op.Expand(0, op.Shape(offsets))
# shape = (bag_size.dim[0], weight.dim[1])
dim_0 = op.Shape(bag_size, start=0, end=1)
dim_1 = op.Shape(weight, start=1, end=2)
max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0))

return offset2bag, bag_size, max_indices


@torch_op("aten::embedding_bag", private=True)
def _aten_embedding_bag_onnx(
weight: TFloat,
Expand All @@ -2408,7 +2378,7 @@ def _aten_embedding_bag_onnx(
mode: int,
per_sample_weights: TFloat,
include_last_offset: bool,
) -> TFloat:
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
neg_1 = op.Constant(value_ints=[-1])
# Assume indices is shape(5,2), indices_1d is shape(10,)
indices_1d = op.Reshape(indices, neg_1)
Expand All @@ -2419,9 +2389,9 @@ def _aten_embedding_bag_onnx(
weight_dim_1 = op.Reshape(op.Shape(weight, start=1), neg_1)
indices_size = op.Shape(indices_1d)

# Assume indices is shape(5,2), offsets=[0,2,3], include_last_offset = False
# Assume indices is shape(5,2) reshape to (10,), offsets=[0,2,3], include_last_offset = False
# [0,2,3] -> [0:2], [2:3], [3:10]
num_bag = op.Reshape(op.Size(offsets), neg_1) # 3 bags, means 15 is the last index
num_bag = op.Reshape(op.Size(offsets), neg_1) # 3 bags, means 10 is the last index
if op.Equal(include_last_offset, True):
num_bag = num_bag - 1 # 2 bags, means 3 is the last index
else:
Expand Down Expand Up @@ -2452,8 +2422,8 @@ def _aten_embedding_bag_onnx(
weight_rows = op.Slice(new_weight, start, end)
if op.Equal(index_tensor, num_bag - 1): # The last bag
row_result = op.ReduceSum(weight_rows, axes=[0])
# When include_last_offset=False, offsets=[0,2,3], denominator=5-3=2
# When include_last_offset=True, offsets=[0,2,3], denominator=5-2=3
# When include_last_offset=False, offsets=[0,2,3] -> [0,2,3,10], denominator=10-3=7
# When include_last_offset=True, offsets=[0,2,3], denominator=10-2=8
denominator = op.Sub(op.Shape(indices, start=0, end=1), start)
if op.Greater(denominator, 0):
row_result = op.Div(row_result, op.CastLike(denominator, new_weight))
Expand All @@ -2469,8 +2439,157 @@ def _aten_embedding_bag_onnx(
result = op.SequenceInsert(result, row_result)
index_tensor = index_tensor + 1
cond = index_tensor < num_bag

result = op.ConcatFromSequence(result, axis=0)
return op.CastLike(result, weight)
result = op.CastLike(result, weight)

# Only compute the shape of other 3 outputs, we don't care the value
if mode == 0: # sum
offset2bag = op.Shape(indices, start=0, end=0) # Generate empty tensor
if op.Equal(include_last_offset, True):
bag_size = op.Expand(0, op.Shape(offsets))
else:
bag_size = op.Expand(0, op.Shape(offsets) - 1)
max_indices = op.Expand(0, op.Shape(bag_size))
elif mode == 1: # mean
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
bag_size = op.Expand(0, op.Shape(offsets) - 1)
max_indices = op.Expand(0, op.Shape(bag_size))
else: # max
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
bag_size = op.Expand(0, op.Shape(offsets) - 1)
# shape = (bag_size.dim[0], weight.dim[1])
dim_0 = op.Shape(bag_size, start=0, end=1)
dim_1 = op.Shape(weight, start=1, end=2)
max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0))

return result, offset2bag, bag_size, max_indices


@torch_op("aten::embedding_bag.padding_idx", trace_only=True)
def aten_embedding_bag_padding_idx(
weight: TFloat,
indices: INT64,
offsets: INT64 = None, # Could be None according to the doc, go 2d branch
scale_grad_by_freq: bool = False, # pylint: disable=unused-argument
mode: int = 1, # [0,1,2] indicate ["sum", "mean", "max"], default is "mean"
sparse: bool = False, # pylint: disable=unused-argument
per_sample_weights: Optional[TFloat] = None,
include_last_offset: bool = False,
padding_idx: Optional[int] = None,
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
"""embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)"""
# assert(padding_idx is not None)

if per_sample_weights is None:
per_sample_weights = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(indices))
per_sample_weights = op.CastLike(per_sample_weights, weight)

# Change padding_idx to positive value, -1 means the last index
if padding_idx < 0:
padding_idx = weight.shape[0] + padding_idx

result, offset2bag, bag_size, max_indices = _aten_embedding_bag_1d_padding_idx_onnx(
weight, indices, offsets, mode, per_sample_weights, include_last_offset, padding_idx
)

return result, offset2bag, bag_size, max_indices


@torch_op("aten::embedding_bag.padding_idx", private=True)
def _aten_embedding_bag_1d_padding_idx_onnx(
weight: TFloat,
indices: INT64,
offsets: INT64,
mode: int,
per_sample_weights: TFloat,
include_last_offset: bool,
padding_idx: int,
) -> Tuple[TFloat, TFloat, TFloat, TFloat]:
neg_1 = op.Constant(value_ints=[-1])
# Get weight out according to indices,
# e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
indices_weight = op.Gather(weight, indices)
# This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=1))

# The element in sequence must be FLOAT32 dtype due to ORT bug
indices_weight = op.Cast(indices_weight, to=FLOAT.dtype)
# FIXME: https://github.com/microsoft/onnxruntime/issues/16846
result = op.SequenceEmpty()

num_bag = op.Reshape(op.Size(offsets), neg_1)
idx_size = op.Reshape(op.Size(indices), neg_1)

if op.Equal(include_last_offset, True):
num_bag = num_bag - 1
# Change(by ScatterElement setting) the last element to 'end'
# [0,2,3] -> [0,2,end]
offsets = op.ScatterElements(offsets, [-1], idx_size)
else:
# Change [0,2,3] -> [0,2,3,end], means [0:2],[2:3],[3:end]
offsets = op.Concat(offsets, idx_size, axis=0)

# Process each bag
i = op.Reshape(op.Constant(value_int=0), neg_1) # Used for iterator
cond_1 = i < num_bag
while cond_1:
start_pos = op.Gather(offsets, i)
end_pos = op.Gather(offsets, i + 1)
# empty tensor
curr_offsets = op.Shape(indices, start=0, end=0)
j = start_pos
cond_2 = j < end_pos
while cond_2:
index = op.Gather(indices, j)
if not op.Equal(index, padding_idx):
# Something like the 'append' operation
curr_offsets = op.Concat(curr_offsets, op.Reshape(j, neg_1), axis=0)
j = j + 1
cond_2 = j < end_pos

# Empty input get zero value output, not empty output
if op.Size(curr_offsets) == 0:
dim_1 = op.Shape(weight, start=1, end=2)
expand_shape = op.Concat([1], dim_1, axis=0)
row_result = op.Expand([0.0], expand_shape)
else:
row_weight = op.Gather(indices_weight, curr_offsets)
if mode == 0: # sum
row_result = op.ReduceSum(row_weight, axes=[0])
elif mode == 1: # mean
row_result = op.ReduceMean(row_weight, axes=[0])
else:
row_result = op.ReduceMax(row_weight, axes=[0])

result = op.SequenceInsert(result, row_result)

i = i + 1
cond_1 = i < num_bag

result = op.ConcatFromSequence(result, axis=0)
result = op.CastLike(result, weight)

if mode == 0: # sum
offset2bag = op.Expand(0, op.Shape(indices))
if op.Equal(include_last_offset, True):
bag_size = op.Expand(0, op.Shape(offsets))
else:
bag_size = op.Expand(0, op.Shape(offsets) - 1)
max_indices = op.Expand(0, op.Shape(bag_size))
elif mode == 1: # mean
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
bag_size = op.Expand(0, op.Shape(offsets) - 1)
max_indices = op.Expand(0, op.Shape(bag_size))
else: # mode == 2, max
offset2bag = op.Expand(0, op.Shape(indices, start=0, end=1))
bag_size = op.Expand(0, op.Shape(offsets) - 1)
# shape = (bag_size.dim[0], weight.dim[1])
dim_0 = op.Shape(bag_size, start=0, end=1)
dim_1 = op.Shape(weight, start=1, end=2)
max_indices = op.Expand(0, op.Concat(dim_0, dim_1, axis=0))

return result, offset2bag, bag_size, max_indices


def aten_embedding_dense_backward(
Expand Down
129 changes: 129 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,128 @@ def make_per_sample_weight(flag, idx):
)


def sample_inputs_embedding_bag_padding_idx(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

def make_input(shape):
return common_methods_invocations.make_tensor(
shape, device=device, dtype=dtype, requires_grad=requires_grad
)

def make_long_input(shape, *, low, high, noncontiguous=False):
return common_methods_invocations.make_tensor(
shape,
device=device,
dtype=torch.long,
low=low,
high=high,
noncontiguous=noncontiguous,
)

def make_per_sample_weight(flag, idx):
# a tensor of float / double weights, or None
# to indicate all weights should be taken to be 1
if flag:
return make_input(idx.reshape(-1).shape)
return None

offsets = [
torch.tensor([0, 2, 3], device=device, dtype=torch.long),
# Below case not work for FullGraph mode, guess due to op.While() bug:
# when the initial condition is False, it still excute the loop body once.
# torch.tensor([0, 0, 2], device=device, dtype=torch.long),
# torch.tensor([0, 2, 2, 4], device=device, dtype=torch.long),
]
for offset in offsets:
for include_last_offset in (True, False):
for generate_per_sample_weight in (True, False):
for mode in (
0,
1,
2,
): # ('sum', 'mean', 'max')
# per_sample_weights only support mode='sum'
if generate_per_sample_weight and mode in (1, 2): # ('mean', 'max'):
continue

for padding_idx in (-1, 0, 1, 2, 3):
# 1-D index tensor
indices = make_long_input((S,), low=0, high=M)
per_sample_weights = make_per_sample_weight(
generate_per_sample_weight, indices
)
# 0
yield common_methods_invocations.SampleInput(
make_input((M, S)),
args=(indices,),
kwargs={
"offsets": offset,
"scale_grad_by_freq": False,
"mode": mode,
"sparse": False,
"per_sample_weights": per_sample_weights,
"include_last_offset": include_last_offset,
"padding_idx": padding_idx,
},
)

indices = make_long_input((S,), low=0, high=M, noncontiguous=True)
per_sample_weights = make_per_sample_weight(
generate_per_sample_weight, indices
)
# 1
yield common_methods_invocations.SampleInput(
make_input((M, S)),
args=(indices,),
kwargs={
"offsets": offset,
"scale_grad_by_freq": False,
"mode": mode,
"sparse": False,
"per_sample_weights": per_sample_weights,
"include_last_offset": include_last_offset,
"padding_idx": padding_idx,
},
)

# if mode != 2: # "max" mode in 2-D index tensor make aten func crash
# # 2-D index tensor
# indices = make_long_input((S, S), low=0, high=M)
# per_sample_weights = make_per_sample_weight(
# generate_per_sample_weight, indices
# )
# # 2
# yield common_methods_invocations.SampleInput(
# make_input((M, S)),
# args=(indices,),
# kwargs={
# "offsets": offset,
# "mode": mode,
# "per_sample_weights": per_sample_weights,
# "include_last_offset": include_last_offset,
# "padding_idx": padding_idx,
# },
# )

# indices = make_long_input((S, S), low=0, high=M, noncontiguous=True)
# per_sample_weights = make_per_sample_weight(
# generate_per_sample_weight, indices
# )
# # 3
# yield common_methods_invocations.SampleInput(
# make_input((M, S)),
# args=(indices,),
# kwargs={
# "offsets": offset,
# "mode": mode,
# "per_sample_weights": per_sample_weights,
# "include_last_offset": include_last_offset,
# "padding_idx": padding_idx,
# },
# )


def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs):
del op_info
# Case `target_end == 1`, where `target_end = (input.size(dimension) - size) // step + 1`.
Expand Down Expand Up @@ -844,6 +966,13 @@ def sample_inputs__softmax(
sample_inputs_func=sample_inputs_embedding_bag,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.embedding_bag.padding_idx",
aten_name="embedding_bag.padding_idx",
dtypes=common_dtype.floating_types_and_half(),
sample_inputs_func=sample_inputs_embedding_bag_padding_idx,
supports_out=False,
),
opinfo_core.OpInfo(
"nn.functional.conv3d",
aten_name="conv3d",
Expand Down
Loading