@@ -72,13 +72,13 @@ class UNet(nn.Module):
72
72
73
73
Args:
74
74
in_channels (int, optional): number of channels in input image
75
- out_channels (int, optional): number of channels in output segmentation
75
+ num_classes (int, optional): number of classes in output segmentation
76
76
start_channels (int, optional): power of 2 channels to start with
77
77
depth (int, optional): number of contractions/expansions
78
78
p (float, optional): dropout probability
79
79
"""
80
80
81
- def __init__ (self , in_channels = 1 , out_channels = 2 , start_channels = 6 ,
81
+ def __init__ (self , in_channels = 1 , num_classes = 2 , start_channels = 6 ,
82
82
depth = 4 , p = 0.5 ):
83
83
super (UNet , self ).__init__ ()
84
84
@@ -96,7 +96,7 @@ def __init__(self, in_channels=1, out_channels=2, start_channels=6,
96
96
Expand (2 ** d , 2 ** (d - 1 )) for d in range (
97
97
start_channels + depth , start_channels , - 1 )
98
98
])
99
- self .conv2 = nn .Conv2d (2 ** start_channels , out_channels , 1 )
99
+ self .conv2 = nn .Conv2d (2 ** start_channels , num_classes , 1 )
100
100
self .softmax = nn .LogSoftmax (dim = 1 )
101
101
102
102
# Initialize weights
0 commit comments