Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions catalyst/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ def _has_str_intersections(origin_string: str, strings: Tuple):
return any(x in origin_string for x in strings)


def _get_batch_size(loader: DataLoader):
batch_size = loader.batch_size
if batch_size is not None:
return batch_size

batch_size = loader.batch_sampler.batch_size
if batch_size is not None:
return batch_size
raise NotImplementedError(
"No `batch_size` found,"
"please specity it throught `loader.batch_size`, or `loader.batch_sampler.batch_size`"
)


class RunnerException(Exception):
"""Exception class for all runner errors."""

Expand Down Expand Up @@ -209,7 +223,7 @@ def __init__(
self.loggers: Dict[str, ILogger] = {}

# the dataflow - model input/output and other batch tensors
self.batch: [Dict, torch.Tensor] = None
self.batch: Dict[str, torch.Tensor] = None

# metrics flow - batch, loader and epoch metrics
self.batch_metrics: BATCH_METRICS = defaultdict(None)
Expand Down Expand Up @@ -660,7 +674,7 @@ def on_loader_start(self, runner: "IRunner"):
self.is_valid_loader: bool = self.loader_key.startswith("valid")
self.is_infer_loader: bool = self.loader_key.startswith("infer")
assert self.is_train_loader or self.is_valid_loader or self.is_infer_loader
self.loader_batch_size: int = self.loader.batch_size
self.loader_batch_size: int = _get_batch_size(self.loader)
self.loader_batch_len: int = len(self.loader)
self.loader_sample_len: int = len(self.loader.dataset)
self.loader_batch_step: int = 0
Expand Down
1 change: 1 addition & 0 deletions catalyst/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from catalyst.data.sampler import (
BalanceClassSampler,
BalanceBatchSampler,
BatchBalanceClassSampler,
DistributedSamplerWrapper,
DynamicLenBatchSampler,
DynamicBalanceClassSampler,
Expand Down
14 changes: 8 additions & 6 deletions catalyst/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ def __getattr__(self, key):
attribute value

Raises:
NotImplementedError: if could not find attribute in ``origin``
or ``wrapper``
NotImplementedError: if could not find attribute in ``origin`` or ``wrapper``
"""
value = getattr(self.origin, key, None)
if value is not None:
some_default_value = "_no_attr_found_"
value = self.origin.__dict__.get(key, some_default_value)
# value = getattr(self.origin, key, None)
if value != some_default_value:
return value
value = getattr(self, key, None)
if value is not None:
value = self.__dict__.get(key, some_default_value)
# value = getattr(self, key, None)
if value != some_default_value:
return value
raise NotImplementedError()

Expand Down
190 changes: 186 additions & 4 deletions catalyst/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import Counter
import logging
from operator import itemgetter
from random import choices, sample
import random

import numpy as np
import torch
Expand All @@ -20,6 +20,46 @@ class BalanceClassSampler(Sampler):
labels: list of class label for each elem in the dataset
mode: Strategy to balance classes.
Must be one of [downsampling, upsampling]

Python API examples:

.. code-block:: python

import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.data import ToTensor, BalanceClassSampler
from catalyst.contrib.datasets import MNIST

train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
train_labels = train_data.targets.cpu().numpy().tolist()
train_sampler = BalanceClassSampler(train_labels, mode=5000)
valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())

loaders = {
"train": DataLoader(train_data, sampler=train_sampler, batch_size=32),
"valid": DataLoader(valid_data, batch_size=32),
}

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

runner = dl.SupervisedRunner()
# model training
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=1,
logdir="./logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
)
"""

def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"):
Expand Down Expand Up @@ -165,24 +205,165 @@ def __iter__(self) -> Iterator[int]:
"""
inds = []

for cls_id in sample(self._classes, self._num_epoch_classes):
for cls_id in random.sample(self._classes, self._num_epoch_classes):
all_cls_inds = find_value_ids(self._labels, cls_id)

# we've checked in __init__ that this value must be > 1
num_samples_exists = len(all_cls_inds)

if num_samples_exists < self._k:
selected_inds = sample(all_cls_inds, k=num_samples_exists) + choices(
selected_inds = random.sample(all_cls_inds, k=num_samples_exists) + random.choices(
all_cls_inds, k=self._k - num_samples_exists
)
else:
selected_inds = sample(all_cls_inds, k=self._k)
selected_inds = random.sample(all_cls_inds, k=self._k)

inds.extend(selected_inds)

return iter(inds)


class BatchBalanceClassSampler(Sampler):
"""
BatchSampler version of BalanceBatchSampler.
This kind of sampler can be used for both metric learning and classification task.

BatchSampler with the given strategy for the C unique classes dataset:
- Selection `num_classes` of C classes for each batch
- Selection `num_samples` instances for each class in the batch
The epoch ends after `num_batches`.
So, the batch sise is `num_classes` * `num_samples`.

One of the purposes of this sampler is to be used for
forming triplets and pos/neg pairs inside the batch.
To guarante existance of these pairs in the batch,
`num_classes` and `num_samples` should be > 1. (1)

This type of sampling can be found in the classical paper of Person Re-Id,
where P (`num_classes`) equals 32 and K (`num_samples`) equals 4:
`In Defense of the Triplet Loss for Person Re-Identification`_.

Args:
labels: list of classes labeles for each elem in the dataset
num_classes: number of classes in a batch, should be > 1
num_samples: number of instances of each class in a batch, should be > 1
num_batches: number of batches in epoch
(default = len(labels) // (num_classes * num_samples))

.. _In Defense of the Triplet Loss for Person Re-Identification:
https://arxiv.org/abs/1703.07737

Python API examples:

.. code-block:: python

import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.data import ToTensor, BatchBalanceClassSampler
from catalyst.contrib.datasets import MNIST

train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
train_labels = train_data.targets.cpu().numpy().tolist()
train_sampler = BatchBalanceClassSampler(train_labels, num_classes=10, num_samples=4)
valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())

loaders = {
"train": DataLoader(train_data, batch_sampler=train_sampler),
"valid": DataLoader(valid_data, batch_size=32),
}

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

runner = dl.SupervisedRunner()
# model training
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=1,
logdir="./logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
)
"""

def __init__(
self,
labels: Union[List[int], np.ndarray],
num_classes: int,
num_samples: int,
num_batches: int = None,
):
"""Sampler initialisation."""
super().__init__(labels)
classes = set(labels)

assert isinstance(num_classes, int) and isinstance(num_samples, int)
assert (1 < num_classes <= len(classes)) and (1 < num_samples)
assert all(
n > 1 for n in Counter(labels).values()
), "Each class shoud contain at least 2 instances to fit (1)"

labels = np.array(labels)
self._labels = list(set(labels.tolist()))
self._num_classes = num_classes
self._num_samples = num_samples
self._batch_size = self._num_classes * self._num_samples
self._num_batches = num_batches or len(labels) // self._batch_size
self.lbl2idx = {
label: np.arange(len(labels))[labels == label].tolist() for label in set(labels)
}

@property
def batch_size(self) -> int:
"""
Returns:
this value should be used in DataLoader as batch size
"""
return self._batch_size

@property
def batches_in_epoch(self) -> int:
"""
Returns:
number of batches in an epoch
"""
return self._num_batches

def __len__(self) -> int:
"""
Returns:
number of samples in an epoch
"""
return self._num_batches # * self._batch_size

def __iter__(self) -> Iterator[int]:
"""
Returns:
indeces for sampling dataset elems during an epoch
"""
indices = []
for _ in range(self._num_batches):
batch_indices = []
classes_for_batch = random.sample(self._labels, self._num_classes)
while self._num_classes != len(set(classes_for_batch)):
classes_for_batch = random.sample(self._labels, self._num_classes)
for cls_id in classes_for_batch:
replace_flag = self._num_samples > len(self.lbl2idx[cls_id])
batch_indices += np.random.choice(
self.lbl2idx[cls_id], self._num_samples, replace=replace_flag
).tolist()
indices.append(batch_indices)
return iter(indices)


class DynamicBalanceClassSampler(Sampler):
"""
This kind of sampler can be used for classification tasks with significant
Expand Down Expand Up @@ -552,6 +733,7 @@ def __iter__(self) -> Iterator[int]:
__all__ = [
"BalanceClassSampler",
"BalanceBatchSampler",
"BatchBalanceClassSampler",
"DistributedSamplerWrapper",
"DynamicBalanceClassSampler",
"DynamicLenBatchSampler",
Expand Down
7 changes: 7 additions & 0 deletions docs/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ BalanceClassSampler
:undoc-members:
:special-members:

BatchBalanceClassSampler
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: catalyst.data.sampler.BatchBalanceClassSampler
:members: __init__
:undoc-members:
:special-members:

DistributedSamplerWrapper
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: catalyst.data.sampler.DistributedSamplerWrapper
Expand Down
48 changes: 47 additions & 1 deletion tests/catalyst/data/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from typing import List, Tuple
from collections import Counter
from operator import itemgetter
import os
from random import randint, shuffle

import numpy as np
import pytest
from torch.utils.data import DataLoader

from catalyst.data.sampler import BalanceBatchSampler, DynamicBalanceClassSampler
from catalyst.contrib.datasets import MNIST
from catalyst.data.sampler import (
BalanceBatchSampler,
BalanceClassSampler,
BatchBalanceClassSampler,
DynamicBalanceClassSampler,
)

TLabelsPK = List[Tuple[List[int], int, int]]

Expand Down Expand Up @@ -35,6 +43,44 @@ def generate_valid_labels(num: int) -> TLabelsPK:
return labels_pk


def test_balance_class_sampler():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
D103 Missing docstring in public function

bs = 32
data = MNIST(os.getcwd(), train=False, download=True)
for mode in ["downsampling", "upsampling", 100, 200, 500]:
sampler = BalanceClassSampler(data.targets.cpu().numpy().tolist(), mode=mode)
loader = DataLoader(data, sampler=sampler, batch_size=bs)
y_list = []
for _, y in loader:
# assert len(x) == len(y) == bs
y_list.extend(y.cpu().numpy().tolist())
# prior
if mode == "downsampling":
mode = 892
if mode == "upsampling":
mode = 1135
assert all(
n == mode for n in Counter(y_list).values()
), f"Each class shoud contain {mode} instances"


def test_batch_balance_class_sampler():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
D103 Missing docstring in public function

data = MNIST(os.getcwd(), train=False, download=True)
for num_classes in [2, 3, 5, 10]:
for num_samples in [2, 5, 10, 50]:
sampler = BatchBalanceClassSampler(
data.targets.cpu().numpy().tolist(),
num_classes=num_classes,
num_samples=num_samples,
)
loader = DataLoader(data, batch_sampler=sampler)
for _, y in loader:
stats = Counter(y.cpu().numpy().tolist())
assert len(stats) == num_classes, f"Each batch shoud contain {num_classes} classes"
assert all(
n == num_samples for n in stats.values()
), f"Each class shoud contain {num_samples} instances"


@pytest.fixture()
def input_for_balance_batch_sampler() -> TLabelsPK:
"""
Expand Down
Loading