You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
torchao is a library to create and integrate high-performance custom data types, layouts and kernels into their PyTorch workflows with up to **2x speedups** with **65%** less VRAM for [inference](#inference) and support for [training](#training)
9
+
torchao is a library to create and integrate high-performance custom data types, layouts and kernels into your PyTorch workflows with up to **2x speedups** with **65% less VRAM** for [inference](#inference) and support for [training](#training)
9
10
10
11
All with no intrusive code changes and minimal accuracy degradation.
11
12
@@ -15,7 +16,7 @@ All with no intrusive code changes and minimal accuracy degradation.
15
16
16
17
#### Without intrusive code changes
17
18
18
-
Quantizing your models is a 1 liner that should work on any model with `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/) and a hugginface inference example [here](scripts/hf_eval.py)
19
+
Quantizing your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/) and a HuggingFace inference example [here](scripts/hf_eval.py)
19
20
20
21
```python
21
22
from torchao.quantization.quant_api import quantize
@@ -59,12 +60,10 @@ We've added support for semi-structured 2:4 sparsity with 6% end to end speedups
59
60
60
61
The code change is a 1 liner with the full example available [here](torchao/sparsity/training/)
*[MX](torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
A key design principle for us is composability as in any new dtype or layout we provide needs to work with `torch.compile()` and needs to work with `FSDP`. It shouldn't matter if the kernels are written are pure PyTorch, CUDA, C++, or Triton - things should just work! And here is our current strategy
75
+
A key design principle for us is composability as in any new dtype or layout we provide needs to work with `torch.compile()` and needs to work with `FSDP`. It shouldn't matter if the kernels are written in pure PyTorch, CUDA, C++, or Triton - things should just work! And here is our current strategy
77
76
1. Write the dtype, layout or bit packing logic in pure PyTorch and code-generate efficient kernels with torch.compile. You can inspect those kernels with `TORCH_LOGS="output_code" python your_code.py` and check if a single kernel is being generated and if any unnecessary buffers are being created
78
-
2. However once you get a kernel, how do you know how good it is? The best way is to benchmark the code-generated code with the best kernel on the market. But packaging custom CPP/CUDA kernels that work on multiple devices is tedious but we've abstracted all the tedium from you with our [custom ops support](./torchao/csrc/) so if you love writing kernels but hate packaging, we'd love to accept contributions for your custom ops. One key benefit is a kernel written as a custom op will just work with no graph breaks with `torch.compile()`. Compilers are great at optimizations like fusions and overhead reduction but it's challenging for compilers to rewrite the math of an algorithm such that it's faster but also numerically stable so we are betting on both compilers and custom ops
79
-
3. Finally while historically most quantization has been done for inference there is now a thriving area of research combining lower dtypes and sharding. One popular example is [NF4](torchao/dtypes/nf4tensor.py) which is used to create the QLoRA algorithm and you can define the semantics for how custom tensors should be sharded over multiple devices. We gave an accessible talk on [how to do this](https://x.com/HamelHusain/status/1800315287574847701).
77
+
2. However once you get a kernel, how do you know how good it is? The best way is to benchmark the compiler generated code with the best kernel on the market. But packaging custom CPP/CUDA kernels that work on multiple devices is tedious but we've abstracted all the tedium from you with our [custom ops support](./torchao/csrc/) so if you love writing kernels but hate packaging, we'd love to accept contributions for your custom ops. One key benefit is a kernel written as a custom op will just work with no graph breaks with `torch.compile()`. Compilers are great at optimizations like fusions and overhead reduction but it's challenging for compilers to rewrite the math of an algorithm such that it's faster but also numerically stable so we are betting on both compilers and custom ops
78
+
3. Finally while historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization. One popular example is [NF4](torchao/dtypes/nf4tensor.py) which was used to implement the QLoRA algorithm. The NF4 tensor also contains semantics for how it should be sharded over multiple devices so it composes with FSDP. We gave an accessible talk on [how to do this](https://x.com/HamelHusain/status/1800315287574847701).
80
79
81
-
## Get Started
82
80
83
81
### Installation
84
82
`torchao` makes liberal use of several new features in Pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.
@@ -93,6 +91,13 @@ Nightly Release
93
91
pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
*[Fused int4/fp16 Quant Matmul](torchao/prototype/hqq) which is particularly useful for compute bound kernels showing 4x speedups over tinygemm for larger batch sizes such as 512
102
107
*[gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/fp6_llm](torchao/prototype/fp6_llm)
103
108
*[vayuda](https://github.com/vayuda) with generic bitpacking kernels that were code generated using pure PyTorch [prototype/common](torchao/prototype/common)
104
-
*[andreaskopf](https://github.com/andreaskoepf) and [melvinebenezer](https://github.com/melvinebenezer) with [1 bit LLMs](torchao/prototype/dtypes) Bitnet 1.58 bitpacked into uin2 and fully code-generated with torch.compile
109
+
*[andreaskopf](https://github.com/andreaskoepf) and [melvinebenezer](https://github.com/melvinebenezer) with [1 bit LLMs](torchao/prototype/dtypes) Bitnet 1.58 bitpacked into uint2 and fully code-generated with torch.compile
105
110
106
111
## How to contribute
107
112
108
113
This repository is currently under heavy development
109
114
* If you have suggestions on the API or use cases you'd like to be covered, please open an [issue](https://github.com/pytorch/ao/issues)
110
115
* If you'd like to co-develop the library with us please join us on #torchao on [discord.gg/cudamode](https://discord.gg/cudamode) - there are a lot of dtypes out there and we could use a lot more hands to make them go brrr
0 commit comments