Skip to content

Commit b5523cd

Browse files
authored
pyre-fix
Differential Revision: D82265586 Pull Request resolved: #14241
1 parent 09f5beb commit b5523cd

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

backends/xnnpack/test/ops/test_linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,9 @@ def _test_groupwise_dq_linear(
395395
quantize_(
396396
mod,
397397
Int8DynamicActivationIntxWeightConfig(
398-
weight_dtype=torch.int4, weight_granularity=PerGroup(group_size)
398+
# pyre-ignore[16]
399+
weight_dtype=torch.int4,
400+
weight_granularity=PerGroup(group_size),
399401
),
400402
)
401403
unwrap_tensor_subclass(mod)

examples/models/llama/source_transformation/quantize.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def quantize( # noqa C901
135135
PerAxis(0) if group_size == 0 else PerGroup(group_size)
136136
),
137137
weight_mapping_type=MappingType.SYMMETRIC,
138+
# pyre-ignore[6]
138139
intx_packing_format="opaque_torchao_auto",
139140
),
140141
)
@@ -154,12 +155,23 @@ def quantize( # noqa C901
154155
from torchao.quantization.granularity import PerGroup
155156
from torchao.utils import unwrap_tensor_subclass
156157

158+
def filter_fn(m, fqn):
159+
is_linear = isinstance(m, nn.Linear)
160+
has_shape_compatible_with_group_size = False
161+
if is_linear:
162+
has_shape_compatible_with_group_size = (
163+
m.weight.shape[1] % group_size == 0
164+
)
165+
return is_linear and has_shape_compatible_with_group_size
166+
157167
quantize_(
158168
model,
159169
Int8DynamicActivationIntxWeightConfig(
170+
# pyre-ignore[16]
160171
weight_dtype=torch.int4,
161172
weight_granularity=PerGroup(group_size),
162173
),
174+
filter_fn=filter_fn,
163175
)
164176

165177
model = unwrap_tensor_subclass(model)

0 commit comments

Comments
 (0)