Skip to content

Commit 5f05587

Browse files
authored
[fix] Allow None values in _pack_field_values and fallback to NonTensorStack (#75)
- Modify _pack_field_values to tolerate None placeholders in the values list, falling back to NonTensorStack instead of raising ValueError. - Pure tensor lists (no None) still use torch.stack or nested tensor. - Update docstring to reflect the new None-tolerant behavior. --------- Signed-off-by: 宁本哲 <ningbenzhe@xiaohongshu.com>
1 parent 46460a1 commit 5f05587

2 files changed

Lines changed: 56 additions & 18 deletions

File tree

tests/test_async_simple_storage_manager.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -542,21 +542,45 @@ class TestPackFieldValues:
542542
def test_uniform_tensors_to_stack(self):
543543
"""Same-shape tensors → torch.stack."""
544544
values = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]
545-
result = AsyncSimpleStorageManager._pack_field_values(values)
545+
result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined]
546546
assert isinstance(result, torch.Tensor)
547547
assert not result.is_nested
548548
assert result.shape == (2, 2)
549549

550550
def test_variable_length_tensors_to_nested(self):
551551
"""Different-shape tensors → nested tensor."""
552552
values = [torch.tensor([1.0]), torch.tensor([2.0, 3.0])]
553-
result = AsyncSimpleStorageManager._pack_field_values(values)
553+
result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined]
554554
assert isinstance(result, torch.Tensor)
555555
assert result.is_nested
556556

557557
def test_non_tensors_to_nontensorstack(self):
558558
"""Non-tensor values → NonTensorStack."""
559559
values = ["hello", "world"]
560-
result = AsyncSimpleStorageManager._pack_field_values(values)
560+
result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined]
561561
assert isinstance(result, NonTensorStack)
562562
assert result.tolist() == ["hello", "world"]
563+
564+
def test_mixed_tensors_and_none_to_nontensorstack(self):
565+
"""Mixed tensor + None values should stay as NonTensorStack (no stacking)."""
566+
t0 = torch.tensor([1.0, 2.0])
567+
t2 = torch.tensor([3.0, 4.0])
568+
values = [t0, None, t2]
569+
570+
result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined]
571+
572+
assert isinstance(result, NonTensorStack)
573+
unpacked = result.tolist()
574+
assert len(unpacked) == 3
575+
assert torch.equal(unpacked[0], t0)
576+
assert unpacked[1] is None
577+
assert torch.equal(unpacked[2], t2)
578+
579+
def test_all_none_to_nontensorstack(self):
580+
"""All-None values should be preserved in NonTensorStack."""
581+
values = [None, None]
582+
583+
result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined]
584+
585+
assert isinstance(result, NonTensorStack)
586+
assert result.tolist() == [None, None]

transfer_queue/storage/managers/simple_backend_manager.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -387,24 +387,38 @@ def _pack_field_values(values: list) -> torch.Tensor | NonTensorStack:
387387
"""
388388
Pack a list of per-sample values into a batched container.
389389
390-
For tensor values, this performs a memory copy via stacking or nested tensor creation.
391-
Non-tensor values are grouped into a ``NonTensorStack`` without copying.
390+
For pure tensor lists (no None), this performs a memory copy via stacking
391+
or nested tensor creation. Mixed types, non-tensor values, or lists
392+
containing None placeholders are grouped into a ``NonTensorStack``.
393+
394+
Args:
395+
values: List of per-sample values to pack. May contain None for
396+
unfilled batch positions.
397+
398+
Returns:
399+
A stacked ``torch.Tensor`` (or nested tensor) when all values are
400+
tensors, otherwise a ``NonTensorStack``.
401+
402+
Raises:
403+
ValueError: If *values* is empty.
392404
"""
393405
if not values:
394406
raise ValueError("_pack_field_values received empty values list; caller should filter empty batches")
395-
if any(v is None for v in values):
396-
raise ValueError("_pack_field_values received None in values list; some batch positions were not filled")
397-
if all(isinstance(v, torch.Tensor) for v in values):
398-
if all(v.shape == values[0].shape for v in values):
399-
return torch.stack(values)
400-
try:
401-
return torch.nested.as_nested_tensor(values, layout=torch.jagged)
402-
except (RuntimeError, TypeError) as e:
403-
logger.warning(
404-
f"Failed to pack nested tensor with jagged layout. "
405-
f"Falling back to strided layout. Detailed error: {e}"
406-
)
407-
return torch.nested.as_nested_tensor(values, layout=torch.strided)
407+
non_none = [v for v in values if v is not None]
408+
if non_none and all(isinstance(v, torch.Tensor) for v in non_none):
409+
if len(non_none) == len(values):
410+
# Pure tensor list — try stacking / nested tensor
411+
if all(v.shape == values[0].shape for v in values):
412+
return torch.stack(values)
413+
try:
414+
return torch.nested.as_nested_tensor(values, layout=torch.jagged)
415+
except (RuntimeError, TypeError) as e:
416+
logger.warning(
417+
f"Failed to pack nested tensor with jagged layout. "
418+
f"Falling back to strided layout. Detailed error: {e}"
419+
)
420+
return torch.nested.as_nested_tensor(values, layout=torch.strided)
421+
# Mixed tensor + None — cannot stack, fall through to NonTensorStack
408422
return NonTensorStack(*values)
409423

410424
async def get_data(self, metadata: BatchMeta) -> TensorDict:

0 commit comments

Comments
 (0)