diff --git a/CHANGELOG.md b/CHANGELOG.md index 59867925..64e2ccc1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Fix arguments when instantiating `l2l.nn.Scale`. * Fix `train_loss` logging in `LightningModule` implementations with PyTorch-Lightning 1.5. * Fix `RandomClassRotation` ([#283](https://github.com/learnables/learn2learn/pull/283)) to incorporate multi-channelled inputs. ([Varad Pimpalkhute](https://github.com/nightlessbaron/)) +* Fix memory leak in `maml.py` and `meta-sgd.py` and add tests to `maml_test.py` and `metasgd_test.py` to check for possible future memory leaks. ([#284] (https://github.com/learnables/learn2learn/issues/284)) ([Kevin Zhang] (https://github.com/kzhang2)) ## v0.1.6 diff --git a/learn2learn/algorithms/meta_sgd.py b/learn2learn/algorithms/meta_sgd.py index d9aa9e94..bf0002e5 100644 --- a/learn2learn/algorithms/meta_sgd.py +++ b/learn2learn/algorithms/meta_sgd.py @@ -46,12 +46,16 @@ def meta_sgd_update(model, lrs=None, grads=None): p = model._parameters[param_key] if p is not None and p.grad is not None: model._parameters[param_key] = p - p._lr * p.grad + p.grad = None + p._lr = None # Second, handle the buffers if necessary for buffer_key in model._buffers: buff = model._buffers[buffer_key] if buff is not None and buff.grad is not None and buff._lr is not None: model._buffers[buffer_key] = buff - buff._lr * buff.grad + buff.grad = None + buff._lr = None # Then, recurse for each submodule for module_key in model._modules: diff --git a/learn2learn/utils/__init__.py b/learn2learn/utils/__init__.py index 03592838..c95d690f 100644 --- a/learn2learn/utils/__init__.py +++ b/learn2learn/utils/__init__.py @@ -279,22 +279,24 @@ def update_module(module, updates=None, memo=None): # Update the params for param_key in module._parameters: p = module._parameters[param_key] - if p is not None and hasattr(p, 'update') and p.update is not None: - if p in memo: - module._parameters[param_key] = memo[p] - else: + if p in memo: + module._parameters[param_key] = memo[p] + else: + if p is not None and hasattr(p, 'update') and p.update is not None: updated = p + p.update + p.update = None memo[p] = updated module._parameters[param_key] = updated # Second, handle the buffers if necessary for buffer_key in module._buffers: buff = module._buffers[buffer_key] - if buff is not None and hasattr(buff, 'update') and buff.update is not None: - if buff in memo: - module._buffers[buffer_key] = memo[buff] - else: + if buff in memo: + module._buffers[buffer_key] = memo[buff] + else: + if buff is not None and hasattr(buff, 'update') and buff.update is not None: updated = buff + buff.update + buff.update = None memo[buff] = updated module._buffers[buffer_key] = updated diff --git a/tests/unit/algorithms/maml_test.py b/tests/unit/algorithms/maml_test.py index 9cc54480..41259d9d 100644 --- a/tests/unit/algorithms/maml_test.py +++ b/tests/unit/algorithms/maml_test.py @@ -188,6 +188,37 @@ def forward(self, x): loss = sum(p.norm(p=2) for p in clone.parameters()) loss.backward() + def test_memory_consumption(self): + + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + def get_memory(): + return torch.cuda.memory_allocated(0) + + BSZ = 1024 + INPUT_SIZE = 128 + N_STEPS = 5 + N_EVAL = 5 + + device = torch.device("cuda") + model = torch.nn.Sequential(*[ + torch.nn.Linear(INPUT_SIZE, INPUT_SIZE) for _ in range(10) + ]) + maml = l2l.algorithms.MAML(model, lr=0.0001) + maml.to(device) + + memory_usages = [] + + for evaluation in range(N_EVAL): + learner = maml.clone() + X = torch.randn(BSZ, INPUT_SIZE, device=device) + for step in range(N_STEPS): + learner.adapt(torch.norm(learner(X))) + + memory_usages.append(get_memory()) + + for i in range(1, len(memory_usages)): + self.assertTrue(memory_usages[0] == memory_usages[i]) + if __name__ == '__main__': unittest.main() diff --git a/tests/unit/algorithms/metasgd_test.py b/tests/unit/algorithms/metasgd_test.py index f04c2cb5..0ee71d29 100644 --- a/tests/unit/algorithms/metasgd_test.py +++ b/tests/unit/algorithms/metasgd_test.py @@ -80,6 +80,36 @@ def test_adaptation(self): self.assertTrue(hasattr(p, 'grad')) self.assertTrue(p.grad.norm(p=2).item() > 0.0) + def test_memory_consumption(self): + + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + def get_memory(): + return torch.cuda.memory_allocated(0) + + BSZ = 1024 + INPUT_SIZE = 128 + N_STEPS = 5 + N_EVAL = 5 + + device = torch.device("cuda") + model = torch.nn.Sequential(*[ + torch.nn.Linear(INPUT_SIZE, INPUT_SIZE) for _ in range(10) + ]) + maml = l2l.algorithms.MetaSGD(model, lr=0.0001) + maml.to(device) + + memory_usages = [] + + for evaluation in range(N_EVAL): + learner = maml.clone() + X = torch.randn(BSZ, INPUT_SIZE, device=device) + for step in range(N_STEPS): + learner.adapt(torch.norm(learner(X))) + memory_usages.append(get_memory()) + + for i in range(1, len(memory_usages)): + self.assertTrue(memory_usages[0] == memory_usages[i]) + if __name__ == '__main__': unittest.main() diff --git a/tests/unit/data/metadataset_test.py b/tests/unit/data/metadataset_test.py index 15f9a482..a6425203 100644 --- a/tests/unit/data/metadataset_test.py +++ b/tests/unit/data/metadataset_test.py @@ -62,50 +62,54 @@ def test_labels_to_indices(self): self.assertEqual(dict_label_to_indices[key][0], ord(key) - 97) def test_union_metadataset(self): - for ds_class in [ - l2l.vision.datasets.FC100, - l2l.vision.datasets.CIFARFS, - ]: - datasets = [ - ds_class('~/data', mode='train', download=True), - ds_class('~/data', mode='validation', download=True), - ds_class('~/data', mode='test', download=True), - ] - datasets = [l2l.data.MetaDataset(ds) for ds in datasets] - union = l2l.data.UnionMetaDataset(datasets) - self.assertEqual(len(union), sum([len(ds) for ds in datasets])) - self.assertTrue(len(union.labels) == sum([len(ds.labels) for ds in datasets])) - self.assertTrue(len(union.indices_to_labels) == sum([len(ds.indices_to_labels) for ds in datasets])) - ref = datasets[1][23] - item = union[len(datasets[0]) + 23] - # self.assertTrue(item[1] == ref[1]) # Would fail, because labels are remapped. - self.assertTrue(np.linalg.norm(np.array(item[0]) - np.array(ref[0])) <= 1e-6) - ref = datasets[1][0] - item = union[len(datasets[0]) + 0] - # self.assertTrue(item[1] == ref[1]) # Would fail, because labels are remapped. - self.assertTrue(np.linalg.norm(np.array(item[0]) - np.array(ref[0])) <= 1e-6) + datasets = [ + self.ds.get_mnist(), + self.ds.get_omniglot(), + ] + datasets = [l2l.data.MetaDataset(ds) for ds in datasets] + union = l2l.data.UnionMetaDataset(datasets) + self.assertEqual(len(union), sum([len(ds) for ds in datasets])) + self.assertTrue(len(union.labels) == sum([len(ds.labels) for ds in datasets])) + self.assertTrue(len(union.indices_to_labels) == sum([len(ds.indices_to_labels) for ds in datasets])) + ref = datasets[1][23] + item = union[len(datasets[0]) + 23] + # self.assertTrue(item[1] == ref[1]) # Would fail, because labels are remapped. + self.assertTrue(np.linalg.norm(np.array(item[0]) - np.array(ref[0])) <= 1e-6) + ref = datasets[1][0] + item = union[len(datasets[0]) + 0] + # self.assertTrue(item[1] == ref[1]) # Would fail, because labels are remapped. + self.assertTrue(np.linalg.norm(np.array(item[0]) - np.array(ref[0])) <= 1e-6) def test_filtered_metadataset(self): - for ds_class in [ - l2l.vision.datasets.FC100, - l2l.vision.datasets.CIFARFS, + for dataset in [ + self.ds.get_omniglot(), + self.ds.get_mnist(), ]: - datasets = [ - ds_class('~/data', mode='train', download=True), - ds_class('~/data', mode='validation', download=True), - ds_class('~/data', mode='test', download=True), - ] - datasets = [l2l.data.MetaDataset(ds) for ds in datasets] - union = l2l.data.UnionMetaDataset(datasets) - classes = datasets[1].labels - filtered = l2l.data.FilteredMetaDataset(union, classes) - self.assertEqual(len(filtered.labels), len(datasets[1].labels)) - self.assertEqual(len(filtered), len(datasets[1])) - for label in filtered.labels: - self.assertTrue(label in datasets[1].labels) + dataset = l2l.data.MetaDataset(dataset) + all_classes = dataset.labels + even_classes = [i for i in all_classes if i % 2 == 0] + odd_classes = [i for i in all_classes if i % 2 == 1] + evens = l2l.data.FilteredMetaDataset(dataset, even_classes) + odds = l2l.data.FilteredMetaDataset(dataset, odd_classes) + + self.assertEqual(sorted(even_classes), sorted(evens.labels)) + self.assertEqual(sorted(odd_classes), sorted(odds.labels)) + + union = l2l.data.UnionMetaDataset((evens, odds)) + self.assertEqual(sorted(union.labels), sorted(all_classes)) + + for label in evens.labels: + self.assertTrue(label in even_classes) + self.assertEqual( + len(evens.labels_to_indices[label]), + len(dataset.labels_to_indices[label]) + ) + + for label in odds.labels: + self.assertTrue(label in odd_classes) self.assertEqual( - len(filtered.labels_to_indices[label]), - len(datasets[1].labels_to_indices[label]) + len(odds.labels_to_indices[label]), + len(dataset.labels_to_indices[label]) )