-
Notifications
You must be signed in to change notification settings - Fork 64
Add Op(_scaled_dot_product_flash_attention) | feat(torchlib) #1043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Op(_scaled_dot_product_flash_attention) | feat(torchlib) #1043
Conversation
Codecov Report
@@ Coverage Diff @@
## main #1043 +/- ##
==========================================
+ Coverage 77.68% 77.73% +0.04%
==========================================
Files 114 114
Lines 14445 14473 +28
Branches 1545 1546 +1
==========================================
+ Hits 11222 11250 +28
Misses 2857 2857
Partials 366 366
|
return ( | ||
result, | ||
logsumexp, | ||
empty_tensor_int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I create TInt
for these guys?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
INT64?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The one in embedding remains TFloat though. But I can do INT64 in this case. Depends should we follow native-func sig or what we really return.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the return types for the empty float values need to be TFloat, do we need a CaskLike self here? Otherwise it would be FLOAT because the dtype is set and not dependent on the input?
lgtm with the return types fixed |
@justinchuby I found CI all fails except torch-nightly. I guess it needs torch-nightly to test this op? |
Looks like so. We can skip the tests for older torch by using |
… inputs" Previous to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into `_fill_tensor_shape_type`, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph. This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None). Reference test from a TorchBench model: ```python def test_nanogpt(self): import sys sys.path.append("/home/titaiwang") from nanoGPT.model import GPT, GPTConfig # Load the model kwargs = { "block_size": 256, "vocab_size": 8096, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency "n_layer": 2, "n_head": 2, "n_embd": 128, "dropout": 0.0, "bias": False, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster } config = GPTConfig(**kwargs) with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_mem_efficient=True ): model = GPT(config) print("Done loading model") inputs = torch.arange(128).view(2, 64) targets = torch.arange(128).view(2, 64) self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( model, (inputs,), input_kwargs={ "targets": targets, }, verbose=True, ) ``` [ghstack-poisoned]
Previous to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into `_fill_tensor_shape_type`, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph. This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None). Reference test from a TorchBench model: ```python def test_nanogpt(self): import sys sys.path.append("/home/titaiwang") from nanoGPT.model import GPT, GPTConfig # Load the model kwargs = { "block_size": 256, "vocab_size": 8096, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency "n_layer": 2, "n_head": 2, "n_embd": 128, "dropout": 0.0, "bias": False, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster } config = GPTConfig(**kwargs) with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_mem_efficient=True ): model = GPT(config) print("Done loading model") inputs = torch.arange(128).view(2, 64) targets = torch.arange(128).view(2, 64) self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( model, (inputs,), input_kwargs={ "targets": targets, }, verbose=True, ) ``` [ghstack-poisoned]
… inputs" Prior to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into `_fill_tensor_shape_type`, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph. This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None). Reference test from a TorchBench model: ```python def test_nanogpt(self): import sys sys.path.append("/home/titaiwang") from nanoGPT.model import GPT, GPTConfig # Load the model kwargs = { "block_size": 256, "vocab_size": 8096, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency "n_layer": 2, "n_head": 2, "n_embd": 128, "dropout": 0.0, "bias": False, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster } config = GPTConfig(**kwargs) with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_mem_efficient=True ): model = GPT(config) print("Done loading model") inputs = torch.arange(128).view(2, 64) targets = torch.arange(128).view(2, 64) self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( model, (inputs,), input_kwargs={ "targets": targets, }, verbose=True, ) ``` [ghstack-poisoned]
Prior to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into `_fill_tensor_shape_type`, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph. This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None). Reference test from a TorchBench model: ```python def test_nanogpt(self): import sys sys.path.append("/home/titaiwang") from nanoGPT.model import GPT, GPTConfig # Load the model kwargs = { "block_size": 256, "vocab_size": 8096, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency "n_layer": 2, "n_head": 2, "n_embd": 128, "dropout": 0.0, "bias": False, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster } config = GPTConfig(**kwargs) with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_mem_efficient=True ): model = GPT(config) print("Done loading model") inputs = torch.arange(128).view(2, 64) targets = torch.arange(128).view(2, 64) self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( model, (inputs,), input_kwargs={ "targets": targets, }, verbose=True, ) ``` [ghstack-poisoned]
Prior to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into `_fill_tensor_shape_type`, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph. This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None). Reference test from a TorchBench model: ```python def test_nanogpt(self): import sys sys.path.append("/home/titaiwang") from nanoGPT.model import GPT, GPTConfig # Load the model kwargs = { "block_size": 256, "vocab_size": 8096, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency "n_layer": 2, "n_head": 2, "n_embd": 128, "dropout": 0.0, "bias": False, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster } config = GPTConfig(**kwargs) with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_mem_efficient=True ): model = GPT(config) print("Done loading model") inputs = torch.arange(128).view(2, 64) targets = torch.arange(128).view(2, 64) self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime( model, (inputs,), input_kwargs={ "targets": targets, }, verbose=True, ) ``` Pull Request resolved: #108708 Approved by: https://github.com/justinchuby, https://github.com/thiagocrepaldi
#1197) Fix #1160 Follow up #1043 It's another `scaled_dot_product_XXX_attention` It only supports CUDA. https://github.com/pytorch/pytorch/blob/38ae17d166a001ef6837553d1ddffa111624df27/torch/_meta_registrations.py#L5195-L5236 NOTE: This PR also enables CUDA tests.
_scaled_dot_product_flash_attention
is one out of three ATen implementations ofnn.functional.scaled_dot_product_attention
according to the page: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html.As of which one of three ATen operator is representing
nn.functional.scaled_dot_product_attention
in a model is decided by a context manager: https://pytorch.org/docs/stable/backends.html. From ONNX perspective, they have no difference except the function signature.Only the first result matters in terms of the model prediction, and the unrelated outputs are following the below code:
NOTE: PyTorch converter should consider None would appear in
_fill_tensor_shape_type
, otherwise, the exporter crashes.