Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions test/nodes/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase
from torchdata.nodes import IterableWrapper
from torchdata.nodes.batch import Batcher, Unbatcher

from .utils import MockSource, run_test_save_load_state
Expand All @@ -28,6 +29,11 @@ def test_batcher(self) -> None:
self.assertEqual(results[i][j]["test_tensor"], torch.tensor([i * batch_size + j]))
self.assertEqual(results[i][j]["test_str"], f"str_{i * batch_size + j}")

def test_batcher_batch_size_zero_raises(self):
source = IterableWrapper(range(10))
with self.assertRaises(ValueError):
Batcher(source, batch_size=0)

def test_batcher_drop_last_false(self) -> None:
batch_size = 6
src = MockSource(num_samples=20)
Expand Down
2 changes: 2 additions & 0 deletions torchdata/nodes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class Batcher(BaseNode[List[T]]):

def __init__(self, source: BaseNode[T], batch_size: int, drop_last: bool = True):
super().__init__()
if batch_size <= 0:
raise ValueError("batch_size must be a positive integer")
self.source = source
self.batch_size = batch_size
self.drop_last = drop_last
Expand Down