Skip to content

Unsupported Scalar Type 5? -- Portable/optimized ops don't consistently support half/bfloat16 #7748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
bluejack opened this issue Jan 17, 2025 · 35 comments
Assignees
Labels
module: kernels Issues related to kernel libraries and utilities, and code under kernels/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@bluejack
Copy link

bluejack commented Jan 17, 2025

🐛 Describe the bug

After exporting a model to pte form and running it through executor_runner, I get:

E 00:00:02.220756 executorch:inputs_portable.cpp:45] Unsupported scalar type 5

I believe this is the "Half" type, or float16

Does that simply mean executor_runner does not support float16? Or does the whole framework not support float16?

Noted that when I run some investigation on the file using a python script, I get as far as sending it my float16 tensors, but it still fails to execute with a similar error:

[op_native_layer_norm.cpp:169] In function operator()(), assert failed (false): Unhandled dtype Half for native_layer_norm.out

I'm including the versions below, but note that this is using executorch built from head, rather than the last release. Should I expect the framework to support float16? And look to my own code for the error?

Versions

PyTorch version: 2.6.0.dev20250104
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.6.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.4)
CMake version: version 3.31.4
Libc version: N/A

Python version: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct 4 2024, 08:22:19) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Pro

Versions of relevant libraries:
[pip3] executorch==0.6.0a0+cd0e584
[pip3] flake8==7.0.0
[pip3] mypy==1.11.2
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.0.0
[pip3] numpydoc==1.7.0
[pip3] torch==2.6.0.dev20250104
[pip3] torchao==0.8.0+git2e032c6b
[pip3] torchaudio==2.6.0.dev20250104
[pip3] torchsr==1.0.4
[pip3] torchvision==0.22.0.dev20250104
[conda] executorch 0.6.0a0+cd0e584 pypi_0 pypi
[conda] numpy 2.0.0 pypi_0 pypi
[conda] numpydoc 1.7.0 py312hca03da5_0
[conda] torch 2.6.0.dev20250104 pypi_0 pypi
[conda] torchao 0.8.0+git2e032c6b pypi_0 pypi
[conda] torchaudio 2.6.0.dev20250104 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.22.0.dev20250104 pypi_0 pypi

cc @larryliu0820 @manuelcandales

swolchok added a commit that referenced this issue Jan 17, 2025
Partial fix for #7748.

ghstack-source-id: 0c7e0a5
ghstack-comment-id: 2599375147
Pull Request resolved: #7750
@swolchok
Copy link
Contributor

Does that simply mean executor_runner does not support float16?

It looks like this particular function does not support float16. I've just sent #7750 to fix it.

Or does the whole framework not support float16?

We are capable of supporting it, but it looks like portable ops coverage is spotty. I'll send a fix for native_layer_norm and as many other places as I can find.

swolchok added a commit that referenced this issue Jan 18, 2025
Partial fix for #7748.

ghstack-source-id: 9f183dd
ghstack-comment-id: 2599398274
Pull Request resolved: #7758
@swolchok swolchok self-assigned this Jan 18, 2025
@swolchok swolchok changed the title Unsupported Scalar Type 5? Unsupported Scalar Type 5? -- Portable/optimized ops don't consistently support half/bfloat16 Jan 18, 2025
swolchok added a commit that referenced this issue Jan 18, 2025
Partial fix for #7748.

ghstack-source-id: a72e5e3
ghstack-comment-id: 2599413770
Pull Request resolved: #7760
@swolchok
Copy link
Contributor

swolchok commented Jan 18, 2025

By the way, if you're running on your Mac, you might want to enable the XNNPACK delegate when exporting; there's a good chance you will get both better performance and a workaround for the remaining instance of this issue I haven't got PRs out for yet (though I don't know whether XNNPACK has layer norm off the top of my head).

swolchok added a commit that referenced this issue Jan 18, 2025
Partial fix for #7748.

ghstack-source-id: 02bfc58
ghstack-comment-id: 2599413770
Pull Request resolved: #7760
swolchok added a commit that referenced this issue Jan 18, 2025
Partial fix for #7748.

ghstack-source-id: b7b3380
ghstack-comment-id: 2599481711
Pull Request resolved: #7767
swolchok added a commit that referenced this issue Jan 18, 2025
Partial fix for #7748.

