Skip to content

Commit 184fb12

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] fix mypy errors after the 0.981 release (#6652)
Reviewed By: YosuaMichael Differential Revision: D39885431 fbshipit-source-id: a5b82afc1a54e78c50f36fadc0a2f9bd977edcb6
1 parent c3eb098 commit 184fb12

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

torchvision/models/_api.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,7 @@ def get_weight(name: str) -> WeightsEnum:
112112
return weights_enum.from_str(value_name)
113113

114114

115-
W = TypeVar("W", bound=WeightsEnum)
116-
117-
118-
def get_model_weights(name: Union[Callable, str]) -> W:
115+
def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
119116
"""
120117
Retuns the weights enum class associated to the given model.
121118
@@ -125,10 +122,10 @@ def get_model_weights(name: Union[Callable, str]) -> W:
125122
name (callable or str): The model builder function or the name under which it is registered.
126123
127124
Returns:
128-
weights_enum (W): The weights enum class associated with the model.
125+
weights_enum (WeightsEnum): The weights enum class associated with the model.
129126
"""
130127
model = get_model_builder(name) if isinstance(name, str) else name
131-
return cast(W, _get_enum_from_fn(model))
128+
return _get_enum_from_fn(model)
132129

133130

134131
def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
@@ -199,7 +196,7 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]:
199196
return sorted(models)
200197

201198

202-
def get_model_builder(name: str) -> Callable[..., M]:
199+
def get_model_builder(name: str) -> Callable[..., nn.Module]:
203200
"""
204201
Gets the model name and returns the model builder method.
205202
@@ -219,7 +216,7 @@ def get_model_builder(name: str) -> Callable[..., M]:
219216
return fn
220217

221218

222-
def get_model(name: str, **config: Any) -> M:
219+
def get_model(name: str, **config: Any) -> nn.Module:
223220
"""
224221
Gets the model name and configuration and returns an instantiated model.
225222

torchvision/prototype/datasets/utils/_internal.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import functools
33
import pathlib
44
import pickle
5-
from typing import Any, BinaryIO, Callable, cast, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union
5+
from typing import Any, BinaryIO, Callable, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union
66

77
import torch
88
import torch.distributed as dist
@@ -72,8 +72,8 @@ def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any:
7272
return obj
7373

7474

75-
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D:
76-
return cast(D, _getattr_closure(path, attrs=name.split(".")))
75+
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> Any:
76+
return _getattr_closure(path, attrs=name.split("."))
7777

7878

7979
def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D:

0 commit comments

Comments
 (0)