Skip to content

Commit e83725e

Browse files
cherry pick pull/2879 to release/2.3 branch (#2882)
1 parent 9643560 commit e83725e

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,12 @@ def embedding_bag_with_traversable_offsets(
9494
# however, pytorch doc says if `include_last_offset` is True, the size of offsets
9595
# is equal to the number of bags + 1. The last element is the size of the input,
9696
# or the ending index position of the last bag (sequence).
97-
offsets.itemset(-1, len_embed)
97+
# Notes: here offsets should always be 1d array
98+
if len(offsets.shape) != 1:
99+
raise TypeError(
100+
f"The offsets should be in 1d array, here offset shape is {offsets.shape}."
101+
)
102+
offsets[-1] = len_embed
98103
else:
99104
# add the end index to offsets
100105
offsets = np.append(offsets, len_embed)

0 commit comments

Comments
 (0)