@@ -956,6 +956,72 @@ def _(func, types, args, kwargs):
956956 return return_and_correct_aliasing (func , args , kwargs , new_tensor )
957957
958958
959+ @implements (aten .split .Tensor )
960+ def _ (func , types , args , kwargs ):
961+ tensor , split_size_or_sections , dim = args
962+ assert isinstance (split_size_or_sections , int ), "unimplemented"
963+
964+ # 2D case
965+ #
966+ # orig
967+ # qdata.shape [M, K]
968+ # scale.shape [M, 1]
969+ # block_size [1, K]
970+ #
971+ # split with size (K // 2) across dim -1:
972+ # qdata.shape [M, K // 2], [M, K // 2]
973+ # scale.shape [M, 1], [M, 1]
974+ # block_size [1, K // 2], [1, K // 2]
975+ #
976+ # split with size (M // 2) across dim 0:
977+ # qdata.shape [M // 2, K], [M // 2, K]
978+ # scale.shape [M // 2, 1], [M // 2, 1]
979+ # block_size [1, K], [1, K]
980+
981+ # split the qdata
982+ new_qdatas = func (tensor .qdata , split_size_or_sections , dim )
983+ num_chunks = len (new_qdatas )
984+
985+ # split the scale
986+ new_scales = []
987+ new_block_sizes = []
988+ if tensor .scale .shape [dim ] == 1 and tensor .block_size [dim ] == tensor .shape [dim ]:
989+ # repeat the scale, split block_size
990+ for _ in range (num_chunks ):
991+ new_scales .append (tensor .scale )
992+ new_block_size = tensor .block_size
993+ new_block_size [dim ] = new_block_size [dim ] // split_size_or_sections
994+ new_block_sizes .append (new_block_size )
995+
996+ elif tensor .scale .shape [dim ] == tensor .shape [dim ] and tensor .block_size [dim ] == 1 :
997+ # repeat the block size, split scale
998+ new_scales = func (tensor .scale , split_size_or_sections , dim )
999+ for _ in range (num_chunks ):
1000+ new_block_sizes .append (tensor .block_size )
1001+
1002+ else :
1003+ raise AssertionError (
1004+ f"`aten.split.Tensor` with { dim = } and { tensor .scale .shape = } is not yet implemented"
1005+ )
1006+
1007+ new_tensors_list = []
1008+ for idx in range (num_chunks ):
1009+ new_tensor = tensor .__class__ (
1010+ new_qdatas [idx ],
1011+ new_scales [idx ],
1012+ new_block_sizes [idx ],
1013+ tensor .mm_config ,
1014+ tensor .act_quant_kwargs ,
1015+ tensor .kernel_preference ,
1016+ tensor .dtype ,
1017+ )
1018+ new_tensor = return_and_correct_aliasing (func , args , kwargs , new_tensor )
1019+ new_tensors_list .append (new_tensor )
1020+
1021+ new_tensors_tuple = tuple (new_tensors_list )
1022+ return new_tensors_tuple
1023+
1024+
9591025Float8Tensor .__module__ = "torchao.quantization"
9601026
9611027# Allow a model with Float8Tensor weights to be loaded with `weights_only=True`
0 commit comments