Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Fix arguments when instantiating `l2l.nn.Scale`.
* Fix `train_loss` logging in `LightningModule` implementations with PyTorch-Lightning 1.5.
* Fix `RandomClassRotation` ([#283](https://github.com/learnables/learn2learn/pull/283)) to incorporate multi-channelled inputs. ([Varad Pimpalkhute](https://github.com/nightlessbaron/))
* Fix memory leak in `maml.py` and `meta-sgd.py` and add tests to `maml_test.py` and `metasgd_test.py` to check for possible future memory leaks. ([#284] (https://github.com/learnables/learn2learn/issues/284)) ([Kevin Zhang] (https://github.com/kzhang2))


## v0.1.6
Expand Down
4 changes: 4 additions & 0 deletions learn2learn/algorithms/meta_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,16 @@ def meta_sgd_update(model, lrs=None, grads=None):
p = model._parameters[param_key]
if p is not None and p.grad is not None:
model._parameters[param_key] = p - p._lr * p.grad
p.grad = None
p._lr = None

# Second, handle the buffers if necessary
for buffer_key in model._buffers:
buff = model._buffers[buffer_key]
if buff is not None and buff.grad is not None and buff._lr is not None:
model._buffers[buffer_key] = buff - buff._lr * buff.grad
buff.grad = None
buff._lr = None

# Then, recurse for each submodule
for module_key in model._modules:
Expand Down
18 changes: 10 additions & 8 deletions learn2learn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,22 +279,24 @@ def update_module(module, updates=None, memo=None):
# Update the params
for param_key in module._parameters:
p = module._parameters[param_key]
if p is not None and hasattr(p, 'update') and p.update is not None:
if p in memo:
module._parameters[param_key] = memo[p]
else:
if p in memo:
module._parameters[param_key] = memo[p]
else:
if p is not None and hasattr(p, 'update') and p.update is not None:
updated = p + p.update
p.update = None
memo[p] = updated
module._parameters[param_key] = updated

# Second, handle the buffers if necessary
for buffer_key in module._buffers:
buff = module._buffers[buffer_key]
if buff is not None and hasattr(buff, 'update') and buff.update is not None:
if buff in memo:
module._buffers[buffer_key] = memo[buff]
else:
if buff in memo:
module._buffers[buffer_key] = memo[buff]
else:
if buff is not None and hasattr(buff, 'update') and buff.update is not None:
updated = buff + buff.update
buff.update = None
memo[buff] = updated
module._buffers[buffer_key] = updated

Expand Down
31 changes: 31 additions & 0 deletions tests/unit/algorithms/maml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,37 @@ def forward(self, x):
loss = sum(p.norm(p=2) for p in clone.parameters())
loss.backward()

def test_memory_consumption(self):

if torch.cuda.is_available() and torch.cuda.device_count() > 0:
def get_memory():
return torch.cuda.memory_allocated(0)

BSZ = 1024
INPUT_SIZE = 128
N_STEPS = 5
N_EVAL = 5

device = torch.device("cuda")
model = torch.nn.Sequential(*[
torch.nn.Linear(INPUT_SIZE, INPUT_SIZE) for _ in range(10)
])
maml = l2l.algorithms.MAML(model, lr=0.0001)
maml.to(device)

memory_usages = []

for evaluation in range(N_EVAL):
learner = maml.clone()
X = torch.randn(BSZ, INPUT_SIZE, device=device)
for step in range(N_STEPS):
learner.adapt(torch.norm(learner(X)))

memory_usages.append(get_memory())

for i in range(1, len(memory_usages)):
self.assertTrue(memory_usages[0] == memory_usages[i])


if __name__ == '__main__':
unittest.main()
30 changes: 30 additions & 0 deletions tests/unit/algorithms/metasgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,36 @@ def test_adaptation(self):
self.assertTrue(hasattr(p, 'grad'))
self.assertTrue(p.grad.norm(p=2).item() > 0.0)

def test_memory_consumption(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kzhang2 and shield this test too.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


if torch.cuda.is_available() and torch.cuda.device_count() > 0:
def get_memory():
return torch.cuda.memory_allocated(0)

BSZ = 1024
INPUT_SIZE = 128
N_STEPS = 5
N_EVAL = 5

device = torch.device("cuda")
model = torch.nn.Sequential(*[
torch.nn.Linear(INPUT_SIZE, INPUT_SIZE) for _ in range(10)
])
maml = l2l.algorithms.MetaSGD(model, lr=0.0001)
maml.to(device)

memory_usages = []

for evaluation in range(N_EVAL):
learner = maml.clone()
X = torch.randn(BSZ, INPUT_SIZE, device=device)
for step in range(N_STEPS):
learner.adapt(torch.norm(learner(X)))
memory_usages.append(get_memory())

for i in range(1, len(memory_usages)):
self.assertTrue(memory_usages[0] == memory_usages[i])


if __name__ == '__main__':
unittest.main()
84 changes: 44 additions & 40 deletions tests/unit/data/metadataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,50 +62,54 @@ def test_labels_to_indices(self):
self.assertEqual(dict_label_to_indices[key][0], ord(key) - 97)

def test_union_metadataset(self):
for ds_class in [
l2l.vision.datasets.FC100,
l2l.vision.datasets.CIFARFS,
]:
datasets = [
ds_class('~/data', mode='train', download=True),
ds_class('~/data', mode='validation', download=True),
ds_class('~/data', mode='test', download=True),
]
datasets = [l2l.data.MetaDataset(ds) for ds in datasets]
union = l2l.data.UnionMetaDataset(datasets)
self.assertEqual(len(union), sum([len(ds) for ds in datasets]))
self.assertTrue(len(union.labels) == sum([len(ds.labels) for ds in datasets]))
self.assertTrue(len(union.indices_to_labels) == sum([len(ds.indices_to_labels) for ds in datasets]))
ref = datasets[1][23]
item = union[len(datasets[0]) + 23]
# self.assertTrue(item[1] == ref[1]) # Would fail, because labels are remapped.
self.assertTrue(np.linalg.norm(np.array(item[0]) - np.array(ref[0])) <= 1e-6)
ref = datasets[1][0]
item = union[len(datasets[0]) + 0]
# self.assertTrue(item[1] == ref[1]) # Would fail, because labels are remapped.
self.assertTrue(np.linalg.norm(np.array(item[0]) - np.array(ref[0])) <= 1e-6)
datasets = [
self.ds.get_mnist(),
self.ds.get_omniglot(),
]
datasets = [l2l.data.MetaDataset(ds) for ds in datasets]
union = l2l.data.UnionMetaDataset(datasets)
self.assertEqual(len(union), sum([len(ds) for ds in datasets]))
self.assertTrue(len(union.labels) == sum([len(ds.labels) for ds in datasets]))
self.assertTrue(len(union.indices_to_labels) == sum([len(ds.indices_to_labels) for ds in datasets]))
ref = datasets[1][23]
item = union[len(datasets[0]) + 23]
# self.assertTrue(item[1] == ref[1]) # Would fail, because labels are remapped.
self.assertTrue(np.linalg.norm(np.array(item[0]) - np.array(ref[0])) <= 1e-6)
ref = datasets[1][0]
item = union[len(datasets[0]) + 0]
# self.assertTrue(item[1] == ref[1]) # Would fail, because labels are remapped.
self.assertTrue(np.linalg.norm(np.array(item[0]) - np.array(ref[0])) <= 1e-6)

def test_filtered_metadataset(self):
for ds_class in [
l2l.vision.datasets.FC100,
l2l.vision.datasets.CIFARFS,
for dataset in [
self.ds.get_omniglot(),
self.ds.get_mnist(),
]:
datasets = [
ds_class('~/data', mode='train', download=True),
ds_class('~/data', mode='validation', download=True),
ds_class('~/data', mode='test', download=True),
]
datasets = [l2l.data.MetaDataset(ds) for ds in datasets]
union = l2l.data.UnionMetaDataset(datasets)
classes = datasets[1].labels
filtered = l2l.data.FilteredMetaDataset(union, classes)
self.assertEqual(len(filtered.labels), len(datasets[1].labels))
self.assertEqual(len(filtered), len(datasets[1]))
for label in filtered.labels:
self.assertTrue(label in datasets[1].labels)
dataset = l2l.data.MetaDataset(dataset)
all_classes = dataset.labels
even_classes = [i for i in all_classes if i % 2 == 0]
odd_classes = [i for i in all_classes if i % 2 == 1]
evens = l2l.data.FilteredMetaDataset(dataset, even_classes)
odds = l2l.data.FilteredMetaDataset(dataset, odd_classes)

self.assertEqual(sorted(even_classes), sorted(evens.labels))
self.assertEqual(sorted(odd_classes), sorted(odds.labels))

union = l2l.data.UnionMetaDataset((evens, odds))
self.assertEqual(sorted(union.labels), sorted(all_classes))

for label in evens.labels:
self.assertTrue(label in even_classes)
self.assertEqual(
len(evens.labels_to_indices[label]),
len(dataset.labels_to_indices[label])
)

for label in odds.labels:
self.assertTrue(label in odd_classes)
self.assertEqual(
len(filtered.labels_to_indices[label]),
len(datasets[1].labels_to_indices[label])
len(odds.labels_to_indices[label]),
len(dataset.labels_to_indices[label])
)


Expand Down