@@ -2266,19 +2266,144 @@ def aten_embedding_backward(
2266
2266
raise NotImplementedError ()
2267
2267
2268
2268
2269
+ @torch_op (
2270
+ (
2271
+ "aten::embedding_bag" ,
2272
+ "aten::_embedding_bag" ,
2273
+ "aten::_embedding_bag_forward_only" ,
2274
+ ),
2275
+ trace_only = True ,
2276
+ )
2269
2277
def aten_embedding_bag (
2270
- weight : TensorType ,
2271
- indices : TensorType ,
2272
- offsets : TensorType ,
2273
- scale_grad_by_freq : bool = False ,
2274
- mode : int = 0 ,
2275
- sparse : bool = False ,
2276
- per_sample_weights : Optional [TensorType ] = None ,
2278
+ weight : TFloat ,
2279
+ indices : INT64 ,
2280
+ offsets : INT64 = None , # Could be None accotding to the doc, go 2d branch
2281
+ scale_grad_by_freq : bool = False , # pylint: disable=unused-argument
2282
+ mode : int = 1 , # [0,1,2] indicate ["sum", "mean", "max"], default is "mean"
2283
+ sparse : bool = False , # pylint: disable=unused-argument
2284
+ per_sample_weights : Optional [TFloat ] = None ,
2277
2285
include_last_offset : bool = False ,
2278
- ) -> tuple [ TensorType , TensorType , TensorType , TensorType ]:
2286
+ ) -> Tuple [ TFloat , TFloat , TFloat , TFloat ]:
2279
2287
"""embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)"""
2280
2288
2281
- raise NotImplementedError ()
2289
+ # assert(rank(indices) in [1,2])
2290
+ # assert(rank(offsets) == 1)
2291
+ # assert(op.Size(per_sample_weights) == op.Size(indices))
2292
+ if per_sample_weights is None :
2293
+ # Set per_sample_weights to 1.0, because cannot check 'None' in ONNX-Script
2294
+ # Size of persample_weights is the same as indices, and should be 1d tensor
2295
+ indices_1d = op .Reshape (indices , [- 1 ])
2296
+ per_sample_weights = op .Expand (1 , op .Shape (indices_1d ))
2297
+ # Dtype of per_sample_weights is the same as weight
2298
+ per_sample_weights = op .CastLike (per_sample_weights , weight )
2299
+
2300
+ result = _aten_embedding_bag_onnx (
2301
+ weight , indices , offsets , mode , per_sample_weights , include_last_offset
2302
+ )
2303
+ offset2bag , bag_size , max_indices = _compute_output_others_shape (
2304
+ weight , indices , offsets , mode , include_last_offset
2305
+ )
2306
+ return result , offset2bag , bag_size , max_indices
2307
+
2308
+
2309
+ # This python function only compute the shape of outputs instead of values, fill with 0
2310
+ def _compute_output_others_shape (weight , indices , offsets , mode , include_last_off ):
2311
+ if mode == 0 : # sum
2312
+ offset2bag = op .Shape (indices , start = 0 , end = 0 ) # Generate empty tensor
2313
+ bag_size = op .Expand (0 , op .Shape (offsets ))
2314
+ max_indices = op .Expand (0 , op .Shape (offsets ))
2315
+ elif mode == 1 : # mean
2316
+ offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
2317
+ if include_last_off is True :
2318
+ bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
2319
+ else :
2320
+ bag_size = op .Expand (0 , op .Shape (offsets ))
2321
+ max_indices = op .Expand (0 , op .Shape (bag_size ))
2322
+ else : # max
2323
+ offset2bag = op .Expand (0 , op .Shape (indices , start = 0 , end = 1 ))
2324
+ if include_last_off is True :
2325
+ bag_size = op .Expand (0 , op .Shape (offsets ) - 1 )
2326
+ else :
2327
+ bag_size = op .Expand (0 , op .Shape (offsets ))
2328
+ # shape = (bag_size.dim[0], weight.dim[1])
2329
+ dim_0 = op .Shape (bag_size , start = 0 , end = 1 )
2330
+ dim_1 = op .Shape (weight , start = 1 , end = 2 )
2331
+ max_indices = op .Expand (0 , op .Concat (dim_0 , dim_1 , axis = 0 ))
2332
+
2333
+ return offset2bag , bag_size , max_indices
2334
+
2335
+
2336
+ @torch_op ("aten::embedding_bag" , private = True )
2337
+ def _aten_embedding_bag_onnx (
2338
+ weight : TFloat ,
2339
+ indices : INT64 ,
2340
+ offsets : INT64 ,
2341
+ mode : int ,
2342
+ per_sample_weights : TFloat ,
2343
+ include_last_offset : bool ,
2344
+ ) -> TFloat :
2345
+ neg_1 = op .Constant (value_ints = [- 1 ])
2346
+ # Assume indices is shape(5,2), indices_1d is shape(10,)
2347
+ indices_1d = op .Reshape (indices , neg_1 )
2348
+ # Get weight out according to indices_1d,
2349
+ new_weight = op .Gather (weight , indices_1d )
2350
+ # This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
2351
+ new_weight = op .Mul (new_weight , op .Unsqueeze (per_sample_weights , axes = 1 ))
2352
+ weight_dim_1 = op .Reshape (op .Shape (weight , start = 1 ), neg_1 )
2353
+ indices_size = op .Shape (indices_1d )
2354
+
2355
+ # Assume indices is shape(5,2), offsets=[0,2,3], include_last_offset = False
2356
+ # [0,2,3] -> [0:2], [2:3], [3:10]
2357
+ num_bag = op .Reshape (op .Size (offsets ), neg_1 ) # 3 bags, means 15 is the last index
2358
+ if op .Equal (include_last_offset , True ):
2359
+ num_bag = num_bag - 1 # 2 bags, means 3 is the last index
2360
+ else :
2361
+ offsets = op .Concat (offsets , indices_size , axis = 0 ) # Replace end with number
2362
+
2363
+ # The element in sequence must be FLOAT32 dtype due to ORT bug
2364
+ new_weight = op .Cast (new_weight , to = FLOAT .dtype )
2365
+ # FIXME: https://github.com/microsoft/onnxruntime/issues/16846
2366
+ result = op .SequenceEmpty ()
2367
+
2368
+ index_tensor = op .Reshape (op .Constant (value_int = 0 ), neg_1 ) # Used for iterator
2369
+ cond = index_tensor < num_bag
2370
+ # Process each bag
2371
+ while cond :
2372
+ start = op .Slice (offsets , index_tensor , index_tensor + 1 )
2373
+ end = op .Slice (offsets , index_tensor + 1 , index_tensor + 2 )
2374
+ # row_result should be 0, need to generate (1,N) shape tensor with 0 values
2375
+ if start == end :
2376
+ row_result = op .Expand (
2377
+ op .Constant (value_floats = [0.0 ]),
2378
+ op .Concat (op .Constant (value_ints = [1 ]), weight_dim_1 , axis = 0 ),
2379
+ )
2380
+ else :
2381
+ if mode == 0 : # sum
2382
+ weight_rows = op .Slice (new_weight , start , end )
2383
+ row_result = op .ReduceSum (weight_rows , axes = [0 ])
2384
+ elif mode == 1 : # mean
2385
+ weight_rows = op .Slice (new_weight , start , end )
2386
+ if op .Equal (index_tensor , num_bag - 1 ): # The last bag
2387
+ row_result = op .ReduceSum (weight_rows , axes = [0 ])
2388
+ # When include_last_offset=False, offsets=[0,2,3], denominator=5-3=2
2389
+ # When include_last_offset=True, offsets=[0,2,3], denominator=5-2=3
2390
+ denominator = op .Sub (op .Shape (indices , start = 0 , end = 1 ), start )
2391
+ if op .Greater (denominator , 0 ):
2392
+ row_result = op .Div (row_result , op .CastLike (denominator , new_weight ))
2393
+ else :
2394
+ row_result = op .ReduceMean (weight_rows , axes = [0 ])
2395
+ else : # max
2396
+ if op .Equal (index_tensor , num_bag - 1 ): # The last bag
2397
+ weight_rows = op .Slice (new_weight , start , indices_size )
2398
+ else :
2399
+ weight_rows = op .Slice (new_weight , start , end )
2400
+ row_result = op .ReduceMax (weight_rows , axes = [0 ])
2401
+
2402
+ result = op .SequenceInsert (result , row_result )
2403
+ index_tensor = index_tensor + 1
2404
+ cond = index_tensor < num_bag
2405
+ result = op .ConcatFromSequence (result , axis = 0 )
2406
+ return op .CastLike (result , weight )
2282
2407
2283
2408
2284
2409
def aten_embedding_dense_backward (
0 commit comments