Skip to content

Commit b184584

Browse files
committed
Update
[ghstack-poisoned]
1 parent b4ec4cb commit b184584

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,12 @@ def test_per_row_config_before_dim(self):
11571157
assert config_deser.granularity[0].dim == -1
11581158
assert config_deser.granularity[1].dim == -1
11591159

1160+
@common_utils.parametrize("dim", [-2, -1])
1161+
def test_chunk(self, dim):
1162+
x = torch.randn(16, 5120, 16384, device="cuda", dtype=torch.bfloat16)
1163+
x_fp8 = Float8Tensor.from_hp(x)
1164+
self._test_chunk_similar_to_vllm_llama4(x_fp8, dim)
1165+
11601166

11611167
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)
11621168

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,72 @@ def _(func, types, args, kwargs):
956956
return return_and_correct_aliasing(func, args, kwargs, new_tensor)
957957

958958

959+
@implements(aten.split.Tensor)
960+
def _(func, types, args, kwargs):
961+
tensor, split_size_or_sections, dim = args
962+
assert isinstance(split_size_or_sections, int), "unimplemented"
963+
964+
# 2D case
965+
#
966+
# orig
967+
# qdata.shape [M, K]
968+
# scale.shape [M, 1]
969+
# block_size [1, K]
970+
#
971+
# split with size (K // 2) across dim -1:
972+
# qdata.shape [M, K // 2], [M, K // 2]
973+
# scale.shape [M, 1], [M, 1]
974+
# block_size [1, K // 2], [1, K // 2]
975+
#
976+
# split with size (M // 2) across dim 0:
977+
# qdata.shape [M // 2, K], [M // 2, K]
978+
# scale.shape [M // 2, 1], [M // 2, 1]
979+
# block_size [1, K], [1, K]
980+
981+
# split the qdata
982+
new_qdatas = func(tensor.qdata, split_size_or_sections, dim)
983+
num_chunks = len(new_qdatas)
984+
985+
# split the scale
986+
new_scales = []
987+
new_block_sizes = []
988+
if tensor.scale.shape[dim] == 1 and tensor.block_size[dim] == tensor.shape[dim]:
989+
# repeat the scale, split block_size
990+
for _ in range(num_chunks):
991+
new_scales.append(tensor.scale)
992+
new_block_size = tensor.block_size
993+
new_block_size[dim] = new_block_size[dim] // split_size_or_sections
994+
new_block_sizes.append(new_block_size)
995+
996+
elif tensor.scale.shape[dim] == tensor.shape[dim] and tensor.block_size[dim] == 1:
997+
# repeat the block size, split scale
998+
new_scales = func(tensor.scale, split_size_or_sections, dim)
999+
for _ in range(num_chunks):
1000+
new_block_sizes.append(tensor.block_size)
1001+
1002+
else:
1003+
raise AssertionError(
1004+
f"`aten.split.Tensor` with {dim=} and {tensor.scale.shape=} is not yet implemented"
1005+
)
1006+
1007+
new_tensors_list = []
1008+
for idx in range(num_chunks):
1009+
new_tensor = tensor.__class__(
1010+
new_qdatas[idx],
1011+
new_scales[idx],
1012+
new_block_sizes[idx],
1013+
tensor.mm_config,
1014+
tensor.act_quant_kwargs,
1015+
tensor.kernel_preference,
1016+
tensor.dtype,
1017+
)
1018+
new_tensor = return_and_correct_aliasing(func, args, kwargs, new_tensor)
1019+
new_tensors_list.append(new_tensor)
1020+
1021+
new_tensors_tuple = tuple(new_tensors_list)
1022+
return new_tensors_tuple
1023+
1024+
9591025
Float8Tensor.__module__ = "torchao.quantization"
9601026

9611027
# Allow a model with Float8Tensor weights to be loaded with `weights_only=True`

torchao/testing/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,15 @@ def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig):
641641
quantize_(l, config)
642642
_w_slice = l.weight[0]
643643

644+
def _test_chunk_similar_to_vllm_llama4(self, ao_tensor, dim):
645+
# source code in vLLM LLaMa 4:
646+
# https://github.com/vllm-project/vllm/blob/34553b9d2702dd2a27a578fec819e88e76dcbfb4/vllm/model_executor/models/llama4.py#L455
647+
ao_tensor_chunked = ao_tensor.chunk(2, dim=dim)
648+
ao_tensor_unchunked = torch.cat(ao_tensor_chunked, dim=dim)
649+
torch.testing.assert_close(
650+
ao_tensor.dequantize(), ao_tensor_unchunked.dequantize(), atol=0, rtol=0
651+
)
652+
644653

645654
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
646655
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)

0 commit comments

Comments
 (0)