Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,12 @@ def test_per_row_config_before_dim(self):
assert config_deser.granularity[0].dim == -1
assert config_deser.granularity[1].dim == -1

@common_utils.parametrize("dim", [-2, -1])
def test_chunk(self, dim):
x = torch.randn(16, 5120, 16384, device="cuda", dtype=torch.bfloat16)
x_fp8 = Float8Tensor.from_hp(x)
self._test_chunk_similar_to_vllm_llama4(x_fp8, dim)


common_utils.instantiate_parametrized_tests(TestFloat8Tensor)

Expand Down
66 changes: 66 additions & 0 deletions torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,72 @@ def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, new_tensor)


@implements(aten.split.Tensor)
def _(func, types, args, kwargs):
tensor, split_size_or_sections, dim = args
assert isinstance(split_size_or_sections, int), "unimplemented"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: better error message?


# 2D case
#
# orig
# qdata.shape [M, K]
# scale.shape [M, 1]
# block_size [1, K]
#
# split with size (K // 2) across dim -1:
# qdata.shape [M, K // 2], [M, K // 2]
# scale.shape [M, 1], [M, 1]
# block_size [1, K // 2], [1, K // 2]
#
# split with size (M // 2) across dim 0:
# qdata.shape [M // 2, K], [M // 2, K]
# scale.shape [M // 2, 1], [M // 2, 1]
# block_size [1, K], [1, K]

# split the qdata
new_qdatas = func(tensor.qdata, split_size_or_sections, dim)
num_chunks = len(new_qdatas)

# split the scale
new_scales = []
new_block_sizes = []
if tensor.scale.shape[dim] == 1 and tensor.block_size[dim] == tensor.shape[dim]:
# repeat the scale, split block_size
for _ in range(num_chunks):
new_scales.append(tensor.scale)
new_block_size = tensor.block_size
new_block_size[dim] = new_block_size[dim] // split_size_or_sections
new_block_sizes.append(new_block_size)

elif tensor.scale.shape[dim] == tensor.shape[dim] and tensor.block_size[dim] == 1:
# repeat the block size, split scale
new_scales = func(tensor.scale, split_size_or_sections, dim)
for _ in range(num_chunks):
new_block_sizes.append(tensor.block_size)

else:
raise AssertionError(
f"`aten.split.Tensor` with {dim=} and {tensor.scale.shape=} is not yet implemented"
)

new_tensors_list = []
for idx in range(num_chunks):
new_tensor = tensor.__class__(
new_qdatas[idx],
new_scales[idx],
new_block_sizes[idx],
tensor.mm_config,
tensor.act_quant_kwargs,
tensor.kernel_preference,
tensor.dtype,
)
new_tensor = return_and_correct_aliasing(func, args, kwargs, new_tensor)
new_tensors_list.append(new_tensor)

new_tensors_tuple = tuple(new_tensors_list)
return new_tensors_tuple


Float8Tensor.__module__ = "torchao.quantization"

# Allow a model with Float8Tensor weights to be loaded with `weights_only=True`
Expand Down
9 changes: 9 additions & 0 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,15 @@ def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig):
quantize_(l, config)
_w_slice = l.weight[0]

def _test_chunk_similar_to_vllm_llama4(self, ao_tensor, dim):
# source code in vLLM LLaMa 4:
# https://github.com/vllm-project/vllm/blob/34553b9d2702dd2a27a578fec819e88e76dcbfb4/vllm/model_executor/models/llama4.py#L455
ao_tensor_chunked = ao_tensor.chunk(2, dim=dim)
ao_tensor_unchunked = torch.cat(ao_tensor_chunked, dim=dim)
torch.testing.assert_close(
ao_tensor.dequantize(), ao_tensor_unchunked.dequantize(), atol=0, rtol=0
)


common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)
Expand Down
Loading