diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 9541491c07..832d5ffe7a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2646,7 +2646,7 @@ def aten_embedding_backward( def aten_embedding_bag( weight: TFloat, indices: INT64, - offsets: INT64, + offsets: INT64 = None, scale_grad_by_freq: bool = False, # pylint: disable=unused-argument mode: int = 0, # [0,1,2] indicate ["sum", "mean", "max"] sparse: bool = False, # pylint: disable=unused-argument @@ -2658,18 +2658,54 @@ def aten_embedding_bag( # assert(rank(indices) in [1,2]) # assert(rank(offsets) == 1) # assert(op.Size(per_sample_weights) == op.Size(indices)) - if per_sample_weights is None: - # Set per_sample_weights to 1.0, because cannot check 'None' in ONNX-Script - # Size of persample_weights is the same as indices, and should be 1d tensor - indices_1d = op.Reshape(indices, [-1]) - per_sample_weights = op.Expand(1, op.Shape(indices_1d)) - # Dtype of per_sample_weights is the same as weight - per_sample_weights = op.CastLike(per_sample_weights, weight) + # if per_sample_weights is None: + # # Set per_sample_weights to 1.0, because cannot check 'None' in ONNX-Script + # # Size of persample_weights is the same as indices, and should be 1d tensor + # indices_1d = op.Reshape(indices, [-1]) + # per_sample_weights = op.Expand(1, op.Shape(indices_1d)) + # # Dtype of per_sample_weights is the same as weight + # per_sample_weights = op.CastLike(per_sample_weights, weight) + + if offsets is None: + if per_sample_weights is None: + # Set per_sample_weights to 1.0, because cannot check 'None' in ONNX-Script + # Size of persample_weights is the same as indices, and should be 1d tensor + per_sample_weights = op.Expand(1, op.Shape(indices)) + per_sample_weights = op.CastLike(per_sample_weights, weight) + result = _aten_embedding_bag_2d_onnx(weight, indices, mode, per_sample_weights) + else: + if per_sample_weights is None: + # Set per_sample_weights to 1.0, because cannot check 'None' in ONNX-Script + # Size of persample_weights is the same as indices, and should be 1d tensor + indices_1d = op.Reshape(indices, [-1]) + per_sample_weights = op.Expand(1, op.Shape(indices_1d)) + # Dtype of per_sample_weights is the same as weight + per_sample_weights = op.CastLike(per_sample_weights, weight) + result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx( + weight, indices, offsets, mode, per_sample_weights, include_last_offset + ) + return result #, offset2bag, bag_size, max_indices - result, offset2bag, bag_size, max_indices = _aten_embedding_bag_onnx( - weight, indices, offsets, mode, per_sample_weights, include_last_offset - ) - return result, offset2bag, bag_size, max_indices + +@torch_op("aten::embedding_bag", private=True) +def _aten_embedding_bag_2d_onnx( + weight: TFloat, + indices: INT64, + mode: int, + per_sample_weights: TFloat, +) -> TFloat: + # Get weight out according to indices + new_weight = op.Gather(weight, indices) + new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=2)) + if mode == 1: # mean + result = op.ReduceMean(new_weight, axes=[1], keepdims=False) + elif mode == 2: # max + result = op.ReduceMax(new_weight, axes=[1], keepdims=False) + else: # sum + # assert(mode == 0) + result = op.ReduceSum(new_weight, axes=[1], keepdims=False) + + return result @torch_op("aten::embedding_bag", private=True) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index c35405b5b1..cb49b00dac 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -657,13 +657,9 @@ def make_per_sample_weight(flag, idx): for offset in offsets: for include_last_offset in (True, False): for generate_per_sample_weight in (True, False): - for mode in ( - 0, - 1, - 2, - ): # ('sum', 'mean', 'max') + for mode in ('sum', 'mean', 'max'): # per_sample_weights only support mode='sum' - if generate_per_sample_weight and mode in (1, 2): # ('mean', 'max'): + if generate_per_sample_weight and mode in ("mean", "max"): continue # 1-D index tensor @@ -699,7 +695,7 @@ def make_per_sample_weight(flag, idx): }, ) - if mode != 2: # "max" mode in 2-D index tensor make aten func crash + if mode != "max": # "max" mode in 2-D index tensor make aten func crash # 2-D index tensor indices = make_long_input((S, S), low=0, high=M) per_sample_weights = make_per_sample_weight( @@ -1003,13 +999,13 @@ def sample_inputs_scaled_dot_product_flash_attention( sample_inputs_func=sample_inputs_col2im, supports_out=False, ), - opinfo_core.OpInfo( - "ops.aten.embedding_bag", - aten_name="embedding_bag", - dtypes=common_dtype.floating_types_and_half(), - sample_inputs_func=sample_inputs_embedding_bag, - supports_out=False, - ), + # opinfo_core.OpInfo( + # "ops.aten.embedding_bag", + # aten_name="embedding_bag", + # dtypes=common_dtype.floating_types_and_half(), + # sample_inputs_func=sample_inputs_embedding_bag, + # supports_out=False, + # ), opinfo_core.OpInfo( "ops.aten.embedding_bag.padding_idx", aten_name="embedding_bag.padding_idx", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index ffc370171a..c3fc616e46 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -235,6 +235,17 @@ def _dropout_input_wrangler( return args, kwargs +def _embedding_bag_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if "mode" in kwargs: + mode_vals = ["sum", "mean", "max"] + value = kwargs["mode"] + idx = mode_vals.index(value) + kwargs["mode"] = idx + return args, kwargs + + def _embedding_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -1031,11 +1042,17 @@ def _where_input_wrangler( ), TorchLibOpInfo("nn.functional.elu", nn_ops.aten_elu), TorchLibOpInfo( - "ops.aten.embedding_bag", + "nn.functional.embedding_bag", core_ops.aten_embedding_bag, + input_wrangler=_embedding_bag_input_wrangler, tolerance={torch.float16: (1e-2, 1e-2)}, trace_only=True, compare_shape_only_for_output=(1, 2, 3), + ).skip( + matcher=lambda sample: sample.kwargs.get("padding_idx") is not None + or sample.kwargs.get("max_norm") is not None + or sample.kwargs.get("norm_type") is not None, + reason="this overload only support none padding_idx, max_norm and norm_type in kwargs", ), TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx",