@@ -323,7 +323,6 @@ def test_tensors_inferred_from_map(
323323 ray_start_regular_shared , restore_data_context , tensor_format
324324):
325325 DataContext .get_current ().use_arrow_tensor_v2 = tensor_format == "v2"
326- class_name = "ArrowTensorTypeV2" if tensor_format == "v2" else "ArrowTensorType"
327326 # Test map.
328327 ds = ray .data .range (10 , override_num_blocks = 10 ).map (
329328 lambda _ : {"data" : np .ones ((4 , 4 ))}
@@ -332,7 +331,11 @@ def test_tensors_inferred_from_map(
332331 assert ds .count () == 10
333332 schema = ds .schema ()
334333 assert schema .names == ["data" ]
335- assert str (schema .types [0 ]) == f"{ class_name } (shape=(4, 4), dtype=double)"
334+ dtype = schema .types [0 ]
335+ expected_type = ArrowTensorTypeV2 if tensor_format == "v2" else ArrowTensorType
336+ assert isinstance (dtype , expected_type )
337+ assert dtype .shape == (4 , 4 )
338+ assert dtype .scalar_type == pa .float64 ()
336339
337340 # Test map_batches.
338341 ds = ray .data .range (16 , override_num_blocks = 4 ).map_batches (
@@ -342,7 +345,11 @@ def test_tensors_inferred_from_map(
342345 assert ds .count () == 24
343346 schema = ds .schema ()
344347 assert schema .names == ["data" ]
345- assert str (schema .types [0 ]) == f"{ class_name } (shape=(4, 4), dtype=double)"
348+ dtype = schema .types [0 ]
349+ expected_type = ArrowTensorTypeV2 if tensor_format == "v2" else ArrowTensorType
350+ assert isinstance (dtype , expected_type )
351+ assert dtype .shape == (4 , 4 )
352+ assert dtype .scalar_type == pa .float64 ()
346353
347354 # Test flat_map.
348355 ds = ray .data .range (10 , override_num_blocks = 10 ).flat_map (
@@ -352,7 +359,11 @@ def test_tensors_inferred_from_map(
352359 assert ds .count () == 20
353360 schema = ds .schema ()
354361 assert schema .names == ["data" ]
355- assert str (schema .types [0 ]) == f"{ class_name } (shape=(4, 4), dtype=double)"
362+ dtype = schema .types [0 ]
363+ expected_type = ArrowTensorTypeV2 if tensor_format == "v2" else ArrowTensorType
364+ assert isinstance (dtype , expected_type )
365+ assert dtype .shape == (4 , 4 )
366+ assert dtype .scalar_type == pa .float64 ()
356367
357368 # Test map_batches ndarray column.
358369 ds = ray .data .range (16 , override_num_blocks = 4 ).map_batches (
@@ -362,7 +373,11 @@ def test_tensors_inferred_from_map(
362373 assert ds .count () == 24
363374 schema = ds .schema ()
364375 assert schema .names == ["a" ]
365- assert str (schema .types [0 ]) == f"{ class_name } (shape=(4, 4), dtype=double)"
376+ dtype = schema .types [0 ]
377+ expected_type = ArrowTensorTypeV2 if tensor_format == "v2" else ArrowTensorType
378+ assert isinstance (dtype , expected_type )
379+ assert dtype .shape == (4 , 4 )
380+ assert dtype .scalar_type == pa .float64 ()
366381
367382 ds = ray .data .range (16 , override_num_blocks = 4 ).map_batches (
368383 lambda _ : pd .DataFrame ({"a" : [np .ones ((2 , 2 )), np .ones ((3 , 3 ))]}),
@@ -372,7 +387,11 @@ def test_tensors_inferred_from_map(
372387 assert ds .count () == 16
373388 schema = ds .schema ()
374389 assert schema .names == ["a" ]
375- assert str (schema .types [0 ]) == f"{ class_name } (shape=(None, None), dtype=double)"
390+ dtype = schema .types [0 ]
391+ expected_type = ArrowTensorTypeV2 if tensor_format == "v2" else ArrowTensorType
392+ assert isinstance (dtype , expected_type )
393+ assert dtype .shape == (None , None )
394+ assert dtype .scalar_type == pa .float64 ()
376395
377396
378397@pytest .mark .parametrize ("tensor_format" , ["v1" , "v2" ])
0 commit comments