Skip to content

Commit fd4463f

Browse files
committed
Remove old imports to unblock code cleanup in torchao
Specifically, this unblocks: pytorch/ao#3308 pytorch/ao#3146
1 parent 67ab86b commit fd4463f

File tree

2 files changed

+33
-20
lines changed

2 files changed

+33
-20
lines changed

tests/recipes/test_configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import torchtune
1010

1111
from omegaconf import OmegaConf
12-
from torchao.utils import TORCH_VERSION_AFTER_2_4
1312
from torchtune import config
13+
from torchtune.utils import torch_version_ge
1414

1515
CONFIG_DIR = Path(torchtune.__file__).parent.parent / "recipes" / "configs"
1616

@@ -24,7 +24,7 @@ def test_instantiate(self) -> None:
2424
]
2525
for config_path in all_configs:
2626
# QAT config is only compatible with PyTorch 2.4+
27-
if config_path.endswith("qat_full.yaml") and not TORCH_VERSION_AFTER_2_4:
27+
if config_path.endswith("qat_full.yaml") and not torch_version_ge("2.4.0"):
2828
continue
2929
cfg = OmegaConf.load(config_path)
3030
config.validate(cfg)

torchtune/training/quantization.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,15 @@
1919
Float8RowwiseParallel,
2020
)
2121
from torchao.quantization import (
22-
int4_weight_only,
23-
int8_dynamic_activation_int4_weight,
22+
Int4WeightOnlyConfig,
23+
Int8DynamicActivationInt4WeightConfig,
2424
quantize_,
2525
)
2626
from torchao.quantization.qat import (
2727
Int4WeightOnlyQATQuantizer,
2828
Int8DynActInt4WeightQATQuantizer,
2929
)
30-
from torchao.quantization.qat.linear import (
31-
disable_4w_fake_quant,
32-
disable_8da4w_fake_quant,
33-
enable_4w_fake_quant,
34-
enable_8da4w_fake_quant,
35-
)
30+
3631

3732
from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear
3833

@@ -58,6 +53,24 @@
5853
_quantizer_mode_to_enable_fake_quant = {}
5954

6055

56+
def _enable_linear_fake_quant(
57+
mod: torch.nn.Module,
58+
enabled: bool = True,
59+
):
60+
"""
61+
Helper function to enable fake quantization in `FakeQuantizedLinear`.
62+
"""
63+
if isinstance(mod, FakeQuantizedLinear):
64+
if mod.activation_fake_quantizer is not None:
65+
mod.activation_fake_quantizer.enabled = enabled
66+
if mod.weight_fake_quantizer is not None:
67+
mod.weight_fake_quantizer.enabled = enabled
68+
69+
70+
def _disable_linear_fake_quant(mod: torch.nn.Module):
71+
_enable_linear_fake_quant(mod, enabled=False)
72+
73+
6174
# ========================================
6275
# int8 dynamic activations + int4 weight |
6376
# ========================================
@@ -73,15 +86,15 @@ def __init__(self, groupsize: int = 256):
7386
self.groupsize = groupsize
7487

7588
def quantize(self, model):
76-
quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize)
89+
quantize_fn = Int8DynamicActivationInt4WeightConfig(self.groupsize)
7790
quantize_(model, quantize_fn)
7891
return model
7992

8093

8194
_quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w"
8295
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizer] = "8da4w-qat"
83-
_quantizer_mode_to_disable_fake_quant["8da4w-qat"] = disable_8da4w_fake_quant
84-
_quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant
96+
_quantizer_mode_to_disable_fake_quant["8da4w-qat"] = _disable_linear_fake_quant
97+
_quantizer_mode_to_enable_fake_quant["8da4w-qat"] = _enable_linear_fake_quant
8598

8699

87100
# ==================
@@ -101,15 +114,15 @@ def __init__(self, groupsize: int = 128, inner_k_tiles: int = 8):
101114

102115
def quantize(self, model):
103116
layout_type = TensorCoreTiledLayout(self.inner_k_tiles)
104-
quantize_fn = int4_weight_only(self.groupsize, layout_type)
117+
quantize_fn = Int4WeightOnlyConfig(self.groupsize, layout_type)
105118
quantize_(model, quantize_fn)
106119
return model
107120

108121

109122
_quantizer_to_mode[Int4WeightOnlyQuantizer] = "4w"
110123
_quantizer_to_mode[Int4WeightOnlyQATQuantizer] = "4w-qat"
111-
_quantizer_mode_to_disable_fake_quant["4w-qat"] = disable_4w_fake_quant
112-
_quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant
124+
_quantizer_mode_to_disable_fake_quant["4w-qat"] = _disable_linear_fake_quant
125+
_quantizer_mode_to_enable_fake_quant["4w-qat"] = _enable_linear_fake_quant
113126

114127

115128
# ====================== #
@@ -122,8 +135,8 @@ class Int4WeightOnlyQATQuantizerModuleSwap(Int4WeightOnlyQATQuantizer):
122135
pass
123136

124137

125-
disable_4w_fake_quant_module_swap = disable_4w_fake_quant
126-
enable_4w_fake_quant_module_swap = enable_4w_fake_quant
138+
disable_4w_fake_quant_module_swap = _disable_linear_fake_quant
139+
enable_4w_fake_quant_module_swap = _enable_linear_fake_quant
127140
_quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap"
128141
_quantizer_mode_to_disable_fake_quant[
129142
"4w-qat-module-swap"
@@ -138,8 +151,8 @@ class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantize
138151
pass
139152

140153

141-
disable_8da4w_fake_quant_module_swap = disable_8da4w_fake_quant
142-
enable_8da4w_fake_quant_module_swap = enable_8da4w_fake_quant
154+
disable_8da4w_fake_quant_module_swap = _disable_linear_fake_quant
155+
enable_8da4w_fake_quant_module_swap = _enable_linear_fake_quant
143156
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap"
144157
_quantizer_mode_to_disable_fake_quant[
145158
"8da4w-qat-module-swap"

0 commit comments

Comments
 (0)