From 9d46a5af732cc5e331ecfa59cd32c7138bcec5b4 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Thu, 31 Aug 2023 18:02:18 +0000 Subject: [PATCH 1/6] Add Op(meta__scaled_dot_product_flash) | feat(torchlib) --- onnxscript/function_libs/torch_lib/ops/nn.py | 41 +++++++++++++ .../function_libs/torch_lib/extra_opinfo.py | 59 +++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 6 ++ 3 files changed, 106 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 1c35a55a04..6cbd8abfa5 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1673,6 +1673,47 @@ def aten_scaled_dot_product_attention( ) +@torch_op("aten::_scaled_dot_product_flash_attention", trace_only=True) +def aten_scaled_dot_product_flash_attention( + query: TFloat, + key: TFloat, + value: TFloat, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, # pylint: disable=unused-argument + scale: Optional[float] = None, +) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat, TFloat, INT64, INT64, TFloat]: + """_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor ouput, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + + One of the implementations of scaled_dot_product_attention. + Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + + NOTE: Currently, there are three implementations of nn.scaled_dot_product_attention in PyTorch due to optimization. + However, it's the same implementation from ONNX perspective. + + """ + result = aten_scaled_dot_product_attention_bool_mask( + query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale + ) + + # The followings are not comsumed by the graph. + logsumexp = op.Expand(0, op.Shape(query)) + empty_tensor = op.ConstantOfShape([]) + empty_int = op.Constant(value_int=0) + + return ( + result, + logsumexp, + empty_tensor, + empty_tensor, + empty_int, + empty_int, + empty_tensor, + empty_tensor, + empty_tensor, + ) + + @torch_op("aten::scaled_dot_product_attention", trace_only=True) def aten_scaled_dot_product_attention_bool_mask( query: TFloat, diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 561834ede7..ca3b50f060 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -930,6 +930,58 @@ def sample_inputs__softmax( yield opinfo_core.SampleInput(make_arg(shape), args=dim, kwargs=kwargs) +def sample_inputs_scaled_dot_product_flash_attention( + op_info, device, dtype, requires_grad, **kwargs +): + del op_info + del kwargs + + make = opinfo_core.partial( + opinfo_core.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8 + + dim_3_q_shape = (batch, seq_q, head_dim) + dim_3_kv_shape = (batch, seq_kv, head_dim) + dim_4_q_shape = (batch, num_heads, seq_q, head_dim) + dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) + + broadcast_tuple = ((num_heads, seq_q, head_dim), (batch, num_heads, seq_kv, head_dim)) + + qkv_shapes = [ + (dim_3_q_shape, dim_3_kv_shape), + (dim_4_q_shape, dim_4_kv_shape), + broadcast_tuple, + ] + samples = [] + for qkv_shape, is_causal, dropout_p in opinfo_core.product( + qkv_shapes, [True, False], [0.0, 0.5] + ): + shape_q, shape_kv = qkv_shape + samples.append( + opinfo_core.SampleInput( + make(shape_q), + make(shape_kv), + make(shape_kv), + is_causal=is_causal, + dropout_p=dropout_p, + ) + ) + + # Add an attn_mask + samples.append( + opinfo_core.SampleInput( + make((batch, num_heads, seq_q, head_dim)), + make((batch, num_heads, seq_kv, head_dim)), + make((batch, num_heads, seq_kv, head_dim)), + is_causal=False, + dropout_p=0.0, + ) + ) + + yield from samples + + # NOTE: How to create an OpInfo: # 1. Create a function that generates sample inputs for the op. # This function should yield SampleInputs. @@ -1130,4 +1182,11 @@ def sample_inputs__softmax( sample_inputs_func=sample_inputs__softmax, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten._scaled_dot_product_flash_attention", + aten_name="_scaled_dot_product_flash_attention", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_scaled_dot_product_flash_attention, + supports_out=False, + ), ] diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 694af9a66b..49db830c4d 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1749,6 +1749,12 @@ def _where_input_wrangler( matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, reason="dropout is random so the results do not match", ), + TorchLibOpInfo( + "nn.functional.scaled_dot_product_attention", + nn_ops.aten_scaled_dot_product_flash_attention, + trace_only=True, + tolerance={torch.float32: (3e-4, 1.5e-5)}, + ), TorchLibOpInfo( "nn.functional.scaled_dot_product_attention_bool_mask", nn_ops.aten_scaled_dot_product_attention_bool_mask, From a90d78c8d5be719a59713a1d2a2179dec9bf062e Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Thu, 31 Aug 2023 20:47:11 +0000 Subject: [PATCH 2/6] refactor tests --- onnxscript/function_libs/torch_lib/ops/nn.py | 23 ++++++++++++------- .../function_libs/torch_lib/extra_opinfo.py | 17 ++++++-------- .../function_libs/torch_lib/ops_test_data.py | 4 +++- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 6cbd8abfa5..74d2ed4443 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1692,25 +1692,32 @@ def aten_scaled_dot_product_flash_attention( However, it's the same implementation from ONNX perspective. """ - result = aten_scaled_dot_product_attention_bool_mask( + result = aten_scaled_dot_product_attention( query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale ) # The followings are not comsumed by the graph. - logsumexp = op.Expand(0, op.Shape(query)) - empty_tensor = op.ConstantOfShape([]) + query_first_three_dims = op.Slice( + op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3]) + ) + logsumexp = op.Expand(0.0, query_first_three_dims) + # TODO: breaking checker + empty_tensor_int = op.Cast(op.ConstantOfShape(op.Constant(value_ints=[])), to=INT64.dtype) + empty_tensor_float = op.Cast( + op.ConstantOfShape(op.Constant(value_ints=[])), to=FLOAT.dtype + ) empty_int = op.Constant(value_int=0) return ( result, logsumexp, - empty_tensor, - empty_tensor, + empty_tensor_int, + empty_tensor_int, empty_int, empty_int, - empty_tensor, - empty_tensor, - empty_tensor, + empty_tensor_int, + empty_tensor_int, + empty_tensor_float, ) diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index ca3b50f060..c35405b5b1 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -941,21 +941,13 @@ def sample_inputs_scaled_dot_product_flash_attention( ) batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8 - dim_3_q_shape = (batch, seq_q, head_dim) - dim_3_kv_shape = (batch, seq_kv, head_dim) dim_4_q_shape = (batch, num_heads, seq_q, head_dim) dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) - broadcast_tuple = ((num_heads, seq_q, head_dim), (batch, num_heads, seq_kv, head_dim)) - - qkv_shapes = [ - (dim_3_q_shape, dim_3_kv_shape), - (dim_4_q_shape, dim_4_kv_shape), - broadcast_tuple, - ] + qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] samples = [] for qkv_shape, is_causal, dropout_p in opinfo_core.product( - qkv_shapes, [True, False], [0.0, 0.5] + qkv_shapes, [True, False], [0.0] ): shape_q, shape_kv = qkv_shape samples.append( @@ -1186,7 +1178,12 @@ def sample_inputs_scaled_dot_product_flash_attention( "ops.aten._scaled_dot_product_flash_attention", aten_name="_scaled_dot_product_flash_attention", dtypes=common_dtype.floating_types_and(torch.bfloat16), + # NOTE: Different from aten::scaled_dot_product_attention, this op doesn't support + # dim<=3 input. sample_inputs_func=sample_inputs_scaled_dot_product_flash_attention, supports_out=False, + supports_forward_ad=False, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, ), ] diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 49db830c4d..7f39de67f7 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1750,10 +1750,12 @@ def _where_input_wrangler( reason="dropout is random so the results do not match", ), TorchLibOpInfo( - "nn.functional.scaled_dot_product_attention", + "ops.aten._scaled_dot_product_flash_attention", nn_ops.aten_scaled_dot_product_flash_attention, trace_only=True, tolerance={torch.float32: (3e-4, 1.5e-5)}, + # Output[0] is OK, but other outputs just have the same shape with zero values + nondeterministic=True, ), TorchLibOpInfo( "nn.functional.scaled_dot_product_attention_bool_mask", From 3521e2d28195472a99d63cb2024829486441b5e4 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 5 Sep 2023 17:27:23 +0000 Subject: [PATCH 3/6] Use make_tensor --- onnxscript/function_libs/torch_lib/ops/nn.py | 41 +++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 74d2ed4443..c682d5dd11 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -17,6 +17,8 @@ import math from typing import Optional, Sequence, Tuple +import onnx + from onnxscript import FLOAT, INT64 from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( @@ -1673,6 +1675,29 @@ def aten_scaled_dot_product_attention( ) +@torch_op("aten::_scaled_dot_product_flash_attention", private=True) +def _aten_scaled_dot_product_flash_attention_fillin_empty_outputs( + query: TFloat, +) -> Tuple[TFloat, TFloat, TFloat, TFloat]: + # The followings are not comsumed by the graph. + query_first_three_dims = op.Slice( + op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3]) + ) + logsumexp = op.Expand(0.0, query_first_three_dims) + empty_tensor_int = op.Cast( + op.ConstantOfShape( + op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], [])) + ), + to=INT64.dtype, + ) + empty_tensor_float = op.ConstantOfShape( + op.Constant(value=onnx.helper.make_tensor("Empty_FLOATS", INT64.dtype, [0], [])) + ) + empty_int = op.Constant(value_int=0) + + return logsumexp, empty_tensor_int, empty_int, empty_tensor_float + + @torch_op("aten::_scaled_dot_product_flash_attention", trace_only=True) def aten_scaled_dot_product_flash_attention( query: TFloat, @@ -1697,16 +1722,12 @@ def aten_scaled_dot_product_flash_attention( ) # The followings are not comsumed by the graph. - query_first_three_dims = op.Slice( - op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3]) - ) - logsumexp = op.Expand(0.0, query_first_three_dims) - # TODO: breaking checker - empty_tensor_int = op.Cast(op.ConstantOfShape(op.Constant(value_ints=[])), to=INT64.dtype) - empty_tensor_float = op.Cast( - op.ConstantOfShape(op.Constant(value_ints=[])), to=FLOAT.dtype - ) - empty_int = op.Constant(value_int=0) + ( + logsumexp, + empty_tensor_int, + empty_int, + empty_tensor_float, + ) = _aten_scaled_dot_product_flash_attention_fillin_empty_outputs(query) return ( result, From 0021a52d1ded8faf46c908fdf755fe3aed57d136 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 5 Sep 2023 17:35:10 +0000 Subject: [PATCH 4/6] add TODO --- onnxscript/function_libs/torch_lib/ops/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index c682d5dd11..a5ad7bdd8a 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1679,11 +1679,11 @@ def aten_scaled_dot_product_attention( def _aten_scaled_dot_product_flash_attention_fillin_empty_outputs( query: TFloat, ) -> Tuple[TFloat, TFloat, TFloat, TFloat]: - # The followings are not comsumed by the graph. query_first_three_dims = op.Slice( op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3]) ) logsumexp = op.Expand(0.0, query_first_three_dims) + # TODO: Eliminate `make_tensor` usage when ORT supports empty tensor. empty_tensor_int = op.Cast( op.ConstantOfShape( op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], [])) From 2a9af9abb0fe32c029cb5ecde94a44e3b93da3af Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 5 Sep 2023 21:23:56 +0000 Subject: [PATCH 5/6] Update return types --- onnxscript/function_libs/torch_lib/ops/nn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index a5ad7bdd8a..74e29492ae 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1678,7 +1678,7 @@ def aten_scaled_dot_product_attention( @torch_op("aten::_scaled_dot_product_flash_attention", private=True) def _aten_scaled_dot_product_flash_attention_fillin_empty_outputs( query: TFloat, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: +) -> Tuple[FLOAT, INT64, INT64, FLOAT]: query_first_three_dims = op.Slice( op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3]) ) @@ -1707,8 +1707,8 @@ def aten_scaled_dot_product_flash_attention( is_causal: bool = False, return_debug_mask: bool = False, # pylint: disable=unused-argument scale: Optional[float] = None, -) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat, TFloat, INT64, INT64, TFloat]: - """_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor ouput, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +) -> Tuple[TFloat, FLOAT, INT64, INT64, INT64, INT64, INT64, INT64, FLOAT]: + """_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) One of the implementations of scaled_dot_product_attention. Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html From 376765ba019849ebdf7e80fb5411c5c81f752079 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Tue, 5 Sep 2023 23:39:17 +0000 Subject: [PATCH 6/6] skip older torch --- onnxscript/tests/function_libs/torch_lib/ops_test_data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 7f39de67f7..ed2643093b 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1756,6 +1756,9 @@ def _where_input_wrangler( tolerance={torch.float32: (3e-4, 1.5e-5)}, # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, + ).skip( + enabled_if=version_utils.torch_older_than("2.1"), + reason="The operator is not supported in older version.", ), TorchLibOpInfo( "nn.functional.scaled_dot_product_attention_bool_mask",