diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index b03710c5eb..26b8af9d63 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -9,7 +9,7 @@ - [ ] Did you check that your code passes the unit tests `pytest .` ? - [ ] Did you add your new functionality to the docs? - [ ] Did you update the [CHANGELOG](https://github.com/catalyst-team/catalyst/blob/master/CHANGELOG.md)? -- [ ] Did you run [colab minimal CI/CD](https://colab.research.google.com/drive/1JCGTVvWlrIsLXMPRRRSWiAstSLic4nbA) with `latest` and `minimal` requirements? +- [ ] Did you run [colab minimal CI/CD](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/colab_ci_cd.ipynb) with `latest` and `minimal` requirements? diff --git a/CHANGELOG.md b/CHANGELOG.md index ce7ef7aff2..e39b1acd93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - added `pre-commit` hook to run codestyle checker on commit ([#1257](https://github.com/catalyst-team/catalyst/pull/1257)) -- `on publish` github action for docker and docs added [#1260](https://github.com/catalyst-team/catalyst/pull/1260) +- `on publish` github action for docker and docs added ([#1260](https://github.com/catalyst-team/catalyst/pull/1260)) +- MixupCallback and `utils.mixup_batch` ([#1241](https://github.com/catalyst-team/catalyst/pull/1241)) - Barlow twins loss ([#1259](https://github.com/catalyst-team/catalyst/pull/1259)) +- BatchBalanceClassSampler ([#1262](https://github.com/catalyst-team/catalyst/pull/1262)) ### Changed @@ -25,7 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - make `expdir` in `catalyst-dl run` optional ([#1249](https://github.com/catalyst-team/catalyst/pull/1249)) - Bump neptune-client from 0.9.5 to 0.9.8 in `requirements-neptune.txt` ([#1251](https://github.com/catalyst-team/catalyst/pull/1251)) - automatic merge for master (with [Mergify](https://mergify.io/)) fixed ([#1250](https://github.com/catalyst-team/catalyst/pull/1250)) -- Evaluate loader custom model bug was fixed [#1254](https://github.com/catalyst-team/catalyst/pull/1254) +- Evaluate loader custom model bug was fixed ([#1254](https://github.com/catalyst-team/catalyst/pull/1254)) +- `BatchPrefetchLoaderWrapper` issue with batch-based PyTorch samplers ([#1262](https://github.com/catalyst-team/catalyst/pull/1262)) + ## [21.06] - 2021-06-29 diff --git a/catalyst/core/runner.py b/catalyst/core/runner.py index 384a506822..268b07a4e3 100644 --- a/catalyst/core/runner.py +++ b/catalyst/core/runner.py @@ -42,6 +42,20 @@ def _has_str_intersections(origin_string: str, strings: Tuple): return any(x in origin_string for x in strings) +def _get_batch_size(loader: DataLoader): + batch_size = loader.batch_size + if batch_size is not None: + return batch_size + + batch_size = loader.batch_sampler.batch_size + if batch_size is not None: + return batch_size + raise NotImplementedError( + "No `batch_size` found," + "please specity it throught `loader.batch_size`, or `loader.batch_sampler.batch_size`" + ) + + class RunnerException(Exception): """Exception class for all runner errors.""" @@ -209,7 +223,7 @@ def __init__( self.loggers: Dict[str, ILogger] = {} # the dataflow - model input/output and other batch tensors - self.batch: [Dict, torch.Tensor] = None + self.batch: Dict[str, torch.Tensor] = None # metrics flow - batch, loader and epoch metrics self.batch_metrics: BATCH_METRICS = defaultdict(None) @@ -660,7 +674,7 @@ def on_loader_start(self, runner: "IRunner"): self.is_valid_loader: bool = self.loader_key.startswith("valid") self.is_infer_loader: bool = self.loader_key.startswith("infer") assert self.is_train_loader or self.is_valid_loader or self.is_infer_loader - self.loader_batch_size: int = self.loader.batch_size + self.loader_batch_size: int = _get_batch_size(self.loader) self.loader_batch_len: int = len(self.loader) self.loader_sample_len: int = len(self.loader.dataset) self.loader_batch_step: int = 0 diff --git a/catalyst/data/__init__.py b/catalyst/data/__init__.py index eb7312ac9a..060dde1170 100644 --- a/catalyst/data/__init__.py +++ b/catalyst/data/__init__.py @@ -17,6 +17,7 @@ from catalyst.data.sampler import ( BalanceClassSampler, BalanceBatchSampler, + BatchBalanceClassSampler, DistributedSamplerWrapper, DynamicLenBatchSampler, DynamicBalanceClassSampler, diff --git a/catalyst/data/loader.py b/catalyst/data/loader.py index 7fee033a27..2dbf2f2c96 100644 --- a/catalyst/data/loader.py +++ b/catalyst/data/loader.py @@ -34,14 +34,16 @@ def __getattr__(self, key): attribute value Raises: - NotImplementedError: if could not find attribute in ``origin`` - or ``wrapper`` + NotImplementedError: if could not find attribute in ``origin`` or ``wrapper`` """ - value = getattr(self.origin, key, None) - if value is not None: + some_default_value = "_no_attr_found_" + value = self.origin.__dict__.get(key, some_default_value) + # value = getattr(self.origin, key, None) + if value != some_default_value: return value - value = getattr(self, key, None) - if value is not None: + value = self.__dict__.get(key, some_default_value) + # value = getattr(self, key, None) + if value != some_default_value: return value raise NotImplementedError() diff --git a/catalyst/data/sampler.py b/catalyst/data/sampler.py index ccbb77b3a0..80cd61bc64 100644 --- a/catalyst/data/sampler.py +++ b/catalyst/data/sampler.py @@ -2,7 +2,7 @@ from collections import Counter import logging from operator import itemgetter -from random import choices, sample +import random import numpy as np import torch @@ -20,6 +20,46 @@ class BalanceClassSampler(Sampler): labels: list of class label for each elem in the dataset mode: Strategy to balance classes. Must be one of [downsampling, upsampling] + + Python API examples: + + .. code-block:: python + + import os + from torch import nn, optim + from torch.utils.data import DataLoader + from catalyst import dl + from catalyst.data import ToTensor, BalanceClassSampler + from catalyst.contrib.datasets import MNIST + + train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) + train_labels = train_data.targets.cpu().numpy().tolist() + train_sampler = BalanceClassSampler(train_labels, mode=5000) + valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()) + + loaders = { + "train": DataLoader(train_data, sampler=train_sampler, batch_size=32), + "valid": DataLoader(valid_data, batch_size=32), + } + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.02) + + runner = dl.SupervisedRunner() + # model training + runner.train( + model=model, + criterion=criterion, + optimizer=optimizer, + loaders=loaders, + num_epochs=1, + logdir="./logs", + valid_loader="valid", + valid_metric="loss", + minimize_valid_metric=True, + verbose=True, + ) """ def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"): @@ -165,24 +205,165 @@ def __iter__(self) -> Iterator[int]: """ inds = [] - for cls_id in sample(self._classes, self._num_epoch_classes): + for cls_id in random.sample(self._classes, self._num_epoch_classes): all_cls_inds = find_value_ids(self._labels, cls_id) # we've checked in __init__ that this value must be > 1 num_samples_exists = len(all_cls_inds) if num_samples_exists < self._k: - selected_inds = sample(all_cls_inds, k=num_samples_exists) + choices( + selected_inds = random.sample(all_cls_inds, k=num_samples_exists) + random.choices( all_cls_inds, k=self._k - num_samples_exists ) else: - selected_inds = sample(all_cls_inds, k=self._k) + selected_inds = random.sample(all_cls_inds, k=self._k) inds.extend(selected_inds) return iter(inds) +class BatchBalanceClassSampler(Sampler): + """ + BatchSampler version of BalanceBatchSampler. + This kind of sampler can be used for both metric learning and classification task. + + BatchSampler with the given strategy for the C unique classes dataset: + - Selection `num_classes` of C classes for each batch + - Selection `num_samples` instances for each class in the batch + The epoch ends after `num_batches`. + So, the batch sise is `num_classes` * `num_samples`. + + One of the purposes of this sampler is to be used for + forming triplets and pos/neg pairs inside the batch. + To guarante existance of these pairs in the batch, + `num_classes` and `num_samples` should be > 1. (1) + + This type of sampling can be found in the classical paper of Person Re-Id, + where P (`num_classes`) equals 32 and K (`num_samples`) equals 4: + `In Defense of the Triplet Loss for Person Re-Identification`_. + + Args: + labels: list of classes labeles for each elem in the dataset + num_classes: number of classes in a batch, should be > 1 + num_samples: number of instances of each class in a batch, should be > 1 + num_batches: number of batches in epoch + (default = len(labels) // (num_classes * num_samples)) + + .. _In Defense of the Triplet Loss for Person Re-Identification: + https://arxiv.org/abs/1703.07737 + + Python API examples: + + .. code-block:: python + + import os + from torch import nn, optim + from torch.utils.data import DataLoader + from catalyst import dl + from catalyst.data import ToTensor, BatchBalanceClassSampler + from catalyst.contrib.datasets import MNIST + + train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) + train_labels = train_data.targets.cpu().numpy().tolist() + train_sampler = BatchBalanceClassSampler(train_labels, num_classes=10, num_samples=4) + valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()) + + loaders = { + "train": DataLoader(train_data, batch_sampler=train_sampler), + "valid": DataLoader(valid_data, batch_size=32), + } + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.02) + + runner = dl.SupervisedRunner() + # model training + runner.train( + model=model, + criterion=criterion, + optimizer=optimizer, + loaders=loaders, + num_epochs=1, + logdir="./logs", + valid_loader="valid", + valid_metric="loss", + minimize_valid_metric=True, + verbose=True, + ) + """ + + def __init__( + self, + labels: Union[List[int], np.ndarray], + num_classes: int, + num_samples: int, + num_batches: int = None, + ): + """Sampler initialisation.""" + super().__init__(labels) + classes = set(labels) + + assert isinstance(num_classes, int) and isinstance(num_samples, int) + assert (1 < num_classes <= len(classes)) and (1 < num_samples) + assert all( + n > 1 for n in Counter(labels).values() + ), "Each class shoud contain at least 2 instances to fit (1)" + + labels = np.array(labels) + self._labels = list(set(labels.tolist())) + self._num_classes = num_classes + self._num_samples = num_samples + self._batch_size = self._num_classes * self._num_samples + self._num_batches = num_batches or len(labels) // self._batch_size + self.lbl2idx = { + label: np.arange(len(labels))[labels == label].tolist() for label in set(labels) + } + + @property + def batch_size(self) -> int: + """ + Returns: + this value should be used in DataLoader as batch size + """ + return self._batch_size + + @property + def batches_in_epoch(self) -> int: + """ + Returns: + number of batches in an epoch + """ + return self._num_batches + + def __len__(self) -> int: + """ + Returns: + number of samples in an epoch + """ + return self._num_batches # * self._batch_size + + def __iter__(self) -> Iterator[int]: + """ + Returns: + indeces for sampling dataset elems during an epoch + """ + indices = [] + for _ in range(self._num_batches): + batch_indices = [] + classes_for_batch = random.sample(self._labels, self._num_classes) + while self._num_classes != len(set(classes_for_batch)): + classes_for_batch = random.sample(self._labels, self._num_classes) + for cls_id in classes_for_batch: + replace_flag = self._num_samples > len(self.lbl2idx[cls_id]) + batch_indices += np.random.choice( + self.lbl2idx[cls_id], self._num_samples, replace=replace_flag + ).tolist() + indices.append(batch_indices) + return iter(indices) + + class DynamicBalanceClassSampler(Sampler): """ This kind of sampler can be used for classification tasks with significant @@ -552,6 +733,7 @@ def __iter__(self) -> Iterator[int]: __all__ = [ "BalanceClassSampler", "BalanceBatchSampler", + "BatchBalanceClassSampler", "DistributedSamplerWrapper", "DynamicBalanceClassSampler", "DynamicLenBatchSampler", diff --git a/docs/api/data.rst b/docs/api/data.rst index b672c99794..b9f6dd60f4 100644 --- a/docs/api/data.rst +++ b/docs/api/data.rst @@ -157,6 +157,13 @@ BalanceClassSampler :undoc-members: :special-members: +BatchBalanceClassSampler +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: catalyst.data.sampler.BatchBalanceClassSampler + :members: __init__ + :undoc-members: + :special-members: + DistributedSamplerWrapper ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: catalyst.data.sampler.DistributedSamplerWrapper diff --git a/tests/catalyst/data/test_sampler.py b/tests/catalyst/data/test_sampler.py index fc293cea8e..824ab39862 100644 --- a/tests/catalyst/data/test_sampler.py +++ b/tests/catalyst/data/test_sampler.py @@ -1,12 +1,20 @@ from typing import List, Tuple from collections import Counter from operator import itemgetter +import os from random import randint, shuffle import numpy as np import pytest +from torch.utils.data import DataLoader -from catalyst.data.sampler import BalanceBatchSampler, DynamicBalanceClassSampler +from catalyst.contrib.datasets import MNIST +from catalyst.data.sampler import ( + BalanceBatchSampler, + BalanceClassSampler, + BatchBalanceClassSampler, + DynamicBalanceClassSampler, +) TLabelsPK = List[Tuple[List[int], int, int]] @@ -35,6 +43,46 @@ def generate_valid_labels(num: int) -> TLabelsPK: return labels_pk +def test_balance_class_sampler(): + """Test for BalanceClassSampler.""" + bs = 32 + data = MNIST(os.getcwd(), train=False, download=True) + for mode in ["downsampling", "upsampling", 100, 200, 500]: + sampler = BalanceClassSampler(data.targets.cpu().numpy().tolist(), mode=mode) + loader = DataLoader(data, sampler=sampler, batch_size=bs) + y_list = [] + for _, y in loader: + # assert len(x) == len(y) == bs + y_list.extend(y.cpu().numpy().tolist()) + # prior + if mode == "downsampling": + mode = 892 + if mode == "upsampling": + mode = 1135 + assert all( + n == mode for n in Counter(y_list).values() + ), f"Each class shoud contain {mode} instances" + + +def test_batch_balance_class_sampler(): + """Test for BatchBalanceClassSampler.""" + data = MNIST(os.getcwd(), train=False, download=True) + for num_classes in [2, 3, 5, 10]: + for num_samples in [2, 5, 10, 50]: + sampler = BatchBalanceClassSampler( + data.targets.cpu().numpy().tolist(), + num_classes=num_classes, + num_samples=num_samples, + ) + loader = DataLoader(data, batch_sampler=sampler) + for _, y in loader: + stats = Counter(y.cpu().numpy().tolist()) + assert len(stats) == num_classes, f"Each batch shoud contain {num_classes} classes" + assert all( + n == num_samples for n in stats.values() + ), f"Each class shoud contain {num_samples} instances" + + @pytest.fixture() def input_for_balance_batch_sampler() -> TLabelsPK: """ diff --git a/tests/pipelines/test_data.py b/tests/pipelines/test_data.py new file mode 100644 index 0000000000..d8edc99a94 --- /dev/null +++ b/tests/pipelines/test_data.py @@ -0,0 +1,140 @@ +# flake8: noqa +import os + +from pytest import mark +from torch import nn, optim +from torch.utils.data import DataLoader + +from catalyst import dl +from catalyst.contrib.datasets import MNIST +from catalyst.data import ( + BalanceClassSampler, + BatchBalanceClassSampler, + BatchPrefetchLoaderWrapper, + ToTensor, +) +from catalyst.settings import IS_CUDA_AVAILABLE + + +def test_balance_class_sampler(): + train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) + train_labels = train_data.targets.cpu().numpy().tolist() + train_sampler = BalanceClassSampler(train_labels, mode=5000) + valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()) + + loaders = { + "train": DataLoader(train_data, sampler=train_sampler, batch_size=32), + "valid": DataLoader(valid_data, batch_size=32), + } + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.02) + + runner = dl.SupervisedRunner() + runner.train( + model=model, + criterion=criterion, + optimizer=optimizer, + loaders=loaders, + num_epochs=1, + logdir="./logs", + valid_loader="valid", + valid_metric="loss", + minimize_valid_metric=True, + verbose=True, + ) + + +def test_batch_balance_class_sampler(): + train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) + train_labels = train_data.targets.cpu().numpy().tolist() + train_sampler = BatchBalanceClassSampler(train_labels, num_classes=10, num_samples=4) + valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()) + + loaders = { + "train": DataLoader(train_data, batch_sampler=train_sampler), + "valid": DataLoader(valid_data, batch_size=32), + } + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.02) + + runner = dl.SupervisedRunner() + runner.train( + model=model, + criterion=criterion, + optimizer=optimizer, + loaders=loaders, + num_epochs=1, + logdir="./logs", + valid_loader="valid", + valid_metric="loss", + minimize_valid_metric=True, + verbose=True, + ) + + +@mark.skipif(not IS_CUDA_AVAILABLE, reason="CUDA device is not available") +def test_balance_class_sampler_with_prefetch(): + train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) + train_labels = train_data.targets.cpu().numpy().tolist() + train_sampler = BalanceClassSampler(train_labels, mode=5000) + valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()) + + loaders = { + "train": DataLoader(train_data, sampler=train_sampler, batch_size=32), + "valid": DataLoader(valid_data, batch_size=32), + } + loaders = {k: BatchPrefetchLoaderWrapper(v) for k, v in loaders.items()} + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.02) + + runner = dl.SupervisedRunner() + runner.train( + model=model, + criterion=criterion, + optimizer=optimizer, + loaders=loaders, + num_epochs=1, + logdir="./logs", + valid_loader="valid", + valid_metric="loss", + minimize_valid_metric=True, + verbose=True, + ) + + +@mark.skipif(not IS_CUDA_AVAILABLE, reason="CUDA device is not available") +def test_batch_balance_class_sampler_with_prefetch(): + train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) + train_labels = train_data.targets.cpu().numpy().tolist() + train_sampler = BatchBalanceClassSampler(train_labels, num_classes=10, num_samples=4) + valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()) + + loaders = { + "train": DataLoader(train_data, batch_sampler=train_sampler), + "valid": DataLoader(valid_data, batch_size=32), + } + loaders = {k: BatchPrefetchLoaderWrapper(v) for k, v in loaders.items()} + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.02) + + runner = dl.SupervisedRunner() + runner.train( + model=model, + criterion=criterion, + optimizer=optimizer, + loaders=loaders, + num_epochs=1, + logdir="./logs", + valid_loader="valid", + valid_metric="loss", + minimize_valid_metric=True, + verbose=True, + )