Skip to content

Commit e02d5dc

Browse files
committed
out_channels -> num_classes
1 parent 4d570a1 commit e02d5dc

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchvision/models/unet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ class UNet(nn.Module):
7272
7373
Args:
7474
in_channels (int, optional): number of channels in input image
75-
out_channels (int, optional): number of channels in output segmentation
75+
num_classes (int, optional): number of classes in output segmentation
7676
start_channels (int, optional): power of 2 channels to start with
7777
depth (int, optional): number of contractions/expansions
7878
p (float, optional): dropout probability
7979
"""
8080

81-
def __init__(self, in_channels=1, out_channels=2, start_channels=6,
81+
def __init__(self, in_channels=1, num_classes=2, start_channels=6,
8282
depth=4, p=0.5):
8383
super(UNet, self).__init__()
8484

@@ -96,7 +96,7 @@ def __init__(self, in_channels=1, out_channels=2, start_channels=6,
9696
Expand(2 ** d, 2 ** (d - 1)) for d in range(
9797
start_channels + depth, start_channels, -1)
9898
])
99-
self.conv2 = nn.Conv2d(2 ** start_channels, out_channels, 1)
99+
self.conv2 = nn.Conv2d(2 ** start_channels, num_classes, 1)
100100
self.softmax = nn.LogSoftmax(dim=1)
101101

102102
# Initialize weights

0 commit comments

Comments
 (0)