Skip to content

fix mypy errors after the 0.981 release #6652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,7 @@ def get_weight(name: str) -> WeightsEnum:
return weights_enum.from_str(value_name)


W = TypeVar("W", bound=WeightsEnum)


def get_model_weights(name: Union[Callable, str]) -> W:
def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
"""
Retuns the weights enum class associated to the given model.

Expand All @@ -125,10 +122,10 @@ def get_model_weights(name: Union[Callable, str]) -> W:
name (callable or str): The model builder function or the name under which it is registered.

Returns:
weights_enum (W): The weights enum class associated with the model.
weights_enum (WeightsEnum): The weights enum class associated with the model.
"""
model = get_model_builder(name) if isinstance(name, str) else name
return cast(W, _get_enum_from_fn(model))
return _get_enum_from_fn(model)


def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
Expand Down Expand Up @@ -199,7 +196,7 @@ def list_models(module: Optional[ModuleType] = None) -> List[str]:
return sorted(models)


def get_model_builder(name: str) -> Callable[..., M]:
def get_model_builder(name: str) -> Callable[..., nn.Module]:
Copy link
Member

@NicolasHug NicolasHug Sep 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity why do we need to change this one, and not all the other usages of M e.g. in register_model() above?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeVar is used to create a "link" between parameters. For example I can do something like

def do_nothing(foo: M) -> M:
    return foo

This tells mypy "whatever we have as input type, we return as output type as well". This is useful in case the function can handle multiple types, e.g.

T = TypeVar("T", int, float)

def add(a: T, b: T) -> T:
    return a + b

In the case above, there is no input with the type M and so mypy complains that it can't figure out the M in the output.

Correct me if I'm wrong @datumbox, but I guess the annotation was intended to mean "we return any kind of model" here. Since M is defined as

M = TypeVar("M", bound=nn.Module)

we can simply use nn.Module here.

"""
Gets the model name and returns the model builder method.

Expand All @@ -219,7 +216,7 @@ def get_model_builder(name: str) -> Callable[..., M]:
return fn


def get_model(name: str, **config: Any) -> M:
def get_model(name: str, **config: Any) -> nn.Module:
"""
Gets the model name and configuration and returns an instantiated model.

Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import functools
import pathlib
import pickle
from typing import Any, BinaryIO, Callable, cast, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union
from typing import Any, BinaryIO, Callable, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -72,8 +72,8 @@ def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any:
return obj


def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D:
return cast(D, _getattr_closure(path, attrs=name.split(".")))
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> Any:
return _getattr_closure(path, attrs=name.split("."))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change needed?
Do you remember why we annotated the return with D in the first place?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change needed?

Same as above. We no longer can use a TypeVar as return annotation if we don't have any input with the same type.

Do you remember why we annotated the return with D in the first place?

IIRC, this is a remnant when this function was a local one inside

def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[Tuple[str, Any]], D]:

In there the annotation was correct and ok, since we have a D in the input parameters. At some point we factored them out since torchdata doesn't like local functions, and I probably forgot to fix this since mypy didn't complain.



def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D:
Expand Down