Skip to content

Commit ad8b42f

Browse files
committed
Merge branch 'main' into jcaip/sam
2 parents 5b3c08f + 5d22ad2 commit ad8b42f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1368
-1007
lines changed

.github/scripts/trymerge.py

Lines changed: 30 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,6 @@ def merge_into(
11631163
# Finally, upload the record to Rockset. The list of pending and failed
11641164
# checks are at the time of the merge
11651165
save_merge_record(
1166-
collection=ROCKSET_MERGES_COLLECTION,
11671166
comment_id=comment_id,
11681167
pr_num=self.pr_num,
11691168
owner=self.org,
@@ -1179,10 +1178,8 @@ def merge_into(
11791178
merge_base_sha=self.get_merge_base(),
11801179
merge_commit_sha=merge_commit_sha,
11811180
is_failed=False,
1182-
dry_run=dry_run,
11831181
skip_mandatory_checks=skip_mandatory_checks,
11841182
ignore_current=bool(ignore_current_checks),
1185-
workspace=ROCKSET_MERGES_WORKSPACE,
11861183
)
11871184
else:
11881185
print("Missing comment ID or PR number, couldn't upload to Rockset")
@@ -1489,7 +1486,6 @@ def checks_to_markdown_bullets(
14891486

14901487
@retries_decorator()
14911488
def save_merge_record(
1492-
collection: str,
14931489
comment_id: int,
14941490
pr_num: int,
14951491
owner: str,
@@ -1505,59 +1501,44 @@ def save_merge_record(
15051501
merge_base_sha: str,
15061502
merge_commit_sha: str = "",
15071503
is_failed: bool = False,
1508-
dry_run: bool = False,
15091504
skip_mandatory_checks: bool = False,
15101505
ignore_current: bool = False,
15111506
error: str = "",
1512-
workspace: str = "commons",
15131507
) -> None:
15141508
"""
1515-
This saves the merge records into Rockset, so we can query them (for fun and profit)
1509+
This saves the merge records as a json, which can later be uploaded to s3
15161510
"""
1517-
if dry_run:
1518-
# Decide not to save the record to Rockset if dry-run is set to not pollute
1519-
# the collection
1520-
return
1521-
1522-
try:
1523-
import rockset # type: ignore[import]
1524-
1525-
# Prepare the record to be written into Rockset
1526-
data = [
1527-
{
1528-
"comment_id": comment_id,
1529-
"pr_num": pr_num,
1530-
"owner": owner,
1531-
"project": project,
1532-
"author": author,
1533-
"pending_checks": pending_checks,
1534-
"failed_checks": failed_checks,
1535-
"ignore_current_checks": ignore_current_checks,
1536-
"broken_trunk_checks": broken_trunk_checks,
1537-
"flaky_checks": flaky_checks,
1538-
"unstable_checks": unstable_checks,
1539-
"last_commit_sha": last_commit_sha,
1540-
"merge_base_sha": merge_base_sha,
1541-
"merge_commit_sha": merge_commit_sha,
1542-
"is_failed": is_failed,
1543-
"skip_mandatory_checks": skip_mandatory_checks,
1544-
"ignore_current": ignore_current,
1545-
"error": error,
1546-
}
1547-
]
15481511

1549-
client = rockset.RocksetClient(
1550-
host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
1551-
)
1552-
client.Documents.add_documents(
1553-
collection=collection,
1554-
data=data,
1555-
workspace=workspace,
1556-
)
1512+
# Prepare the record to be written into Rockset
1513+
data = [
1514+
{
1515+
"comment_id": comment_id,
1516+
"pr_num": pr_num,
1517+
"owner": owner,
1518+
"project": project,
1519+
"author": author,
1520+
"pending_checks": pending_checks,
1521+
"failed_checks": failed_checks,
1522+
"ignore_current_checks": ignore_current_checks,
1523+
"broken_trunk_checks": broken_trunk_checks,
1524+
"flaky_checks": flaky_checks,
1525+
"unstable_checks": unstable_checks,
1526+
"last_commit_sha": last_commit_sha,
1527+
"merge_base_sha": merge_base_sha,
1528+
"merge_commit_sha": merge_commit_sha,
1529+
"is_failed": is_failed,
1530+
"skip_mandatory_checks": skip_mandatory_checks,
1531+
"ignore_current": ignore_current,
1532+
"error": error,
1533+
# This is a unique identifier for the record for deduping purposes
1534+
# in rockset. Any unique string would work
1535+
"_id": f"{project}-{pr_num}-{comment_id}-{os.environ.get('GITHUB_RUN_ID')}",
1536+
}
1537+
]
1538+
repo_root = Path(__file__).resolve().parent.parent.parent
15571539

1558-
except ModuleNotFoundError:
1559-
print("Rockset is missing, no record will be saved")
1560-
return
1540+
with open(repo_root / "merge_record.json", "w") as f:
1541+
json.dump(data, f)
15611542

15621543

15631544
@retries_decorator(rc=[])
@@ -2374,7 +2355,6 @@ def handle_exception(e: Exception, title: str = "Merge failed") -> None:
23742355
# list of pending and failed checks here, but they are not really
23752356
# needed at the moment
23762357
save_merge_record(
2377-
collection=ROCKSET_MERGES_COLLECTION,
23782358
comment_id=args.comment_id,
23792359
pr_num=args.pr_num,
23802360
owner=org,
@@ -2389,11 +2369,9 @@ def handle_exception(e: Exception, title: str = "Merge failed") -> None:
23892369
last_commit_sha=pr.last_commit().get("oid", ""),
23902370
merge_base_sha=pr.get_merge_base(),
23912371
is_failed=True,
2392-
dry_run=args.dry_run,
23932372
skip_mandatory_checks=args.force,
23942373
ignore_current=args.ignore_current,
23952374
error=str(e),
2396-
workspace=ROCKSET_MERGES_WORKSPACE,
23972375
)
23982376
else:
23992377
print("Missing comment ID or PR number, couldn't upload to Rockset")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pip install ${PYTORCH_PIP_PREFIX} torchao --index-url ${PYTORCH_PIP_DOWNLOAD_URL}
2+
python ./test/smoke_tests/smoke_tests.py

.github/workflows/trymerge.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ jobs:
99
name: try_merge_pr_${{ github.event.client_payload.pr_num }}
1010
runs-on: ubuntu-latest
1111
environment: pytorchbot-env
12+
permissions:
13+
id-token: write
1214
env:
1315
GH_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
1416
steps:
@@ -45,6 +47,7 @@ jobs:
4547
IGNORE_CURRENT: ${{ github.event.client_payload.ignore_current }}
4648
ROCKSET_API_KEY: ${{ secrets.ROCKSET_API_KEY }}
4749
DRCI_BOT_KEY: ${{ secrets.DRCI_BOT_KEY }}
50+
GITHUB_RUN_ID: ${{ github.run_id }}
4851
run: |
4952
set -x
5053
if [ -n "${FORCE}" ]; then
@@ -65,6 +68,22 @@ jobs:
6568
python3 .github/scripts/trymerge.py "${PR_NUM}"
6669
fi
6770
71+
- name: configure aws credentials
72+
uses: aws-actions/configure-aws-credentials@v3
73+
continue-on-error: true
74+
with:
75+
role-to-assume: arn:aws:iam::308535385114:role/upload_to_ossci_raw_job_status
76+
aws-region: us-east-1
77+
78+
- name: Upload merge record to s3
79+
if: always()
80+
continue-on-error: true
81+
uses: seemethere/upload-artifact-s3@v5
82+
with:
83+
s3-bucket: ossci-raw-job-status
84+
s3-prefix: merges/${{ github.repository }}/${{ github.event.client_payload.pr_num }}/${{ github.event.client_payload.comment_id }}/${{ github.run_id }}
85+
path: merge_record.json
86+
6887
# We want newer merge commands to supercede old ones
6988
concurrency:
7089
group: try-merge-${{ github.event.client_payload.pr_num }}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
name: Validate binaries
2+
3+
on:
4+
workflow_call:
5+
inputs:
6+
channel:
7+
description: "Channel to use (nightly, test, release, all)"
8+
required: false
9+
type: string
10+
default: release
11+
ref:
12+
description: "Reference to checkout, defaults to empty"
13+
default: ""
14+
required: false
15+
type: string
16+
workflow_dispatch:
17+
inputs:
18+
channel:
19+
description: "Channel to use (nightly, test, release, all)"
20+
required: true
21+
type: choice
22+
options:
23+
- release
24+
- nightly
25+
- test
26+
- all
27+
ref:
28+
description: "Reference to checkout, defaults to empty"
29+
default: ""
30+
required: false
31+
type: string
32+
pytorch_version:
33+
description: "PyTorch version to validate (ie. 2.0, 2.2.2, etc.) - optional"
34+
default: ""
35+
required: false
36+
type: string
37+
jobs:
38+
validate-binaries:
39+
uses: pytorch/test-infra/.github/workflows/validate-domain-library.yml@main
40+
with:
41+
package_type: "wheel"
42+
version: ${{ inputs.version }}
43+
os: "linux"
44+
channel: ${{ inputs.channel }}
45+
repository: "pytorch/ao"
46+
with_cuda: "enable"
47+
with_rocm: "disable"
48+
smoke_test: "source ./.github/scripts/validate_binaries.sh"
49+
install_torch: true

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ env
127127
.circleci/scripts/COMMIT_MSG
128128
scripts/release_notes/*.json
129129
sccache-stats*.json
130+
merge_record.json
130131

131132
# These files get copied over on invoking setup.py
132133
torchgen/packaged/*

README.md

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,17 @@ The models used were `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Meta-Llama-
2929

3030
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
3131
| ----------- | ------------------ | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
32-
| Llama-2-7B | Base (bfloat16) | 12.212 | 105.02 | 1387.78 | 13.21 | 13.90 |
33-
| | int8dq | 12.262 | 9.40 | 62.26 | 6.62 | 8.61 |
34-
| | int8wo | 12.204 | 147.03 | 973.54 | 6.62 | 8.95 |
35-
| | int4wo-64 | 12.843 | 199.81 | 746.45 | 3.74 | 4.75 |
36-
| | int4wo-64-GPTQ | 12.489 | 199.81 | 746.45 | 3.74 | 4.75 |
37-
| Llama-3-8B | Base (bfloat16) | | 94.91 | 1424.58 | 15.01 | 16.43 |
38-
| | int8dq | | 8.41 | 63.23 | 7.52 | 9.24 |
39-
| | int8wo | | 136.75 | 1028.38 | 7.52 | 10.42 |
40-
| | int4wo-64 | | 179.41 | 757.45 | 4.22 | 6.88 |
32+
| Llama-2-7B | Base (bfloat16) | 12.212 | 105.14 | 1389.35 | 13.88 | 13.21 |
33+
| | int8dq | 12.262 | 9.20 | 60.93 | 8.33 | 6.62 |
34+
| | int8wo | 12.204 | 150.18 | 994.40 | 8.95 | 6.62 |
35+
| | int4wo-64 | 12.843 | 199.86 | 746.66 | 4.50 | 3.74 |
36+
| | int4wo-64-GPTQ | 12.489 | 199.86 | 746.66 | 4.50 | 3.74 |
37+
| | autoquant | 12.204 | 159.22 | 1069.87 | 8.91 | 6.72 |
38+
| Llama-3-8B | Base (bfloat16) | N/A | 94.97 | 1425.55 | 16.43 | 15.01 |
39+
| | int8dq | N/A | 8.44 | 63.45 | 8.98 | 7.52 |
40+
| | int8wo | N/A | 139.76 | 1051.02 | 10.42 | 7.52 |
41+
| | int4wo-64 | N/A | 179.44 | 757.60 | 6.62 | 4.22 |
42+
| | autoquant | N/A | 137.71 | 1037.74 | 11.08 | 7.54 |
4143

4244
note: Int8 dynamic quantization works best on compute bound as opposed to memory bound models. Some relatable examples might be [SAM](https://github.com/pytorch-labs/segment-anything-fast) which is compute bound vs Llama at batchsize=1 which is memory bound.
4345

@@ -81,7 +83,7 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})
8183

8284
* [MX](torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
8385
* [nf4](torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) one of the most popular finetuning algorithms without writing custom Triton or CUDA code. Accessible talk [here](https://x.com/HamelHusain/status/1800315287574847701)
84-
* [fp6](torchao/prototype/fp6_llm/) for 2x faster inference over fp16 with an easy to use wrapper api `convert_fp6_llm(model)`
86+
* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize(model, fp6_llm_weight_only())`
8587

8688
## Composability
8789

@@ -92,11 +94,34 @@ A key design principle for us is composability as in any new dtype or layout we
9294

9395

9496
### Installation
97+
9598
`torchao` makes liberal use of several new features in Pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.
9699

97-
Stable Release
100+
#### Install torch
101+
102+
Install torch stable
103+
104+
```
105+
pip install torch
106+
```
107+
108+
Or torch nightlies
109+
110+
```
111+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
112+
```
113+
114+
#### Install torchao
115+
116+
Stable release from Pypi which will default to CUDA 12.1
117+
98118
```Shell
99-
pip install torchao --extra-index-url https://download.pytorch.org/whl/test/cu121 # full options are cpu/cu118/cu121/cu124
119+
pip install torchao
120+
```
121+
122+
Stable Release from the PyTorch index
123+
```Shell
124+
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
100125
```
101126

102127
Nightly Release
@@ -117,10 +142,17 @@ python setup.py install
117142
* [GaLore](torchao/prototype/galore/) a drop for the Adam Optimizer that allows you to finetune llama 7b on a single 4090 card with up to 70% speedups relative to eager PyTorch
118143
* [DoRA](torchao/prototype/dora) a newer replacement for QLoRA with more promising convergence characteristics
119144
* [Fused int4/fp16 Quant Matmul](torchao/prototype/hqq) which is particularly useful for compute bound kernels showing 4x speedups over tinygemm for larger batch sizes such as 512
120-
* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/fp6_llm](torchao/prototype/fp6_llm)
145+
* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/quant_llm](torchao/prototype/quant_llm)
121146
* [vayuda](https://github.com/vayuda) with generic bitpacking kernels that were code generated using pure PyTorch [prototype/common](torchao/prototype/common)
122147
* [andreaskopf](https://github.com/andreaskoepf) and [melvinebenezer](https://github.com/melvinebenezer) with [1 bit LLMs](torchao/prototype/dtypes) Bitnet 1.58 bitpacked into uint2 and fully code-generated with torch.compile
123148

149+
## Blogs and Videos
150+
* [Accelerating Neural Network Training with Semi-Structured (2:4) Sparsity](https://pytorch.org/blog/accelerating-neural-network-training/)
151+
* [https://mobiusml.github.io/whisper-static-cache-blog/](https://mobiusml.github.io/whisper-static-cache-blog/)
152+
* [Slaying OOMs at the Mastering LLM's course](https://x.com/HamelHusain/status/1800315287574847701)
153+
* [Advanced Quantization at CUDA MODE](https://youtu.be/1u9xUK3G4VM?si=4JcPlw2w8chPXW8J)
154+
* [Chip Huyen's GPU Optimization Workshop](https://www.youtube.com/live/v_q2JTIqE20?si=mf7HeZ63rS-uYpS6)
155+
124156
## How to contribute
125157

126158
This repository is currently under heavy development

benchmarks/benchmark_aq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Int8WeightOnlyQuantizedLinearWeight,
66
Int4WeightOnlyQuantizedLinearWeight,
77
)
8-
from torchao.quantization.utils import (
8+
from torchao.utils import (
99
TORCH_VERSION_AFTER_2_4,
1010
)
1111
from torchao.quantization.quant_api import (

benchmarks/benchmark_fp6_llm.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
import torch
2-
from torch import nn
3-
from torchao.prototype.fp6_llm.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2
4-
from torch.utils.benchmark import Timer
52
import pandas as pd
3+
import torch.nn.functional as F
4+
from torchao.prototype.quant_llm import QuantLlmLinearWeight
5+
from torchao.utils import benchmark_torch_function_in_microseconds
66
from tqdm import tqdm
77

88

99
def benchmark(m: int, k: int, n: int):
10-
fp6_weight = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
11-
scales = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
12-
fp6_linear = Fp6LlmLinear(fp6_weight, scales)
10+
fp6_data = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
11+
scale = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
12+
fp6_weight = QuantLlmLinearWeight(fp6_data, scale, 3, 2)
1313

14-
fp16_linear = nn.Linear(k, n, bias=True, dtype=torch.half, device="cuda")
15-
fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight, dtype=torch.half) * scales[:, None]
14+
fp16_weight = fp6_weight.dequantize(torch.half)
1615

1716
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
18-
fp6_output = fp6_linear(fp16_act)
19-
fp16_output = fp16_linear(fp16_act)
17+
fp6_output = F.linear(fp16_act, fp6_weight)
18+
fp16_output = F.linear(fp16_act, fp16_weight)
2019

21-
fp6_measurement = Timer(stmt="fp6_linear(fp16_act)", globals=locals()).blocked_autorange()
22-
fp16_measurement = Timer(stmt="fp16_linear(fp16_act)", globals=locals()).blocked_autorange()
20+
fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight)
21+
fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight)
2322

2423
# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
2524
# doesn't seem to be the right way to check for correctness
@@ -29,9 +28,9 @@ def benchmark(m: int, k: int, n: int):
2928
"m": m,
3029
"k": k,
3130
"n": n,
32-
"fp6_latency (ms)": fp6_measurement.median * 1000,
33-
"fp16_latency (ms)": fp16_measurement.median * 1000,
34-
"speedup (d/s)": fp16_measurement.median / fp6_measurement.median,
31+
"fp6_latency (ms)": fp6_time,
32+
"fp16_latency (ms)": fp16_time,
33+
"speedup (d/s)": fp16_time / fp6_time,
3534
"correct": correct,
3635
}
3736

benchmarks/intmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pathlib
77

88
import torch
9-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4, TORCH_VERSION_AFTER_2_2
9+
from torchao.utils import TORCH_VERSION_AFTER_2_4, TORCH_VERSION_AFTER_2_2
1010

1111

1212
# Check if CUDA is available, if not, exit the script

0 commit comments

Comments
 (0)