Skip to content

Add Op(_scaled_dot_product_flash_attention) | feat(torchlib) #1043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import math
from typing import Optional, Sequence, Tuple

import onnx

from onnxscript import FLOAT, INT64
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import (
Expand Down Expand Up @@ -1673,6 +1675,73 @@ def aten_scaled_dot_product_attention(
)


@torch_op("aten::_scaled_dot_product_flash_attention", private=True)
def _aten_scaled_dot_product_flash_attention_fillin_empty_outputs(
query: TFloat,
) -> Tuple[FLOAT, INT64, INT64, FLOAT]:
query_first_three_dims = op.Slice(
op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3])
)
logsumexp = op.Expand(0.0, query_first_three_dims)
# TODO: Eliminate `make_tensor` usage when ORT supports empty tensor.
empty_tensor_int = op.Cast(
op.ConstantOfShape(
op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], []))
),
to=INT64.dtype,
)
empty_tensor_float = op.ConstantOfShape(
op.Constant(value=onnx.helper.make_tensor("Empty_FLOATS", INT64.dtype, [0], []))
)
empty_int = op.Constant(value_int=0)

return logsumexp, empty_tensor_int, empty_int, empty_tensor_float


@torch_op("aten::_scaled_dot_product_flash_attention", trace_only=True)
def aten_scaled_dot_product_flash_attention(
query: TFloat,
key: TFloat,
value: TFloat,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False, # pylint: disable=unused-argument
scale: Optional[float] = None,
) -> Tuple[TFloat, FLOAT, INT64, INT64, INT64, INT64, INT64, INT64, FLOAT]:
"""_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)

One of the implementations of scaled_dot_product_attention.
Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

NOTE: Currently, there are three implementations of nn.scaled_dot_product_attention in PyTorch due to optimization.
However, it's the same implementation from ONNX perspective.

"""
result = aten_scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale
)

# The followings are not comsumed by the graph.
(
logsumexp,
empty_tensor_int,
empty_int,
empty_tensor_float,
) = _aten_scaled_dot_product_flash_attention_fillin_empty_outputs(query)

return (
result,
logsumexp,
empty_tensor_int,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I create TInt for these guys?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

INT64?

Copy link
Contributor Author

@titaiwangms titaiwangms Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one in embedding remains TFloat though. But I can do INT64 in this case. Depends should we follow native-func sig or what we really return.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the return types for the empty float values need to be TFloat, do we need a CaskLike self here? Otherwise it would be FLOAT because the dtype is set and not dependent on the input?

empty_tensor_int,
empty_int,
empty_int,
empty_tensor_int,
empty_tensor_int,
empty_tensor_float,
)


@torch_op("aten::scaled_dot_product_attention", trace_only=True)
def aten_scaled_dot_product_attention_bool_mask(
query: TFloat,
Expand Down
56 changes: 56 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,50 @@ def sample_inputs__softmax(
yield opinfo_core.SampleInput(make_arg(shape), args=dim, kwargs=kwargs)


def sample_inputs_scaled_dot_product_flash_attention(
op_info, device, dtype, requires_grad, **kwargs
):
del op_info
del kwargs

make = opinfo_core.partial(
opinfo_core.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8

dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)

qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]
samples = []
for qkv_shape, is_causal, dropout_p in opinfo_core.product(
qkv_shapes, [True, False], [0.0]
):
shape_q, shape_kv = qkv_shape
samples.append(
opinfo_core.SampleInput(
make(shape_q),
make(shape_kv),
make(shape_kv),
is_causal=is_causal,
dropout_p=dropout_p,
)
)

# Add an attn_mask
samples.append(
opinfo_core.SampleInput(
make((batch, num_heads, seq_q, head_dim)),
make((batch, num_heads, seq_kv, head_dim)),
make((batch, num_heads, seq_kv, head_dim)),
is_causal=False,
dropout_p=0.0,
)
)

yield from samples


# NOTE: How to create an OpInfo:
# 1. Create a function that generates sample inputs for the op.
# This function should yield SampleInputs.
Expand Down Expand Up @@ -1130,4 +1174,16 @@ def sample_inputs__softmax(
sample_inputs_func=sample_inputs__softmax,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten._scaled_dot_product_flash_attention",
aten_name="_scaled_dot_product_flash_attention",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
# NOTE: Different from aten::scaled_dot_product_attention, this op doesn't support
# dim<=3 input.
sample_inputs_func=sample_inputs_scaled_dot_product_flash_attention,
supports_out=False,
supports_forward_ad=False,
supports_fwgrad_bwgrad=True,
check_batched_forward_grad=False,
),
]
11 changes: 11 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,17 @@ def _where_input_wrangler(
matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
reason="dropout is random so the results do not match",
),
TorchLibOpInfo(
"ops.aten._scaled_dot_product_flash_attention",
nn_ops.aten_scaled_dot_product_flash_attention,
trace_only=True,
tolerance={torch.float32: (3e-4, 1.5e-5)},
# Output[0] is OK, but other outputs just have the same shape with zero values
nondeterministic=True,
).skip(
enabled_if=version_utils.torch_older_than("2.1"),
reason="The operator is not supported in older version.",
),
TorchLibOpInfo(
"nn.functional.scaled_dot_product_attention_bool_mask",
nn_ops.aten_scaled_dot_product_attention_bool_mask,
Expand Down