Skip to content

Add Op(aten::_scaled_dot_product_efficient_attention) | feat(torchlib) #1197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,7 +1700,7 @@ def aten_scaled_dot_product_attention(


@torch_op("aten::_scaled_dot_product_flash_attention", private=True)
def _aten_scaled_dot_product_flash_attention_fillin_empty_outputs(
def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs(
query: TFloat,
) -> Tuple[FLOAT, INT64, INT64, FLOAT]:
query_first_three_dims = op.Slice(
Expand All @@ -1723,7 +1723,7 @@ def _aten_scaled_dot_product_flash_attention_fillin_empty_outputs(


@torch_op("aten::_scaled_dot_product_flash_attention", trace_only=True)
def aten_scaled_dot_product_flash_attention(
def aten__scaled_dot_product_flash_attention(
query: TFloat,
key: TFloat,
value: TFloat,
Expand Down Expand Up @@ -1751,7 +1751,7 @@ def aten_scaled_dot_product_flash_attention(
empty_tensor_int,
empty_int,
empty_tensor_float,
) = _aten_scaled_dot_product_flash_attention_fillin_empty_outputs(query)
) = _aten__scaled_dot_product_flash_attention_fillin_empty_outputs(query)

return (
result,
Expand All @@ -1766,6 +1766,73 @@ def aten_scaled_dot_product_flash_attention(
)


@torch_op("aten::_scaled_dot_product_efficient_attention", private=True)
def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs(
query: TFloat,
compute_log_sumexp: bool,
) -> Tuple[FLOAT, INT64]:
"""_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)"""

query = op.Transpose(query, perm=[0, 2, 1, 3])
query_shape = op.Shape(query)
query_first_dims = query_shape[:1]
query_second_dims = query_shape[1:2]
num_heads = query_shape[-2:-1]

if compute_log_sumexp:
logsumexp_dim = op.Cast(
op.Ceil(op.Cast(query_second_dims, to=FLOAT.dtype) / 32.0) * 32.0, to=INT64.dtype
)
logsum_exp = op.Expand(
0.0, op.Concat(query_first_dims, num_heads, logsumexp_dim, axis=0)
)
else:
logsum_exp = op.Expand(0.0, op.Concat(query_first_dims, num_heads, [0], axis=0))

# See Note [Seed and Offset]:
empty_tensor_int = op.Cast(
op.ConstantOfShape(
op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], []))
),
to=INT64.dtype,
)

return logsum_exp, empty_tensor_int


@torch_op("aten::_scaled_dot_product_efficient_attention", trace_only=True)
def aten__scaled_dot_product_efficient_attention(
query: TFloat,
key: TFloat,
value: TFloat,
attn_bias: Optional[TFloat], # pylint: disable=unused-argument
compute_log_sumexp: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
) -> Tuple[TFloat, FLOAT, INT64, INT64]:
"""_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)"""

result = aten_scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale
)

# The followings are not comsumed by the graph.
(
logsumexp,
empty_tensor_int,
) = _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs(
query, compute_log_sumexp
)

return (
result,
logsumexp,
empty_tensor_int,
empty_tensor_int,
)


@torch_op("aten::scaled_dot_product_attention", trace_only=True)
def aten_scaled_dot_product_attention_bool_mask(
query: TFloat,
Expand Down
60 changes: 57 additions & 3 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

import torch
from torch import testing as torch_testing
from torch.testing._internal import common_dtype, common_methods_invocations
from torch.testing._internal import (
common_device_type,
common_dtype,
common_methods_invocations,
)
from torch.testing._internal.opinfo import core as opinfo_core

S = 5
Expand Down Expand Up @@ -1298,7 +1302,7 @@ def sample_inputs__softmax(
yield opinfo_core.SampleInput(make_arg(shape), args=dim, kwargs=kwargs)


def sample_inputs_scaled_dot_product_flash_attention(
def sample_inputs__scaled_dot_product_flash_attention(
op_info, device, dtype, requires_grad, **kwargs
):
del op_info
Expand Down Expand Up @@ -1342,6 +1346,41 @@ def sample_inputs_scaled_dot_product_flash_attention(
yield from samples


def sample_inputs__scaled_dot_product_efficient_attention(
op_info, device, dtype, requires_grad, **kwargs
):
del op_info
del kwargs

make = opinfo_core.partial(
opinfo_core.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may also control device here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I enabled the whole CUDA tests if it's needed by the test.

)
batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8

dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)

qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]
samples = []
for qkv_shape, is_causal, dropout_p, compute_log_sumexp in opinfo_core.product(
qkv_shapes, [True, False], [0.0], [True, False]
):
shape_q, shape_kv = qkv_shape
samples.append(
opinfo_core.SampleInput(
make(shape_q),
make(shape_kv),
make(shape_kv),
attn_bias=None,
is_causal=is_causal,
dropout_p=dropout_p,
compute_log_sumexp=compute_log_sumexp,
)
)

yield from samples


# NOTE: In `_native_batch_norm_legit` tests, it generates two kinds of args:
# 1. (input, weight, bias, running_mean, running_var, training, momentum, eps)
# 2. (input, weight, bias, training, momentum, eps)
Expand Down Expand Up @@ -1765,11 +1804,26 @@ def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwar
dtypes=common_dtype.floating_types_and(torch.bfloat16),
# NOTE: Different from aten::scaled_dot_product_attention, this op doesn't support
# dim<=3 input.
sample_inputs_func=sample_inputs_scaled_dot_product_flash_attention,
sample_inputs_func=sample_inputs__scaled_dot_product_flash_attention,
supports_out=False,
supports_forward_ad=False,
supports_fwgrad_bwgrad=True,
check_batched_forward_grad=False,
),
opinfo_core.OpInfo(
"ops.aten._scaled_dot_product_efficient_attention",
aten_name="_scaled_dot_product_efficient_attention",
# only support CUDA
dtypes=common_dtype.empty_types(),
dtypesIfCUDA=common_dtype.floating_types_and(torch.bfloat16),
# NOTE: Different from aten::scaled_dot_product_attention, this op doesn't support
# dim<=3 input.
sample_inputs_func=sample_inputs__scaled_dot_product_efficient_attention,
supports_out=False,
supports_forward_ad=False,
supports_fwgrad_bwgrad=True,
check_batched_forward_grad=False,
decorators=[common_device_type.onlyCUDA],
),
opinfo_core.OpInfo(
"ops.aten._native_batch_norm_legit",
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,11 @@ def test_complex_output_match_opinfo_(


common_device_type.instantiate_device_type_tests(
TestOutputConsistencyEager, globals(), only_for="cpu"
TestOutputConsistencyEager, globals(), only_for=["cpu", "cuda"]
)

common_device_type.instantiate_device_type_tests(
TestOutputConsistencyFullGraph, globals(), only_for="cpu"
TestOutputConsistencyFullGraph, globals(), only_for=["cpu", "cuda"]
)

if __name__ == "__main__":
Expand Down
20 changes: 19 additions & 1 deletion onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2018,15 +2018,33 @@ def _where_input_wrangler(
),
TorchLibOpInfo(
"ops.aten._scaled_dot_product_flash_attention",
nn_ops.aten_scaled_dot_product_flash_attention,
nn_ops.aten__scaled_dot_product_flash_attention,
trace_only=True,
tolerance={torch.float32: (3e-4, 1.5e-5)},
# Output[0] is OK, but other outputs just have the same shape with zero values
nondeterministic=True,
compare_shape_only_for_output=(1, 2, 3, 4, 5, 6, 7, 8),
).skip(
enabled_if=version_utils.torch_older_than("2.1"),
reason="The operator is not supported in older version.",
),
TorchLibOpInfo(
"ops.aten._scaled_dot_product_efficient_attention",
nn_ops.aten__scaled_dot_product_efficient_attention,
trace_only=True,
tolerance={torch.float32: (3e-4, 1.5e-5)},
# Output[0] is OK, but other outputs just have the same shape with zero values
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the other compare option instead

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare_shape_only_for_output

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

nondeterministic=True,
compare_shape_only_for_output=(1, 2, 3),
)
.skip(
enabled_if=version_utils.torch_older_than("2.1"),
reason="The operator is not supported in older version.",
)
.skip(
enabled_if=not torch.cuda.is_available(),
reason="_scaled_dot_product_efficient_attention only supports CUDA",
),
TorchLibOpInfo(
"nn.functional.scaled_dot_product_attention_bool_mask",
nn_ops.aten_scaled_dot_product_attention_bool_mask,
Expand Down