diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 5505f5b5e4e..068d3e238f9 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -1,5 +1,6 @@ import copy import os +import pickle import pytest import test_models as TM @@ -73,10 +74,32 @@ def test_get_model_weights(name, weight): ], ) def test_weights_copyable(copy_fn, name): - model_weights = models.get_model_weights(name) - for weights in list(model_weights): - copied_weights = copy_fn(weights) - assert copied_weights is weights + for weights in list(models.get_model_weights(name)): + # It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior + # of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances + # Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop + # support for the identity operation in the future. + assert copy_fn(weights) is weights + + +@pytest.mark.parametrize( + "name", + [ + "resnet50", + "retinanet_resnet50_fpn_v2", + "raft_large", + "quantized_resnet50", + "lraspp_mobilenet_v3_large", + "mvit_v1_b", + ], +) +def test_weights_deserializable(name): + for weights in list(models.get_model_weights(name)): + # It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior + # of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances + # Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop + # support for the identity operation in the future. + assert pickle.loads(pickle.dumps(weights)) is weights @pytest.mark.parametrize( diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 992ebbbaeb2..7c9ef341508 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -2,6 +2,7 @@ import inspect import sys from dataclasses import dataclass, fields +from functools import partial from inspect import signature from types import ModuleType from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union @@ -37,6 +38,32 @@ class Weights: transforms: Callable meta: Dict[str, Any] + def __eq__(self, other: Any) -> bool: + # We need this custom implementation for correct deep-copy and deserialization behavior. + # TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it, + # involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often + # defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling + # for it, the check against the defined members would fail and effectively prevent the weights from being + # deep-copied or deserialized. + # See https://github.com/pytorch/vision/pull/7107 for details. + if not isinstance(other, Weights): + return NotImplemented + + if self.url != other.url: + return False + + if self.meta != other.meta: + return False + + if isinstance(self.transforms, partial) and isinstance(other.transforms, partial): + return ( + self.transforms.func == other.transforms.func + and self.transforms.args == other.transforms.args + and self.transforms.keywords == other.transforms.keywords + ) + else: + return self.transforms == other.transforms + class WeightsEnum(StrEnum): """ @@ -75,9 +102,6 @@ def __getattr__(self, name): return object.__getattribute__(self.value, name) return super().__getattr__(name) - def __deepcopy__(self, memodict=None): - return self - def get_weight(name: str) -> WeightsEnum: """