Skip to content

Commit 95dedaf

Browse files
committed
Adding typing.
1 parent 4502083 commit 95dedaf

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchvision/models/efficientnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
}
2020

2121

22-
def drop_connect(x: Tensor, rate: float):
22+
def drop_connect(x: Tensor, rate: float) -> Tensor:
2323
keep = torch.rand(size=(x.size(0), ), dtype=x.dtype, device=x.device) > rate
2424
keep = keep[(None, ) * (x.ndim - 1)].T
2525
return (x / (1.0 - rate)) * keep
@@ -29,7 +29,7 @@ class MBConvConfig:
2929
def __init__(self,
3030
kernel: int, stride: int, dilation: int,
3131
input_channels: int, out_channels: int, expand_ratio: float,
32-
width_mult: float):
32+
width_mult: float) -> None:
3333
self.kernel = kernel
3434
self.stride = stride
3535
self.dilation = dilation
@@ -38,13 +38,13 @@ def __init__(self,
3838
self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult)
3939

4040
@staticmethod
41-
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None):
41+
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
4242
return _make_divisible(channels * width_mult, 8, min_value)
4343

4444

4545
class MBConv(nn.Module):
4646
def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module],
47-
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
47+
se_layer: Callable[..., nn.Module] = SqueezeExcitation) -> None:
4848
super().__init__()
4949
if not (1 <= cnf.stride <= 2):
5050
raise ValueError('illegal stride value')

0 commit comments

Comments
 (0)