Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions catalyst/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from catalyst.data.sampler import (
BalanceClassSampler,
BalanceBatchSampler,
BatchBalanceClassSampler,
DistributedSamplerWrapper,
DynamicLenBatchSampler,
DynamicBalanceClassSampler,
Expand Down
190 changes: 186 additions & 4 deletions catalyst/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import Counter
import logging
from operator import itemgetter
from random import choices, sample
import random

import numpy as np
import torch
Expand All @@ -20,6 +20,46 @@ class BalanceClassSampler(Sampler):
labels: list of class label for each elem in the dataset
mode: Strategy to balance classes.
Must be one of [downsampling, upsampling]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W293 blank line contains whitespace

Python API examples:

.. code-block:: python

import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.data import ToTensor, BalanceClassSampler
from catalyst.contrib.datasets import MNIST

train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
train_labels = train_data.targets.cpu().numpy().tolist()
train_sampler = BalanceClassSampler(train_labels, mode=5000)
valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())

loaders = {
"train": DataLoader(train_data, sampler=train_sampler, batch_size=32),
"valid": DataLoader(valid_data, batch_size=32),
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W293 blank line contains whitespace

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

runner = dl.SupervisedRunner()
# model training
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=1,
logdir="./logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
)
"""

def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"):
Expand Down Expand Up @@ -165,24 +205,165 @@ def __iter__(self) -> Iterator[int]:
"""
inds = []

for cls_id in sample(self._classes, self._num_epoch_classes):
for cls_id in random.sample(self._classes, self._num_epoch_classes):
all_cls_inds = find_value_ids(self._labels, cls_id)

# we've checked in __init__ that this value must be > 1
num_samples_exists = len(all_cls_inds)

if num_samples_exists < self._k:
selected_inds = sample(all_cls_inds, k=num_samples_exists) + choices(
selected_inds = random.sample(all_cls_inds, k=num_samples_exists) + random.choices(
all_cls_inds, k=self._k - num_samples_exists
)
else:
selected_inds = sample(all_cls_inds, k=self._k)
selected_inds = random.sample(all_cls_inds, k=self._k)

inds.extend(selected_inds)

return iter(inds)


class BatchBalanceClassSampler(Sampler):
"""
BatchSampler version of BalanceBatchSampler.
This kind of sampler can be used for both metric learning and classification task.

BatchSampler with the given strategy for the C unique classes dataset:
- Selection `num_classes` of C classes for each batch
- Selection `num_samples` instances for each class in the batch
The epoch ends after `num_batches`.
So, the batch sise is `num_classes` * `num_samples`.

One of the purposes of this sampler is to be used for
forming triplets and pos/neg pairs inside the batch.
To guarante existance of these pairs in the batch,
`num_classes` and `num_samples` should be > 1. (1)

This type of sampling can be found in the classical paper of Person Re-Id,
where P (`num_classes`) equals 32 and K (`num_samples`) equals 4:
`In Defense of the Triplet Loss for Person Re-Identification`_.

Args:
labels: list of classes labeles for each elem in the dataset
num_classes: number of classes in a batch, should be > 1
num_samples: number of instances of each class in a batch, should be > 1
num_batches: number of batches in epoch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W291 trailing whitespace

(default = len(labels) // (num_classes * num_samples))

.. _In Defense of the Triplet Loss for Person Re-Identification:
https://arxiv.org/abs/1703.07737

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W293 blank line contains whitespace

Python API examples:

.. code-block:: python

import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.data import ToTensor, BatchBalanceClassSampler
from catalyst.contrib.datasets import MNIST

train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
train_labels = train_data.targets.cpu().numpy().tolist()
train_sampler = BatchBalanceClassSampler(train_labels, num_classes=10, num_samples=4)
valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())

loaders = {
"train": DataLoader(train_data, batch_sampler=train_sampler),
"valid": DataLoader(valid_data, batch_size=32),
}

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

runner = dl.SupervisedRunner()
# model training
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=1,
logdir="./logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
)
"""

def __init__(
self,
labels: Union[List[int], np.ndarray],
num_classes: int,
num_samples: int,
num_batches: int = None,
):
"""Sampler initialisation."""
super().__init__(labels)
classes = set(labels)

assert isinstance(num_classes, int) and isinstance(num_samples, int)
assert (1 < num_classes <= len(classes)) and (1 < num_samples)
assert all(
n > 1 for n in Counter(labels).values()
), "Each class shoud contain at least 2 instances to fit (1)"

labels = np.array(labels)
self._labels = list(set(labels.tolist()))
self._num_classes = num_classes
self._num_samples = num_samples
self._batch_size = self._num_classes * self._num_samples
self._num_batches = num_batches or len(labels) // self._batch_size
self.lbl2idx = {
label: np.arange(len(labels))[labels == label].tolist() for label in set(labels)
}

@property
def batch_size(self) -> int:
"""
Returns:
this value should be used in DataLoader as batch size
"""
return self._batch_size

@property
def batches_in_epoch(self) -> int:
"""
Returns:
number of batches in an epoch
"""
return self._num_batches

def __len__(self) -> int:
"""
Returns:
number of samples in an epoch
"""
return self._num_batches # * self._batch_size

def __iter__(self) -> Iterator[int]:
"""
Returns:
indeces for sampling dataset elems during an epoch
"""
indices = []
for _ in range(self._num_batches):
batch_indices = []
classes_for_batch = random.sample(self._labels, self._num_classes)
while self._num_classes != len(set(classes_for_batch)):
classes_for_batch = random.sample(self._labels, self._num_classes)
for cls_id in classes_for_batch:
replace_flag = self._num_samples > len(self.lbl2idx[cls_id])
batch_indices += np.random.choice(
self.lbl2idx[cls_id], self._num_samples, replace=replace_flag
).tolist()
indices.append(batch_indices)
return iter(indices)


class DynamicBalanceClassSampler(Sampler):
"""
This kind of sampler can be used for classification tasks with significant
Expand Down Expand Up @@ -552,6 +733,7 @@ def __iter__(self) -> Iterator[int]:
__all__ = [
"BalanceClassSampler",
"BalanceBatchSampler",
"BatchBalanceClassSampler",
"DistributedSamplerWrapper",
"DynamicBalanceClassSampler",
"DynamicLenBatchSampler",
Expand Down