Skip to content

Commit 1c1d558

Browse files
author
Vincent Moens
authored
[BugFix] Fix Atari DQN ensembling (#1981)
1 parent aa7a690 commit 1c1d558

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

torchrl/data/datasets/atari_dqn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,9 @@ def get(self, index):
745745
if isinstance(index, torch.Tensor):
746746
if index.ndim <= 1:
747747
return self._read_from_splits(index)
748+
elif index.shape[1] == 1:
749+
index = index.squeeze(1)
750+
return self.get(index)
748751
else:
749752
raise RuntimeError("Only 1d tensors are accepted")
750753
# with ThreadPoolExecutor(16) as pool:

torchrl/data/replay_buffers/samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1565,7 +1565,7 @@ def sample(self, storage, batch_size):
15651565
]
15661566
)
15671567
samples = [
1568-
sample if isinstance(sample, torch.Tensor) else torch.tensor(sample)
1568+
sample if isinstance(sample, torch.Tensor) else torch.stack(sample, -1)
15691569
for sample in samples
15701570
]
15711571
if all(samples[0].shape == sample.shape for sample in samples[1:]):

0 commit comments

Comments
 (0)