Skip to content

Commit 46ce434

Browse files
authored
extra sampler + docs (#1262)
* extra sampler + docs * and now with the docs :) * codestyle * tests * tests * tests
1 parent 2ffaab9 commit 46ce434

File tree

9 files changed

+414
-16
lines changed

9 files changed

+414
-16
lines changed

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
- [ ] Did you check that your code passes the unit tests `pytest .` ?
1010
- [ ] Did you add your new functionality to the docs?
1111
- [ ] Did you update the [CHANGELOG](https://github.com/catalyst-team/catalyst/blob/master/CHANGELOG.md)?
12-
- [ ] Did you run [colab minimal CI/CD](https://colab.research.google.com/drive/1JCGTVvWlrIsLXMPRRRSWiAstSLic4nbA) with `latest` and `minimal` requirements?
12+
- [ ] Did you run [colab minimal CI/CD](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/colab_ci_cd.ipynb) with `latest` and `minimal` requirements?
1313

1414
<!-- For CHANGELOG separate each item in unreleased section by blank line to reduce collisions -->
1515

CHANGELOG.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99
### Added
1010

1111
- added `pre-commit` hook to run codestyle checker on commit ([#1257](https://github.com/catalyst-team/catalyst/pull/1257))
12-
- `on publish` github action for docker and docs added [#1260](https://github.com/catalyst-team/catalyst/pull/1260)
12+
- `on publish` github action for docker and docs added ([#1260](https://github.com/catalyst-team/catalyst/pull/1260))
13+
- MixupCallback and `utils.mixup_batch` ([#1241](https://github.com/catalyst-team/catalyst/pull/1241))
1314
- Barlow twins loss ([#1259](https://github.com/catalyst-team/catalyst/pull/1259))
15+
- BatchBalanceClassSampler ([#1262](https://github.com/catalyst-team/catalyst/pull/1262))
1416

1517
### Changed
1618

@@ -25,7 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2527
- make `expdir` in `catalyst-dl run` optional ([#1249](https://github.com/catalyst-team/catalyst/pull/1249))
2628
- Bump neptune-client from 0.9.5 to 0.9.8 in `requirements-neptune.txt` ([#1251](https://github.com/catalyst-team/catalyst/pull/1251))
2729
- automatic merge for master (with [Mergify](https://mergify.io/)) fixed ([#1250](https://github.com/catalyst-team/catalyst/pull/1250))
28-
- Evaluate loader custom model bug was fixed [#1254](https://github.com/catalyst-team/catalyst/pull/1254)
30+
- Evaluate loader custom model bug was fixed ([#1254](https://github.com/catalyst-team/catalyst/pull/1254))
31+
- `BatchPrefetchLoaderWrapper` issue with batch-based PyTorch samplers ([#1262](https://github.com/catalyst-team/catalyst/pull/1262))
32+
2933

3034
## [21.06] - 2021-06-29
3135

catalyst/core/runner.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,20 @@ def _has_str_intersections(origin_string: str, strings: Tuple):
4242
return any(x in origin_string for x in strings)
4343

4444

45+
def _get_batch_size(loader: DataLoader):
46+
batch_size = loader.batch_size
47+
if batch_size is not None:
48+
return batch_size
49+
50+
batch_size = loader.batch_sampler.batch_size
51+
if batch_size is not None:
52+
return batch_size
53+
raise NotImplementedError(
54+
"No `batch_size` found,"
55+
"please specity it throught `loader.batch_size`, or `loader.batch_sampler.batch_size`"
56+
)
57+
58+
4559
class RunnerException(Exception):
4660
"""Exception class for all runner errors."""
4761

@@ -209,7 +223,7 @@ def __init__(
209223
self.loggers: Dict[str, ILogger] = {}
210224

211225
# the dataflow - model input/output and other batch tensors
212-
self.batch: [Dict, torch.Tensor] = None
226+
self.batch: Dict[str, torch.Tensor] = None
213227

214228
# metrics flow - batch, loader and epoch metrics
215229
self.batch_metrics: BATCH_METRICS = defaultdict(None)
@@ -660,7 +674,7 @@ def on_loader_start(self, runner: "IRunner"):
660674
self.is_valid_loader: bool = self.loader_key.startswith("valid")
661675
self.is_infer_loader: bool = self.loader_key.startswith("infer")
662676
assert self.is_train_loader or self.is_valid_loader or self.is_infer_loader
663-
self.loader_batch_size: int = self.loader.batch_size
677+
self.loader_batch_size: int = _get_batch_size(self.loader)
664678
self.loader_batch_len: int = len(self.loader)
665679
self.loader_sample_len: int = len(self.loader.dataset)
666680
self.loader_batch_step: int = 0

catalyst/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from catalyst.data.sampler import (
1818
BalanceClassSampler,
1919
BalanceBatchSampler,
20+
BatchBalanceClassSampler,
2021
DistributedSamplerWrapper,
2122
DynamicLenBatchSampler,
2223
DynamicBalanceClassSampler,

catalyst/data/loader.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@ def __getattr__(self, key):
3434
attribute value
3535
3636
Raises:
37-
NotImplementedError: if could not find attribute in ``origin``
38-
or ``wrapper``
37+
NotImplementedError: if could not find attribute in ``origin`` or ``wrapper``
3938
"""
40-
value = getattr(self.origin, key, None)
41-
if value is not None:
39+
some_default_value = "_no_attr_found_"
40+
value = self.origin.__dict__.get(key, some_default_value)
41+
# value = getattr(self.origin, key, None)
42+
if value != some_default_value:
4243
return value
43-
value = getattr(self, key, None)
44-
if value is not None:
44+
value = self.__dict__.get(key, some_default_value)
45+
# value = getattr(self, key, None)
46+
if value != some_default_value:
4547
return value
4648
raise NotImplementedError()
4749

catalyst/data/sampler.py

Lines changed: 186 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import Counter
33
import logging
44
from operator import itemgetter
5-
from random import choices, sample
5+
import random
66

77
import numpy as np
88
import torch
@@ -20,6 +20,46 @@ class BalanceClassSampler(Sampler):
2020
labels: list of class label for each elem in the dataset
2121
mode: Strategy to balance classes.
2222
Must be one of [downsampling, upsampling]
23+
24+
Python API examples:
25+
26+
.. code-block:: python
27+
28+
import os
29+
from torch import nn, optim
30+
from torch.utils.data import DataLoader
31+
from catalyst import dl
32+
from catalyst.data import ToTensor, BalanceClassSampler
33+
from catalyst.contrib.datasets import MNIST
34+
35+
train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
36+
train_labels = train_data.targets.cpu().numpy().tolist()
37+
train_sampler = BalanceClassSampler(train_labels, mode=5000)
38+
valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())
39+
40+
loaders = {
41+
"train": DataLoader(train_data, sampler=train_sampler, batch_size=32),
42+
"valid": DataLoader(valid_data, batch_size=32),
43+
}
44+
45+
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
46+
criterion = nn.CrossEntropyLoss()
47+
optimizer = optim.Adam(model.parameters(), lr=0.02)
48+
49+
runner = dl.SupervisedRunner()
50+
# model training
51+
runner.train(
52+
model=model,
53+
criterion=criterion,
54+
optimizer=optimizer,
55+
loaders=loaders,
56+
num_epochs=1,
57+
logdir="./logs",
58+
valid_loader="valid",
59+
valid_metric="loss",
60+
minimize_valid_metric=True,
61+
verbose=True,
62+
)
2363
"""
2464

2565
def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"):
@@ -165,24 +205,165 @@ def __iter__(self) -> Iterator[int]:
165205
"""
166206
inds = []
167207

168-
for cls_id in sample(self._classes, self._num_epoch_classes):
208+
for cls_id in random.sample(self._classes, self._num_epoch_classes):
169209
all_cls_inds = find_value_ids(self._labels, cls_id)
170210

171211
# we've checked in __init__ that this value must be > 1
172212
num_samples_exists = len(all_cls_inds)
173213

174214
if num_samples_exists < self._k:
175-
selected_inds = sample(all_cls_inds, k=num_samples_exists) + choices(
215+
selected_inds = random.sample(all_cls_inds, k=num_samples_exists) + random.choices(
176216
all_cls_inds, k=self._k - num_samples_exists
177217
)
178218
else:
179-
selected_inds = sample(all_cls_inds, k=self._k)
219+
selected_inds = random.sample(all_cls_inds, k=self._k)
180220

181221
inds.extend(selected_inds)
182222

183223
return iter(inds)
184224

185225

226+
class BatchBalanceClassSampler(Sampler):
227+
"""
228+
BatchSampler version of BalanceBatchSampler.
229+
This kind of sampler can be used for both metric learning and classification task.
230+
231+
BatchSampler with the given strategy for the C unique classes dataset:
232+
- Selection `num_classes` of C classes for each batch
233+
- Selection `num_samples` instances for each class in the batch
234+
The epoch ends after `num_batches`.
235+
So, the batch sise is `num_classes` * `num_samples`.
236+
237+
One of the purposes of this sampler is to be used for
238+
forming triplets and pos/neg pairs inside the batch.
239+
To guarante existance of these pairs in the batch,
240+
`num_classes` and `num_samples` should be > 1. (1)
241+
242+
This type of sampling can be found in the classical paper of Person Re-Id,
243+
where P (`num_classes`) equals 32 and K (`num_samples`) equals 4:
244+
`In Defense of the Triplet Loss for Person Re-Identification`_.
245+
246+
Args:
247+
labels: list of classes labeles for each elem in the dataset
248+
num_classes: number of classes in a batch, should be > 1
249+
num_samples: number of instances of each class in a batch, should be > 1
250+
num_batches: number of batches in epoch
251+
(default = len(labels) // (num_classes * num_samples))
252+
253+
.. _In Defense of the Triplet Loss for Person Re-Identification:
254+
https://arxiv.org/abs/1703.07737
255+
256+
Python API examples:
257+
258+
.. code-block:: python
259+
260+
import os
261+
from torch import nn, optim
262+
from torch.utils.data import DataLoader
263+
from catalyst import dl
264+
from catalyst.data import ToTensor, BatchBalanceClassSampler
265+
from catalyst.contrib.datasets import MNIST
266+
267+
train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
268+
train_labels = train_data.targets.cpu().numpy().tolist()
269+
train_sampler = BatchBalanceClassSampler(train_labels, num_classes=10, num_samples=4)
270+
valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())
271+
272+
loaders = {
273+
"train": DataLoader(train_data, batch_sampler=train_sampler),
274+
"valid": DataLoader(valid_data, batch_size=32),
275+
}
276+
277+
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
278+
criterion = nn.CrossEntropyLoss()
279+
optimizer = optim.Adam(model.parameters(), lr=0.02)
280+
281+
runner = dl.SupervisedRunner()
282+
# model training
283+
runner.train(
284+
model=model,
285+
criterion=criterion,
286+
optimizer=optimizer,
287+
loaders=loaders,
288+
num_epochs=1,
289+
logdir="./logs",
290+
valid_loader="valid",
291+
valid_metric="loss",
292+
minimize_valid_metric=True,
293+
verbose=True,
294+
)
295+
"""
296+
297+
def __init__(
298+
self,
299+
labels: Union[List[int], np.ndarray],
300+
num_classes: int,
301+
num_samples: int,
302+
num_batches: int = None,
303+
):
304+
"""Sampler initialisation."""
305+
super().__init__(labels)
306+
classes = set(labels)
307+
308+
assert isinstance(num_classes, int) and isinstance(num_samples, int)
309+
assert (1 < num_classes <= len(classes)) and (1 < num_samples)
310+
assert all(
311+
n > 1 for n in Counter(labels).values()
312+
), "Each class shoud contain at least 2 instances to fit (1)"
313+
314+
labels = np.array(labels)
315+
self._labels = list(set(labels.tolist()))
316+
self._num_classes = num_classes
317+
self._num_samples = num_samples
318+
self._batch_size = self._num_classes * self._num_samples
319+
self._num_batches = num_batches or len(labels) // self._batch_size
320+
self.lbl2idx = {
321+
label: np.arange(len(labels))[labels == label].tolist() for label in set(labels)
322+
}
323+
324+
@property
325+
def batch_size(self) -> int:
326+
"""
327+
Returns:
328+
this value should be used in DataLoader as batch size
329+
"""
330+
return self._batch_size
331+
332+
@property
333+
def batches_in_epoch(self) -> int:
334+
"""
335+
Returns:
336+
number of batches in an epoch
337+
"""
338+
return self._num_batches
339+
340+
def __len__(self) -> int:
341+
"""
342+
Returns:
343+
number of samples in an epoch
344+
"""
345+
return self._num_batches # * self._batch_size
346+
347+
def __iter__(self) -> Iterator[int]:
348+
"""
349+
Returns:
350+
indeces for sampling dataset elems during an epoch
351+
"""
352+
indices = []
353+
for _ in range(self._num_batches):
354+
batch_indices = []
355+
classes_for_batch = random.sample(self._labels, self._num_classes)
356+
while self._num_classes != len(set(classes_for_batch)):
357+
classes_for_batch = random.sample(self._labels, self._num_classes)
358+
for cls_id in classes_for_batch:
359+
replace_flag = self._num_samples > len(self.lbl2idx[cls_id])
360+
batch_indices += np.random.choice(
361+
self.lbl2idx[cls_id], self._num_samples, replace=replace_flag
362+
).tolist()
363+
indices.append(batch_indices)
364+
return iter(indices)
365+
366+
186367
class DynamicBalanceClassSampler(Sampler):
187368
"""
188369
This kind of sampler can be used for classification tasks with significant
@@ -552,6 +733,7 @@ def __iter__(self) -> Iterator[int]:
552733
__all__ = [
553734
"BalanceClassSampler",
554735
"BalanceBatchSampler",
736+
"BatchBalanceClassSampler",
555737
"DistributedSamplerWrapper",
556738
"DynamicBalanceClassSampler",
557739
"DynamicLenBatchSampler",

docs/api/data.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,13 @@ BalanceClassSampler
157157
:undoc-members:
158158
:special-members:
159159

160+
BatchBalanceClassSampler
161+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
162+
.. autoclass:: catalyst.data.sampler.BatchBalanceClassSampler
163+
:members: __init__
164+
:undoc-members:
165+
:special-members:
166+
160167
DistributedSamplerWrapper
161168
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
162169
.. autoclass:: catalyst.data.sampler.DistributedSamplerWrapper

0 commit comments

Comments
 (0)