Skip to content

Commit 8079e44

Browse files
committed
fix pcam
1 parent 3e23333 commit 8079e44

File tree

1 file changed

+4
-11
lines changed
  • torchvision/prototype/datasets/_builtin

1 file changed

+4
-11
lines changed

torchvision/prototype/datasets/_builtin/pcam.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
import io
22
from collections import namedtuple
3-
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator
3+
from typing import Any, Dict, List, Optional, Tuple, Iterator
44

5-
import torch
65
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
76
from torchvision.prototype import features
87
from torchvision.prototype.datasets.utils import (
98
Dataset,
109
DatasetConfig,
1110
DatasetInfo,
1211
OnlineResource,
13-
DatasetType,
1412
GDriveResource,
1513
)
1614
from torchvision.prototype.datasets.utils._internal import (
@@ -46,7 +44,6 @@ class PCAM(Dataset):
4644
def _make_info(self) -> DatasetInfo:
4745
return DatasetInfo(
4846
"pcam",
49-
type=DatasetType.RAW,
5047
homepage="https://github.com/basveeling/pcam",
5148
categories=2,
5249
valid_options=dict(split=("train", "test", "val")),
@@ -98,7 +95,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
9895
for file_name, gdrive_id, sha256 in self._RESOURCES[config.split]
9996
]
10097

101-
def _collate_and_decode(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
98+
def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
10299
image, target = data # They're both numpy arrays at this point
103100

104101
return {
@@ -107,11 +104,7 @@ def _collate_and_decode(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
107104
}
108105

109106
def _make_datapipe(
110-
self,
111-
resource_dps: List[IterDataPipe],
112-
*,
113-
config: DatasetConfig,
114-
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
107+
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
115108
) -> IterDataPipe[Dict[str, Any]]:
116109

117110
images_dp, targets_dp = resource_dps
@@ -122,4 +115,4 @@ def _make_datapipe(
122115
dp = Zipper(images_dp, targets_dp)
123116
dp = hint_sharding(dp)
124117
dp = hint_shuffling(dp)
125-
return Mapper(dp, self._collate_and_decode)
118+
return Mapper(dp, self._prepare_sample)

0 commit comments

Comments
 (0)