diff --git a/mmcls/datasets/dataset_wrappers.py b/mmcls/datasets/dataset_wrappers.py index 93de60f6084..645b7f00243 100644 --- a/mmcls/datasets/dataset_wrappers.py +++ b/mmcls/datasets/dataset_wrappers.py @@ -175,17 +175,20 @@ class ClassBalancedDataset(object): 1. For each category c, compute the fraction :math:`f(c)` of images that contain it. - 2. For each category c, compute the category-level repeat factor + 2. For each category c, compute the category-level repeat factor. .. math:: r(c) = \max(1, \sqrt{\frac{t}{f(c)}}) - + + where :math:`t` is `oversample_thr`. 3. For each image I and its labels :math:`L(I)`, compute the image-level - repeat factor + repeat factor. .. math:: r(I) = \max_{c \in L(I)} r(c) + Each image repeats :math:`\lceil r(I) \rceil` times. + Args: dataset (:obj:`BaseDataset`): The dataset to be repeated. oversample_thr (float): frequency threshold below which data is @@ -214,8 +217,8 @@ def __init__(self, dataset, oversample_thr): self.flag = np.asarray(flags, dtype=np.uint8) def _get_repeat_factors(self, dataset, repeat_thr): - # 1. For each category c, compute the fraction # of images - # that contain it: f(c) + # 1. For each category c, compute the fraction of images + # that contain it: f(c) category_freq = defaultdict(int) num_images = len(dataset) for idx in range(num_images): @@ -227,15 +230,15 @@ def _get_repeat_factors(self, dataset, repeat_thr): category_freq[k] = v / num_images # 2. For each category c, compute the category-level repeat factor: - # r(c) = max(1, sqrt(t/f(c))) + # r(c) = max(1, sqrt(t/f(c))) category_repeat = { cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq)) for cat_id, cat_freq in category_freq.items() } # 3. For each image I and its labels L(I), compute the image-level - # repeat factor: - # r(I) = max_{c in L(I)} r(c) + # repeat factor: + # r(I) = max_{c in L(I)} r(c) repeat_factors = [] for idx in range(num_images): cat_ids = set(self.dataset.get_cat_ids(idx))