From bfa4953b27b68539711ab897afd8a90aa62387da Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 27 Sep 2022 11:16:10 +0200 Subject: [PATCH] fix mypy errors after the 0.981 release --- torchvision/models/_api.py | 13 +++++-------- torchvision/prototype/datasets/utils/_internal.py | 6 +++--- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 22073d9c28d..52ac070e6d3 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -112,10 +112,7 @@ def get_weight(name: str) -> WeightsEnum: return weights_enum.from_str(value_name) -W = TypeVar("W", bound=WeightsEnum) - - -def get_model_weights(name: Union[Callable, str]) -> W: +def get_model_weights(name: Union[Callable, str]) -> WeightsEnum: """ Retuns the weights enum class associated to the given model. @@ -125,10 +122,10 @@ def get_model_weights(name: Union[Callable, str]) -> W: name (callable or str): The model builder function or the name under which it is registered. Returns: - weights_enum (W): The weights enum class associated with the model. + weights_enum (WeightsEnum): The weights enum class associated with the model. """ model = get_model_builder(name) if isinstance(name, str) else name - return cast(W, _get_enum_from_fn(model)) + return _get_enum_from_fn(model) def _get_enum_from_fn(fn: Callable) -> WeightsEnum: @@ -199,7 +196,7 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]: return sorted(models) -def get_model_builder(name: str) -> Callable[..., M]: +def get_model_builder(name: str) -> Callable[..., nn.Module]: """ Gets the model name and returns the model builder method. @@ -219,7 +216,7 @@ def get_model_builder(name: str) -> Callable[..., M]: return fn -def get_model(name: str, **config: Any) -> M: +def get_model(name: str, **config: Any) -> nn.Module: """ Gets the model name and configuration and returns an instantiated model. diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 6768469be67..0385d98c2f5 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -2,7 +2,7 @@ import functools import pathlib import pickle -from typing import Any, BinaryIO, Callable, cast, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union +from typing import Any, BinaryIO, Callable, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union import torch import torch.distributed as dist @@ -72,8 +72,8 @@ def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any: return obj -def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D: - return cast(D, _getattr_closure(path, attrs=name.split("."))) +def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> Any: + return _getattr_closure(path, attrs=name.split(".")) def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D: