From 5bf0f04f6ea837326c94a3fe12e6dca300b46316 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 18 Jan 2023 08:29:56 +0100 Subject: [PATCH 01/10] [PoC] properly support deepcopying and serialization of model weights --- torchvision/models/_api.py | 44 ++++++++++++++++++++++++++++++------ torchvision/models/resnet.py | 4 ++-- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 992ebbbaeb2..e965ef4fcfe 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import importlib import inspect import sys from dataclasses import dataclass, fields from inspect import signature from types import ModuleType -from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union +from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Tuple, TypeVar, Union from torch import nn @@ -12,8 +14,39 @@ from .._internally_replaced_utils import load_state_dict_from_url - -__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"] +__all__ = [ + "TransformsFactory", + "WeightsEnum", + "Weights", + "get_model", + "get_model_builder", + "get_model_weights", + "get_weight", + "list_models", +] + + +@dataclass(init=False) +class TransformsFactory: + fn: Callable[..., nn.Module] + args: Tuple[Any, ...] + kwargs: Dict[str, Any] + + def __init__(self, fn: Callable[..., nn.Module], *args: Any, **kwargs: Any) -> None: + self.fn = fn + self.args = args + self.kwargs = kwargs + + def __call__(self) -> nn.Module: + return self.fn(*self.args, **self.kwargs) + + # FIXME: it seems we don't even need this. I'll leave it here until I'm sure we don't + # @classmethod + # def __simple_new__(cls, fn, args, kwargs) -> TransformsFactory: + # return cls(fn, *args, **kwargs) + # + # def __reduce_ex__(self, protocol: int): + # return self.__simple_new__, (self.fn, self.args, self.kwargs) @dataclass @@ -34,7 +67,7 @@ class Weights: """ url: str - transforms: Callable + transforms: TransformsFactory meta: Dict[str, Any] @@ -75,9 +108,6 @@ def __getattr__(self, name): return object.__getattribute__(self.value, name) return super().__getattr__(name) - def __deepcopy__(self, memodict=None): - return self - def get_weight(name: str) -> WeightsEnum: """ diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 1d3638917fe..4cb03811b73 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -7,7 +7,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import register_model, Weights, WeightsEnum +from ._api import register_model, TransformsFactory, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -356,7 +356,7 @@ class ResNet34_Weights(WeightsEnum): class ResNet50_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", - transforms=partial(ImageClassification, crop_size=224), + transforms=TransformsFactory(ImageClassification, crop_size=123), meta={ **_COMMON_META, "num_params": 25557032, From d8c2226edd1d179fb33cceef28b0f4029127e18c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 18 Jan 2023 08:31:31 +0100 Subject: [PATCH 02/10] cleanup --- torchvision/models/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 4cb03811b73..87b946355fc 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -356,7 +356,7 @@ class ResNet34_Weights(WeightsEnum): class ResNet50_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", - transforms=TransformsFactory(ImageClassification, crop_size=123), + transforms=TransformsFactory(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 25557032, From 7ee05e6fba114906c16735681d2d4364aac5ea00 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 19 Jan 2023 13:09:02 +0100 Subject: [PATCH 03/10] use custom equality definition rather than factory --- torchvision/models/_api.py | 48 +++++++++++++++++------------------- torchvision/models/resnet.py | 4 +-- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index e965ef4fcfe..bf4efe35906 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -1,12 +1,13 @@ from __future__ import annotations +import functools import importlib import inspect import sys from dataclasses import dataclass, fields from inspect import signature from types import ModuleType -from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union from torch import nn @@ -15,7 +16,6 @@ from .._internally_replaced_utils import load_state_dict_from_url __all__ = [ - "TransformsFactory", "WeightsEnum", "Weights", "get_model", @@ -26,29 +26,6 @@ ] -@dataclass(init=False) -class TransformsFactory: - fn: Callable[..., nn.Module] - args: Tuple[Any, ...] - kwargs: Dict[str, Any] - - def __init__(self, fn: Callable[..., nn.Module], *args: Any, **kwargs: Any) -> None: - self.fn = fn - self.args = args - self.kwargs = kwargs - - def __call__(self) -> nn.Module: - return self.fn(*self.args, **self.kwargs) - - # FIXME: it seems we don't even need this. I'll leave it here until I'm sure we don't - # @classmethod - # def __simple_new__(cls, fn, args, kwargs) -> TransformsFactory: - # return cls(fn, *args, **kwargs) - # - # def __reduce_ex__(self, protocol: int): - # return self.__simple_new__, (self.fn, self.args, self.kwargs) - - @dataclass class Weights: """ @@ -67,9 +44,28 @@ class Weights: """ url: str - transforms: TransformsFactory + transforms: Callable meta: Dict[str, Any] + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Weights): + return NotImplemented + + if self.url != other.url: + return False + + if self.meta != other.meta: + return False + + if isinstance(self.transforms, functools.partial) and isinstance(other.transforms, functools.partial): + return ( + self.transforms.func == other.transforms.func + and self.transforms.args == other.transforms.args + and self.transforms.keywords == other.transforms.keywords + ) + else: + return self.transforms == other.transforms + class WeightsEnum(StrEnum): """ diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 87b946355fc..1d3638917fe 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -7,7 +7,7 @@ from ..transforms._presets import ImageClassification from ..utils import _log_api_usage_once -from ._api import register_model, TransformsFactory, Weights, WeightsEnum +from ._api import register_model, Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param, handle_legacy_interface @@ -356,7 +356,7 @@ class ResNet34_Weights(WeightsEnum): class ResNet50_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", - transforms=TransformsFactory(ImageClassification, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 25557032, From e5076b98679cc741cbfba0fdf8c5f3bfd841a4a0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 19 Jan 2023 13:10:17 +0100 Subject: [PATCH 04/10] cleanup --- torchvision/models/_api.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index bf4efe35906..530f33cb170 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -1,10 +1,8 @@ -from __future__ import annotations - -import functools import importlib import inspect import sys from dataclasses import dataclass, fields +from functools import partial from inspect import signature from types import ModuleType from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union @@ -15,15 +13,7 @@ from .._internally_replaced_utils import load_state_dict_from_url -__all__ = [ - "WeightsEnum", - "Weights", - "get_model", - "get_model_builder", - "get_model_weights", - "get_weight", - "list_models", -] +__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"] @dataclass @@ -57,7 +47,7 @@ def __eq__(self, other: Any) -> bool: if self.meta != other.meta: return False - if isinstance(self.transforms, functools.partial) and isinstance(other.transforms, functools.partial): + if isinstance(self.transforms, partial) and isinstance(other.transforms, partial): return ( self.transforms.func == other.transforms.func and self.transforms.args == other.transforms.args From 9692c5ca147049dc551e27885557e62dc464edfa Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 19 Jan 2023 13:22:53 +0100 Subject: [PATCH 05/10] more cleanup --- torchvision/models/_api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 530f33cb170..acfe9ab18c8 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -13,6 +13,7 @@ from .._internally_replaced_utils import load_state_dict_from_url + __all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"] From 30d5e7750ddcccea26583a06c1f93391d740d3de Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 19 Jan 2023 13:36:20 +0100 Subject: [PATCH 06/10] add explanation --- torchvision/models/_api.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index acfe9ab18c8..30ae2f28f15 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -39,6 +39,11 @@ class Weights: meta: Dict[str, Any] def __eq__(self, other: Any) -> bool: + # We need this custom implementation for correct deep-copy and deserialization behavior. + # TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it, + # involves an equality check against the defined members. Unfortunately, + # `fn = partial(...); assert deepcopy(fn) != fn` and thus this check fails without custom handling. + # See https://github.com/pytorch/vision/pull/7107 for details. if not isinstance(other, Weights): return NotImplemented From 8e291cb238a1cbb313cbe51938fdda97432676ff Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 19 Jan 2023 13:40:13 +0100 Subject: [PATCH 07/10] add tests --- test/test_extended_models.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 5505f5b5e4e..23215e4c76e 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -1,5 +1,6 @@ import copy import os +import pickle import pytest import test_models as TM @@ -73,10 +74,24 @@ def test_get_model_weights(name, weight): ], ) def test_weights_copyable(copy_fn, name): - model_weights = models.get_model_weights(name) - for weights in list(model_weights): - copied_weights = copy_fn(weights) - assert copied_weights is weights + for weights in list(models.get_model_weights(name)): + assert copy_fn(weights) == weights + + +@pytest.mark.parametrize( + "name", + [ + "resnet50", + "retinanet_resnet50_fpn_v2", + "raft_large", + "quantized_resnet50", + "lraspp_mobilenet_v3_large", + "mvit_v1_b", + ], +) +def test_weights_de_serializable(name): + for weights in list(models.get_model_weights(name)): + assert pickle.loads(pickle.dumps(weights)) == weights @pytest.mark.parametrize( From 527de17866f5c381954443503b86aae0fdec05be Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 19 Jan 2023 14:14:55 +0100 Subject: [PATCH 08/10] imporve comment --- torchvision/models/_api.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 30ae2f28f15..7c9ef341508 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -41,8 +41,10 @@ class Weights: def __eq__(self, other: Any) -> bool: # We need this custom implementation for correct deep-copy and deserialization behavior. # TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it, - # involves an equality check against the defined members. Unfortunately, - # `fn = partial(...); assert deepcopy(fn) != fn` and thus this check fails without custom handling. + # involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often + # defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling + # for it, the check against the defined members would fail and effectively prevent the weights from being + # deep-copied or deserialized. # See https://github.com/pytorch/vision/pull/7107 for details. if not isinstance(other, Weights): return NotImplemented From 568f0439d93adec7fcf9b950e402a66b6078d31f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 19 Jan 2023 14:18:42 +0100 Subject: [PATCH 09/10] add test comments --- test/test_extended_models.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 23215e4c76e..7c86cfa9837 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -75,7 +75,10 @@ def test_get_model_weights(name, weight): ) def test_weights_copyable(copy_fn, name): for weights in list(models.get_model_weights(name)): - assert copy_fn(weights) == weights + # It is somewhat surprising that (deep-)copying is an identity operation here, + # but this is the default behavior of enums. + # See https://github.com/pytorch/vision/pull/7107 for details. + assert copy_fn(weights) is weights @pytest.mark.parametrize( @@ -89,9 +92,12 @@ def test_weights_copyable(copy_fn, name): "mvit_v1_b", ], ) -def test_weights_de_serializable(name): +def test_weights_deserializable(name): for weights in list(models.get_model_weights(name)): - assert pickle.loads(pickle.dumps(weights)) == weights + # It is somewhat surprising that deserialization is an identity operation here, + # but this is the default behavior of enums. + # See https://github.com/pytorch/vision/pull/7107 for details. + assert pickle.loads(pickle.dumps(weights)) is weights @pytest.mark.parametrize( From 2b558f6647e06f868c9c8291d533d5c8bc2e0561 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 19 Jan 2023 14:39:18 +0100 Subject: [PATCH 10/10] improve comments --- test/test_extended_models.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 7c86cfa9837..068d3e238f9 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -75,9 +75,10 @@ def test_get_model_weights(name, weight): ) def test_weights_copyable(copy_fn, name): for weights in list(models.get_model_weights(name)): - # It is somewhat surprising that (deep-)copying is an identity operation here, - # but this is the default behavior of enums. - # See https://github.com/pytorch/vision/pull/7107 for details. + # It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior + # of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances + # Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop + # support for the identity operation in the future. assert copy_fn(weights) is weights @@ -94,9 +95,10 @@ def test_weights_copyable(copy_fn, name): ) def test_weights_deserializable(name): for weights in list(models.get_model_weights(name)): - # It is somewhat surprising that deserialization is an identity operation here, - # but this is the default behavior of enums. - # See https://github.com/pytorch/vision/pull/7107 for details. + # It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior + # of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances + # Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop + # support for the identity operation in the future. assert pickle.loads(pickle.dumps(weights)) is weights