@@ -741,6 +741,55 @@ def test_concat_tensordict():
741741 assert output ["temp" ] == 1.0
742742
743743
744+ def test_chunk_tensordict ():
745+ # Qwen-VL 3d position_ids
746+ position_ids = torch .nested .as_nested_tensor (
747+ [
748+ torch .arange (4 ).expand (4 , 4 ),
749+ torch .arange (5 ).expand (4 , 5 ),
750+ torch .arange (6 ).expand (4 , 6 ),
751+ torch .arange (7 ).expand (4 , 7 ),
752+ ],
753+ layout = torch .jagged ,
754+ )
755+ input_ids = torch .nested .as_nested_tensor (
756+ [torch .arange (4 ), torch .arange (5 ), torch .arange (6 ), torch .arange (7 )], layout = torch .jagged
757+ )
758+
759+ multi_modal_inputs = torch .stack (
760+ [
761+ NonTensorData ({"pixel_values" : torch .randn (3 , 224 , 224 )}),
762+ NonTensorData (None ),
763+ NonTensorData ({"pixel_values" : torch .randn (3 , 128 , 128 )}),
764+ NonTensorData ({"pixel_values" : torch .randn (3 , 128 , 128 )}),
765+ ]
766+ )
767+ td = tu .get_tensordict (
768+ {
769+ "input_ids" : input_ids ,
770+ "position_ids" : position_ids ,
771+ "multi_modal_inputs" : multi_modal_inputs ,
772+ },
773+ )
774+ assert len (td ) == 4
775+ chunks = tu .chunk_tensordict (td , chunks = 2 )
776+
777+ for i , chunk in enumerate (chunks ):
778+ assert len (chunk ) == 2
779+ for key , val in chunk .items ():
780+ if isinstance (val , torch .Tensor ) and val .is_nested :
781+ tensors = td [key ].unbind (dim = 0 )
782+ expected = torch .nested .as_nested_tensor (tensors [i * 2 : (i + 1 ) * 2 ], layout = torch .jagged )
783+ assert torch .all (torch .eq (val .values (), expected .values ())).item ()
784+ else :
785+ expected = td [key ][i * 2 : (i + 1 ) * 2 ]
786+ for tensor , expect in zip (val , expected , strict = False ):
787+ if tensor .data is None :
788+ assert expect is None
789+ else :
790+ assert torch .all (torch .eq (tensor .data ["pixel_values" ], expect ["pixel_values" ])).item ()
791+
792+
744793def test_assign_non_tensor_stack_with_nested_lists ():
745794 """Test assign_non_tensor_stack with lists of lists."""
746795 td = tu .get_tensordict ({"obs" : torch .randn (3 , 4 )}, non_tensor_dict = {})
0 commit comments