|
14 | 14 |
|
15 | 15 | from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout
|
16 | 16 | from torchao.quantization.granularity import PerAxis, PerGroup
|
| 17 | +from torchao.quantization.qat import ( |
| 18 | + FakeQuantizeConfig, |
| 19 | + FromIntXQuantizationAwareTrainingConfig, |
| 20 | + IntXQuantizationAwareTrainingConfig, |
| 21 | +) |
17 | 22 | from torchao.quantization.quant_api import (
|
| 23 | + Int8DynamicActivationInt4WeightConfig, |
18 | 24 | Int8DynamicActivationIntxWeightConfig,
|
19 | 25 | MappingType,
|
20 | 26 | quantize_,
|
@@ -418,6 +424,120 @@ def test_moved_error(self):
|
418 | 424 | granularity=PerGroup(64),
|
419 | 425 | )
|
420 | 426 |
|
| 427 | + @parameterized.expand( |
| 428 | + [ |
| 429 | + param( |
| 430 | + group_size=group_size, |
| 431 | + mapping_type=mapping_type, |
| 432 | + act_mapping_type=act_mapping_type, |
| 433 | + ) |
| 434 | + for group_size, mapping_type, act_mapping_type in zip( |
| 435 | + [32, 64], |
| 436 | + [MappingType.ASYMMETRIC, MappingType.SYMMETRIC], |
| 437 | + [MappingType.ASYMMETRIC, MappingType.SYMMETRIC], |
| 438 | + ) |
| 439 | + ], |
| 440 | + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", |
| 441 | + ) |
| 442 | + def test_identical_to_int8_dynamic_activation_int4_weight( |
| 443 | + self, group_size, mapping_type, act_mapping_type |
| 444 | + ): |
| 445 | + """ |
| 446 | + Checks that Int8DynamicActivationIntxWeightConfig with weight_dtype=torch.int4 is identical to Int8DynamicActivationInt4WeightConfig |
| 447 | + """ |
| 448 | + k0 = 512 |
| 449 | + k1 = 256 |
| 450 | + layers = [ |
| 451 | + torch.nn.Linear(k0, k1), |
| 452 | + ] |
| 453 | + model = torch.nn.Sequential(*layers) |
| 454 | + activations = torch.randn(3, 1, k0) |
| 455 | + |
| 456 | + model_copy = copy.deepcopy(model) |
| 457 | + |
| 458 | + quantize_( |
| 459 | + model, |
| 460 | + Int8DynamicActivationIntxWeightConfig( |
| 461 | + weight_dtype=torch.int4, |
| 462 | + weight_granularity=PerGroup(group_size), |
| 463 | + weight_mapping_type=mapping_type, |
| 464 | + weight_scale_dtype=None, |
| 465 | + act_mapping_type=act_mapping_type, |
| 466 | + ), |
| 467 | + ) |
| 468 | + quantize_( |
| 469 | + model_copy, |
| 470 | + Int8DynamicActivationInt4WeightConfig( |
| 471 | + group_size=group_size, |
| 472 | + mapping_type=mapping_type, |
| 473 | + act_mapping_type=act_mapping_type, |
| 474 | + ), |
| 475 | + ) |
| 476 | + with torch.no_grad(): |
| 477 | + torch.allclose(model(activations), model_copy(activations)) |
| 478 | + |
| 479 | + @parameterized.expand( |
| 480 | + [ |
| 481 | + param( |
| 482 | + group_size=group_size, |
| 483 | + mapping_type=mapping_type, |
| 484 | + act_mapping_type=act_mapping_type, |
| 485 | + ) |
| 486 | + for group_size, mapping_type, act_mapping_type in zip( |
| 487 | + [64, 128], |
| 488 | + [ |
| 489 | + MappingType.SYMMETRIC, |
| 490 | + ], |
| 491 | + [ |
| 492 | + MappingType.ASYMMETRIC, |
| 493 | + ], |
| 494 | + ) |
| 495 | + ], |
| 496 | + name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}", |
| 497 | + ) |
| 498 | + @unittest.skip("not working yet") |
| 499 | + def test_identical_to_qat_configs(self, group_size, mapping_type, act_mapping_type): |
| 500 | + k0 = 512 |
| 501 | + k1 = 256 |
| 502 | + layers = [ |
| 503 | + torch.nn.Linear(k0, k1), |
| 504 | + ] |
| 505 | + model = torch.nn.Sequential(*layers) |
| 506 | + activations = torch.randn(3, 1, k0) |
| 507 | + |
| 508 | + weight_dtype = torch.int4 |
| 509 | + |
| 510 | + activation_config = FakeQuantizeConfig( |
| 511 | + torch.int8, |
| 512 | + "per_token", |
| 513 | + is_symmetric=(act_mapping_type == MappingType.SYMMETRIC), |
| 514 | + is_dynamic=True, |
| 515 | + ) |
| 516 | + weight_config = FakeQuantizeConfig( |
| 517 | + weight_dtype, |
| 518 | + group_size=group_size, |
| 519 | + is_symmetric=(mapping_type == MappingType.SYMMETRIC), |
| 520 | + is_dynamic=False, |
| 521 | + ) |
| 522 | + quantize_( |
| 523 | + model, |
| 524 | + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), |
| 525 | + ) |
| 526 | + |
| 527 | + quantize_(model, FromIntXQuantizationAwareTrainingConfig()) |
| 528 | + expected = model(activations) |
| 529 | + |
| 530 | + quantize_( |
| 531 | + model, |
| 532 | + Int8DynamicActivationIntxWeightConfig( |
| 533 | + weight_granularity=PerGroup(group_size), |
| 534 | + weight_dtype=weight_dtype, |
| 535 | + ), |
| 536 | + ) |
| 537 | + actual = model(activations) |
| 538 | + |
| 539 | + self.assertTrue(torch.allclose(expected, actual)) |
| 540 | + |
421 | 541 |
|
422 | 542 | if __name__ == "__main__":
|
423 | 543 | unittest.main()
|
0 commit comments