Skip to content

Commit d267622

Browse files
authored
Add torchchat quantizer
Differential Revision: D62394341 Pull Request resolved: #897
1 parent b521c9b commit d267622

File tree

8 files changed

+432
-351
lines changed

8 files changed

+432
-351
lines changed

torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
add_library(
88
kernel_aarch64
9-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
10-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
11-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
12-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
9+
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
10+
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
11+
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
12+
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
1313
)

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ set(CMAKE_BUILD_TYPE Release)
1313
add_compile_options("-Wall" "-Werror")
1414

1515
include(CMakePrintHelpers)
16-
message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}")
17-
include_directories(${TORCHAO_LIBRARIES})
16+
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
17+
include_directories(${TORCHAO_INCLUDE_DIRS})
1818

19-
add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)
19+
add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)
2020

21-
include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake)
21+
include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake)
2222

2323
set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH")
2424
string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER)

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
# LICENSE file in the root directory of this source tree.
77

88
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
9-
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../../..
9+
export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../..
1010

1111
export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')"
1212
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
13-
export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op
14-
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
13+
export CMAKE_OUT=/tmp/cmake-out/torchao
14+
cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \
1515
-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
1616
-DPLATFORM="ATEN" \
17-
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
17+
-S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
1818
-B ${CMAKE_OUT}
1919
cmake --build ${CMAKE_OUT}

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py

Lines changed: 20 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,21 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8+
import glob
9+
import os
10+
11+
import sys
812

913
import torch
10-
from torch_custom_op import (
11-
linear_a8sz_w_lowbit_reference_impl,
12-
replace_linear_with_quantized_linear,
14+
15+
sys.path.insert(
16+
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
1317
)
18+
from quant_api import Int8DynActIntxWeightQuantizer
19+
20+
libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*")
21+
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
22+
torch.ops.load_library(libs[0])
1423

1524
group_size = 256
1625
m = 1
@@ -27,15 +36,15 @@
2736

2837
print("Quantizing random model")
2938
quantized_model = copy.deepcopy(model)
30-
quantized_model = quantized_model.eval()
31-
replace_linear_with_quantized_linear(
32-
quantized_model,
33-
kwargs={
34-
"group_size": group_size,
35-
"nbit": nbit,
36-
"has_weight_zeros": has_weight_zeros,
37-
},
39+
quantizer = Int8DynActIntxWeightQuantizer(
40+
device="cpu",
41+
precision=torch.float32,
42+
bitwidth=nbit,
43+
groupsize=group_size,
44+
has_weight_zeros=has_weight_zeros,
3845
)
46+
quantized_model = quantizer.quantize(quantized_model)
47+
quantized_model = quantized_model.eval()
3948

4049
print("Creating random activations")
4150
activations = torch.randn(m, k, dtype=torch.float32)
@@ -58,44 +67,3 @@
5867
print("Running AOTI")
5968
fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu")
6069
fn(activations)
61-
62-
63-
print("\nChecking correctness on layer 0")
64-
linear = model[0]
65-
quantized_linear = quantized_model[0]
66-
67-
with torch.no_grad():
68-
result = quantized_linear(activations)
69-
expected_result = linear_a8sz_w_lowbit_reference_impl(
70-
linear.weight, activations, group_size, nbit, has_weight_zeros
71-
)
72-
non_quantized_result = linear(activations)
73-
74-
75-
# Check that entries in result match entries in expected_result
76-
num_mismatch_at_low_tol = 0
77-
num_total = result.reshape(-1).shape[0]
78-
for i in range(num_total):
79-
actual_val = result.reshape(-1)[i]
80-
expected_val = expected_result.reshape(-1)[i]
81-
if not torch.allclose(actual_val, expected_val):
82-
num_mismatch_at_low_tol += 1
83-
84-
# If results are not close at a relaxed tolerance, exit with failure
85-
if not torch.allclose(actual_val, expected_val, atol=1e-6):
86-
assert False, "Correctness check failed"
87-
88-
# Assert at most 5% of entries are not close at a low tolerance
89-
assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed"
90-
print(
91-
"Correctness check passed. All results are close, and ",
92-
(num_total - num_mismatch_at_low_tol),
93-
"/",
94-
num_total,
95-
" entries are close at a low tolerance.",
96-
)
97-
print("Quantization errors:")
98-
print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item())
99-
print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item())
100-
print("\tquantized_result[0:5]: ", result[0][0:5])
101-
print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5])

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py

Lines changed: 0 additions & 56 deletions
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
9+
import glob
10+
import os
11+
12+
import sys
13+
import unittest
14+
15+
import torch
16+
17+
sys.path.insert(
18+
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
19+
)
20+
from quant_api import (
21+
_Int8DynActIntxWeightQuantizedLinearFallback,
22+
Int8DynActIntxWeightQuantizer,
23+
)
24+
25+
libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*")
26+
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
27+
if len(libs) == 0:
28+
print(
29+
"Could not find library lowbit_op_aten; please run `sh build_custom_op.sh` to build the library. A slow fallback kernel will be used instaed."
30+
)
31+
else:
32+
torch.ops.load_library(libs[0])
33+
34+
35+
class TestInt8DynActIntxWeightQuantizer(unittest.TestCase):
36+
def test_accuracy(self):
37+
group_size = 128
38+
m = 1
39+
n = 1071
40+
k = 4096
41+
activations = torch.randn(m, k, dtype=torch.float32)
42+
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
43+
44+
for nbit in [1, 2, 3, 4, 5, 6, 7]:
45+
for has_weight_zeros in [True, False]:
46+
print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}")
47+
quantized_model = copy.deepcopy(model)
48+
quantizer = Int8DynActIntxWeightQuantizer(
49+
device="cpu",
50+
precision=torch.float32,
51+
bitwidth=nbit,
52+
groupsize=group_size,
53+
has_weight_zeros=has_weight_zeros,
54+
)
55+
quantized_model = quantizer.quantize(quantized_model)
56+
57+
with torch.no_grad():
58+
result = quantized_model(activations)
59+
reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback()
60+
reference_impl.quantize_and_pack_weights(
61+
model[0].weight, nbit, group_size, has_weight_zeros
62+
)
63+
expected_result = reference_impl(activations)
64+
65+
num_mismatch_at_low_tol = 0
66+
num_total = result.reshape(-1).shape[0]
67+
for i in range(num_total):
68+
actual_val = result.reshape(-1)[i]
69+
expected_val = expected_result.reshape(-1)[i]
70+
self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6))
71+
if not torch.allclose(actual_val, expected_val):
72+
num_mismatch_at_low_tol += 1
73+
74+
# Assert at most 5% of entries are not close at a low tolerance
75+
self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05)
76+
77+
78+
if __name__ == "__main__":
79+
unittest.main()

0 commit comments

Comments
 (0)