ghstack-source-id: 02a1dc7
ghstack-comment-id: 2599483099
Pull Request resolved: #7769
@bluejack
Copy link
Author

By the way, if you're running on your Mac, you might want to enable the XNNPACK delegate when exporting; there's a good chance you will get both better performance and a workaround for the remaining instance of this issue I haven't got PRs out for yet (though I don't know whether XNNPACK has layer norm off the top of my head).

Ok, I will look at this option, thanks for the tip.

@bluejack
Copy link
Author

To enable the XNNPACK delegate on export is it anything more than this:

edge_program = to_edge_transform_and_lower(
        exported_text,
        partitioner=[XnnpackPartitioner()],
        compile_config=EdgeCompileConfig(
            _check_ir_validity=True,
            _skip_dim_order=True,
        ),
    )

I replaced my basic to_edge call with this one, which I pulled from the xnnpack examples... but it does not seem to make any difference. Not sure if that's an indication that I am not actually doing the delegation properly, or whether it genuinely doesn't make a difference.

@kimishpatel
Copy link
Contributor

I think you will probably want to apply recipe similar to llama stuff here?
Like for quantization it has to first do this https://github.com/pytorch/executorch/blob/main/examples/models/llama/export_llama_lib.py#L1041 and only 4bit quant option (8da4w works right now)
And then do XNNPACK "lowering" following code similar to https://github.com/pytorch/executorch/blob/main/examples/models/llama/export_llama_lib.py#L685.

SO in the above there are really two steps

  1. 4bit quant that requires some code from https://github.com/pytorch/executorch/blob/main/examples/models/llama/source_transformation/quantize.py

And "lowering" that then identifies appropriate portions of the graph and leverages XNNPACK to execute them. That is the second link.

If you run into issues, which I expect you may, please post here and if you can paste the appropriate graph snippets/model from each stage it would help. @mcr229 and @digantdesai know a ton on this

@kimishpatel
Copy link
Contributor

Also note that the code pointers above have largely been validated with llama3+ models, so if your model is similar that should likely enable using those utils.

Another thing worth mentioning is that, for language models 4-bit optimization works the best for now and is well supported which I linked above. But if thats not the case for your model and some variant of 8-bit quantization works better than I am gonna ask @mcr229 to point you to some code snippets to enable quantization and lowering for you.

@bluejack
Copy link
Author

Thanks for these tips. Ours is a vision model, so eventual quality might suggest more than 4 bits, but at the moment we are just trying to get a proof of concept going. I'll dig in to these resources, thanks.

swolchok added a commit that referenced this issue Jan 21, 2025
Partial fix for #7748.

ghstack-source-id: c7d2a59
ghstack-comment-id: 2605368953
Pull Request resolved: #7791
swolchok added a commit that referenced this issue Jan 21, 2025
Partial fix for #7748.

ghstack-source-id: e25fec3
ghstack-comment-id: 2605391184
Pull Request resolved: #7792
@swolchok
Copy link
Contributor

For anyone else who wants to jump in on op support, here is how I'm identifying ops to look at:

$ cd kernels/portable/cpu
$  rg --files-without-match 'Half|HALF|ALL|HBF16|HBBF16|ufunc'  -g '*.cpp' -g !test | sort

Still have to spot-check, but this gives an initial list.

swolchok added a commit that referenced this issue Jan 21, 2025
Partial fix for #7748.

ghstack-source-id: 9c6f758
ghstack-comment-id: 2605521459
Pull Request resolved: #7794
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
* Coerce to true_ctype in tensor_factory (pytorch#7856)

This should fix the problem where attempts to test bool are often wonky in OSS and fail UBSAN internally; it is undefined behavior to store a value other than 0 or 1 for type bool.

* Support Half/BFloat16 in prod operator (pytorch#7857)

Partial fix for pytorch#7748.
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this issue Jan 30, 2025
@digantdesai digantdesai added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 4, 2025
@swolchok
Copy link
Contributor

swolchok commented Feb 7, 2025

closing because most things should support half/bfloat16 now. (norm ops are the exception per #7846; hoping to get to code sharing with PyTorch and then solve accuracy issues that way)

@swolchok swolchok closed this as completed Feb 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: kernels Issues related to kernel libraries and utilities, and code under kernels/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants