Skip to content

Weights enums cannot be pickled #7099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
mike0sv opened this issue Jan 17, 2023 · 1 comment · Fixed by #7107
Closed

Weights enums cannot be pickled #7099

mike0sv opened this issue Jan 17, 2023 · 1 comment · Fixed by #7107

Comments

@mike0sv
Copy link

mike0sv commented Jan 17, 2023

🐛 Describe the bug

Trying to unplickle pickled weights enum object yields an error inside enum implementation (it cannot match constructed object with existing enum values). Here is reproducible example.

import pickle

from torchvision.models import ResNet50_Weights


def main():
    w = ResNet50_Weights.DEFAULT
    kek = pickle.dumps(w)
    lol = pickle.loads(kek)
    print(lol)


if __name__ == '__main__':
    main()

This happens because internally enum tries to match newly constructed enum value with declared values. And this fails because Weights dataclass has transforms field which is a callable. And if that callable is a functools.partial, it is technically a different object, even though it was constructed from the same original function with the same partial arguments. This snippet fixes the error, so you should get the idea for probable fix

import functools
from torchvision.models._api import Weights

def new_eq(self, other):
        if not isinstance(other, type(self)):
            return False
        if self.meta != other.meta or self.url != other.url:
            return False
        if not isinstance(self.transforms, functools.partial) or not isinstance(other.transforms, functools.partial):
            return self.transforms == other.transforms
        return all(getattr(self.transforms, a) == getattr(other.transforms, a) for a in ["func", "args", "keywords", "__dict__"])

Weights.__eq__ = new_eq

Versions

Collecting environment information...
PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.0.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: version 3.25.0
Libc version: N/A

Python version: 3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:00:33)  [Clang 13.0.1 ] (64-bit runtime)
Python platform: macOS-13.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.971
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.4
[pip3] torch==1.12.1
[pip3] torchvision==0.13.1
[conda] numpy                     1.22.4           py39h7df2422_0    conda-forge
[conda] torch                     1.12.1                   pypi_0    pypi
[conda] torchvision               0.13.1                   pypi_0    pypi
@NicolasHug
Copy link
Member

Thanks for the report @mike0sv . @pmeier and I are looking into this.

Note for self: the error of the snippet above is

ValueError: Weights(url=...) is not a valid ResNet50_Weights

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants