Skip to content

Commit bd00a18

Browse files
committed
init
1 parent 7fa9c69 commit bd00a18

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414

1515
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout
1616
from torchao.quantization.granularity import PerAxis, PerGroup
17+
from torchao.quantization.qat import (
18+
FakeQuantizeConfig,
19+
FromIntXQuantizationAwareTrainingConfig,
20+
IntXQuantizationAwareTrainingConfig,
21+
)
1722
from torchao.quantization.quant_api import (
23+
Int8DynamicActivationInt4WeightConfig,
1824
Int8DynamicActivationIntxWeightConfig,
1925
MappingType,
2026
quantize_,
@@ -418,6 +424,120 @@ def test_moved_error(self):
418424
granularity=PerGroup(64),
419425
)
420426

427+
@parameterized.expand(
428+
[
429+
param(
430+
group_size=group_size,
431+
mapping_type=mapping_type,
432+
act_mapping_type=act_mapping_type,
433+
)
434+
for group_size, mapping_type, act_mapping_type in zip(
435+
[32, 64],
436+
[MappingType.ASYMMETRIC, MappingType.SYMMETRIC],
437+
[MappingType.ASYMMETRIC, MappingType.SYMMETRIC],
438+
)
439+
],
440+
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
441+
)
442+
def test_identical_to_int8_dynamic_activation_int4_weight(
443+
self, group_size, mapping_type, act_mapping_type
444+
):
445+
"""
446+
Checks that Int8DynamicActivationIntxWeightConfig with weight_dtype=torch.int4 is identical to Int8DynamicActivationInt4WeightConfig
447+
"""
448+
k0 = 512
449+
k1 = 256
450+
layers = [
451+
torch.nn.Linear(k0, k1),
452+
]
453+
model = torch.nn.Sequential(*layers)
454+
activations = torch.randn(3, 1, k0)
455+
456+
model_copy = copy.deepcopy(model)
457+
458+
quantize_(
459+
model,
460+
Int8DynamicActivationIntxWeightConfig(
461+
weight_dtype=torch.int4,
462+
weight_granularity=PerGroup(group_size),
463+
weight_mapping_type=mapping_type,
464+
weight_scale_dtype=None,
465+
act_mapping_type=act_mapping_type,
466+
),
467+
)
468+
quantize_(
469+
model_copy,
470+
Int8DynamicActivationInt4WeightConfig(
471+
group_size=group_size,
472+
mapping_type=mapping_type,
473+
act_mapping_type=act_mapping_type,
474+
),
475+
)
476+
with torch.no_grad():
477+
torch.allclose(model(activations), model_copy(activations))
478+
479+
@parameterized.expand(
480+
[
481+
param(
482+
group_size=group_size,
483+
mapping_type=mapping_type,
484+
act_mapping_type=act_mapping_type,
485+
)
486+
for group_size, mapping_type, act_mapping_type in zip(
487+
[64, 128],
488+
[
489+
MappingType.SYMMETRIC,
490+
],
491+
[
492+
MappingType.ASYMMETRIC,
493+
],
494+
)
495+
],
496+
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
497+
)
498+
@unittest.skip("not working yet")
499+
def test_identical_to_qat_configs(self, group_size, mapping_type, act_mapping_type):
500+
k0 = 512
501+
k1 = 256
502+
layers = [
503+
torch.nn.Linear(k0, k1),
504+
]
505+
model = torch.nn.Sequential(*layers)
506+
activations = torch.randn(3, 1, k0)
507+
508+
weight_dtype = torch.int4
509+
510+
activation_config = FakeQuantizeConfig(
511+
torch.int8,
512+
"per_token",
513+
is_symmetric=(act_mapping_type == MappingType.SYMMETRIC),
514+
is_dynamic=True,
515+
)
516+
weight_config = FakeQuantizeConfig(
517+
weight_dtype,
518+
group_size=group_size,
519+
is_symmetric=(mapping_type == MappingType.SYMMETRIC),
520+
is_dynamic=False,
521+
)
522+
quantize_(
523+
model,
524+
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
525+
)
526+
527+
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
528+
expected = model(activations)
529+
530+
quantize_(
531+
model,
532+
Int8DynamicActivationIntxWeightConfig(
533+
weight_granularity=PerGroup(group_size),
534+
weight_dtype=weight_dtype,
535+
),
536+
)
537+
actual = model(activations)
538+
539+
self.assertTrue(torch.allclose(expected, actual))
540+
421541

422542
if __name__ == "__main__":
423543
unittest.main()

0 commit comments

Comments
 (0)