Skip to content

Commit bd747a2

Browse files
authored
Adding numerical tests for Int8DynamicActivationIntxWeightConfig (#2065)
* init * up * up * up * up
1 parent bda5305 commit bd747a2

File tree

2 files changed

+403
-0
lines changed

2 files changed

+403
-0
lines changed

torchao/experimental/tests/test_embedding_xbit_quantizer.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import unittest
1010

1111
import torch
12+
from parameterized import param, parameterized
1213
from torch.testing import FileCheck
1314

1415
from torchao.dtypes import (
@@ -19,8 +20,15 @@
1920
SharedEmbeddingQuantizer,
2021
)
2122
from torchao.quantization.granularity import PerAxis, PerGroup
23+
from torchao.quantization.qat import (
24+
FakeQuantizeConfig,
25+
FromIntXQuantizationAwareTrainingConfig,
26+
Int4WeightOnlyEmbeddingQATQuantizer,
27+
IntXQuantizationAwareTrainingConfig,
28+
)
2229
from torchao.quantization.quant_api import (
2330
Int8DynamicActivationIntxWeightConfig,
31+
IntxWeightOnlyConfig,
2432
MappingType,
2533
quantize_,
2634
)
@@ -184,6 +192,184 @@ def test_shared_embedding(self):
184192
exported_program.graph_module.code
185193
)
186194

195+
@parameterized.expand(
196+
[
197+
param(
198+
weight_dtype=weight_dtype,
199+
granularity=granularity,
200+
mapping_type=mapping_type,
201+
model_dtype=model_dtype,
202+
)
203+
for weight_dtype in [getattr(torch, f"int{x}") for x in range(1, 9)]
204+
for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)]
205+
for mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]
206+
for model_dtype in [torch.float32]
207+
],
208+
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
209+
)
210+
def test_identical_to_IntxWeightOnlyConfig(
211+
self, weight_dtype, granularity, mapping_type, model_dtype
212+
):
213+
embedding_dim = 4096
214+
num_embeddings = 131
215+
model = torch.nn.Sequential(
216+
*[torch.nn.Embedding(num_embeddings, embedding_dim)]
217+
)
218+
model = model.to(model_dtype)
219+
indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32)
220+
221+
quantized_model = copy.deepcopy(model)
222+
quantizer = EmbeddingQuantizer(
223+
weight_dtype=weight_dtype,
224+
granularity=granularity,
225+
mapping_type=mapping_type,
226+
)
227+
quantized_model = quantizer.quantize(quantized_model)
228+
actual_result = quantized_model(indices)
229+
230+
reference_model = copy.deepcopy(model)
231+
quantize_(
232+
reference_model,
233+
IntxWeightOnlyConfig(
234+
weight_dtype=weight_dtype,
235+
granularity=granularity,
236+
mapping_type=mapping_type,
237+
scale_dtype=None,
238+
),
239+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
240+
)
241+
expected_result = reference_model(indices)
242+
self.assertTrue(torch.allclose(actual_result, expected_result))
243+
244+
@parameterized.expand(
245+
[
246+
param(
247+
weight_dtype=weight_dtype,
248+
granularity=granularity,
249+
mapping_type=mapping_type,
250+
scale_dtype=scale_dtype,
251+
model_dtype=model_dtype,
252+
)
253+
for weight_dtype in [getattr(torch, f"int{x}") for x in range(1, 9)]
254+
for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)]
255+
for mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]
256+
for scale_dtype in [torch.float32, torch.bfloat16, torch.float16]
257+
for model_dtype in [torch.float32, torch.bfloat16]
258+
],
259+
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
260+
)
261+
def test_identical_to_IntXQuantizationAwareTrainingConfig(
262+
self, weight_dtype, granularity, mapping_type, scale_dtype, model_dtype
263+
):
264+
# ASYMMETRIC in QAT is very different that PTQ configs
265+
if mapping_type == MappingType.ASYMMETRIC:
266+
return
267+
268+
embedding_dim = 4096
269+
num_embeddings = 131
270+
model = torch.nn.Sequential(
271+
*[torch.nn.Embedding(num_embeddings, embedding_dim)]
272+
)
273+
model = model.to(model_dtype)
274+
indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32)
275+
276+
is_symmetric = mapping_type == MappingType.SYMMETRIC
277+
group_size = (
278+
granularity.group_size
279+
if isinstance(granularity, PerGroup)
280+
else embedding_dim
281+
)
282+
283+
embedding_filter = lambda m, fqn: isinstance(m, torch.nn.Embedding)
284+
weight_config = FakeQuantizeConfig(
285+
weight_dtype,
286+
group_size=group_size,
287+
is_symmetric=is_symmetric,
288+
scale_precision=scale_dtype,
289+
)
290+
quantize_(
291+
model,
292+
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
293+
embedding_filter,
294+
)
295+
expected_out = model(indices)
296+
297+
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
298+
quantize_(
299+
model,
300+
IntxWeightOnlyConfig(
301+
weight_dtype=weight_dtype,
302+
granularity=granularity,
303+
mapping_type=mapping_type,
304+
scale_dtype=scale_dtype,
305+
),
306+
embedding_filter,
307+
)
308+
actual_out = model(indices)
309+
self.assertTrue(torch.allclose(expected_out, actual_out))
310+
311+
@parameterized.expand(
312+
[
313+
param(
314+
granularity=granularity,
315+
scale_dtype=scale_dtype,
316+
model_dtype=model_dtype,
317+
)
318+
for granularity in [PerGroup(32), PerGroup(128), PerAxis(0)]
319+
for scale_dtype in [torch.float32, torch.bfloat16, torch.float16]
320+
for model_dtype in [torch.float32, torch.bfloat16]
321+
],
322+
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
323+
)
324+
def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer(
325+
self, granularity, scale_dtype, model_dtype
326+
):
327+
embedding_dim = 4096
328+
num_embeddings = 131
329+
model = torch.nn.Sequential(
330+
*[torch.nn.Embedding(num_embeddings, embedding_dim)]
331+
)
332+
model = model.to(model_dtype)
333+
indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32)
334+
335+
group_size = (
336+
granularity.group_size
337+
if isinstance(granularity, PerGroup)
338+
else embedding_dim
339+
)
340+
341+
embedding_filter = lambda m, fqn: isinstance(m, torch.nn.Embedding)
342+
343+
qat_quantizer = Int4WeightOnlyEmbeddingQATQuantizer(
344+
group_size=group_size,
345+
scale_precision=scale_dtype,
346+
zero_point_precision=torch.int32,
347+
)
348+
model = qat_quantizer.prepare(model)
349+
expected_out = model(indices)
350+
351+
# Convert model method 1
352+
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
353+
quantize_(
354+
model,
355+
IntxWeightOnlyConfig(
356+
weight_dtype=torch.int4,
357+
granularity=granularity,
358+
mapping_type=MappingType.SYMMETRIC,
359+
scale_dtype=scale_dtype,
360+
),
361+
embedding_filter,
362+
)
363+
actual_out1 = model(indices)
364+
self.assertTrue(torch.allclose(expected_out, actual_out1))
365+
366+
# TODO: method 2 does not work because the converted embedding op
367+
# incorrectly casts output of to indices.dtype
368+
# Convert model method 2
369+
# qat_quantizer.convert(prepared_model_copy)
370+
# actual_out2 = prepared_model_copy(indices)
371+
# self.assertTrue(torch.allclose(expected_out, actual_out2))
372+
187373

188374
if __name__ == "__main__":
189375
unittest.main()

0 commit comments

Comments
 (0)