Skip to content

Commit 9f34cdc

Browse files
authored
simple update (#1531)
1 parent d9f48ee commit 9f34cdc

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

test/stateful_dataloader/test_dataloader.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,12 +2433,15 @@ def test_default_collate_shared_tensor(self):
24332433
self.assertEqual(dataloader.default_collate([t_in]).is_shared(), False)
24342434
self.assertEqual(dataloader.default_collate([n_in]).is_shared(), False)
24352435

2436-
# FIXME: fix the following hack that makes `default_collate` believe
2437-
# that it is in a worker process (since it tests
2438-
# `get_worker_info() != None`), even though it is not.
24392436
old = _utils.worker._worker_info
24402437
try:
2441-
_utils.worker._worker_info = "x"
2438+
_utils.worker._worker_info = _utils.worker.WorkerInfo(
2439+
id=0,
2440+
num_workers=1,
2441+
seed=0,
2442+
dataset=self.dataset,
2443+
worker_method="multiprocessing",
2444+
)
24422445
self.assertEqual(dataloader.default_collate([t_in]).is_shared(), True)
24432446
self.assertEqual(dataloader.default_collate([n_in]).is_shared(), True)
24442447
finally:

0 commit comments

Comments
 (0)