2424
2525
2626class Upsample1D (nn .Module ):
27- """
28- An upsampling layer with an optional convolution.
27+ """A 1D upsampling layer with an optional convolution.
2928
3029 Parameters:
31- channels: channels in the inputs and outputs.
32- use_conv: a bool determining if a convolution is applied.
33- use_conv_transpose:
34- out_channels:
30+ channels (`int`):
31+ number of channels in the inputs and outputs.
32+ use_conv (`bool`, default `False`):
33+ option to use a convolution.
34+ use_conv_transpose (`bool`, default `False`):
35+ option to use a convolution transpose.
36+ out_channels (`int`, optional):
37+ number of output channels. Defaults to `channels`.
3538 """
3639
3740 def __init__ (self , channels , use_conv = False , use_conv_transpose = False , out_channels = None , name = "conv" ):
@@ -62,14 +65,17 @@ def forward(self, x):
6265
6366
6467class Downsample1D (nn .Module ):
65- """
66- A downsampling layer with an optional convolution.
68+ """A 1D downsampling layer with an optional convolution.
6769
6870 Parameters:
69- channels: channels in the inputs and outputs.
70- use_conv: a bool determining if a convolution is applied.
71- out_channels:
72- padding:
71+ channels (`int`):
72+ number of channels in the inputs and outputs.
73+ use_conv (`bool`, default `False`):
74+ option to use a convolution.
75+ out_channels (`int`, optional):
76+ number of output channels. Defaults to `channels`.
77+ padding (`int`, default `1`):
78+ padding for the convolution.
7379 """
7480
7581 def __init__ (self , channels , use_conv = False , out_channels = None , padding = 1 , name = "conv" ):
@@ -93,14 +99,17 @@ def forward(self, x):
9399
94100
95101class Upsample2D (nn .Module ):
96- """
97- An upsampling layer with an optional convolution.
102+ """A 2D upsampling layer with an optional convolution.
98103
99104 Parameters:
100- channels: channels in the inputs and outputs.
101- use_conv: a bool determining if a convolution is applied.
102- use_conv_transpose:
103- out_channels:
105+ channels (`int`):
106+ number of channels in the inputs and outputs.
107+ use_conv (`bool`, default `False`):
108+ option to use a convolution.
109+ use_conv_transpose (`bool`, default `False`):
110+ option to use a convolution transpose.
111+ out_channels (`int`, optional):
112+ number of output channels. Defaults to `channels`.
104113 """
105114
106115 def __init__ (self , channels , use_conv = False , use_conv_transpose = False , out_channels = None , name = "conv" ):
@@ -162,14 +171,17 @@ def forward(self, hidden_states, output_size=None):
162171
163172
164173class Downsample2D (nn .Module ):
165- """
166- A downsampling layer with an optional convolution.
174+ """A 2D downsampling layer with an optional convolution.
167175
168176 Parameters:
169- channels: channels in the inputs and outputs.
170- use_conv: a bool determining if a convolution is applied.
171- out_channels:
172- padding:
177+ channels (`int`):
178+ number of channels in the inputs and outputs.
179+ use_conv (`bool`, default `False`):
180+ option to use a convolution.
181+ out_channels (`int`, optional):
182+ number of output channels. Defaults to `channels`.
183+ padding (`int`, default `1`):
184+ padding for the convolution.
173185 """
174186
175187 def __init__ (self , channels , use_conv = False , out_channels = None , padding = 1 , name = "conv" ):
@@ -209,6 +221,19 @@ def forward(self, hidden_states):
209221
210222
211223class FirUpsample2D (nn .Module ):
224+ """A 2D FIR upsampling layer with an optional convolution.
225+
226+ Parameters:
227+ channels (`int`):
228+ number of channels in the inputs and outputs.
229+ use_conv (`bool`, default `False`):
230+ option to use a convolution.
231+ out_channels (`int`, optional):
232+ number of output channels. Defaults to `channels`.
233+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
234+ kernel for the FIR filter.
235+ """
236+
212237 def __init__ (self , channels = None , out_channels = None , use_conv = False , fir_kernel = (1 , 3 , 3 , 1 )):
213238 super ().__init__ ()
214239 out_channels = out_channels if out_channels else channels
@@ -309,6 +334,19 @@ def forward(self, hidden_states):
309334
310335
311336class FirDownsample2D (nn .Module ):
337+ """A 2D FIR downsampling layer with an optional convolution.
338+
339+ Parameters:
340+ channels (`int`):
341+ number of channels in the inputs and outputs.
342+ use_conv (`bool`, default `False`):
343+ option to use a convolution.
344+ out_channels (`int`, optional):
345+ number of output channels. Defaults to `channels`.
346+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
347+ kernel for the FIR filter.
348+ """
349+
312350 def __init__ (self , channels = None , out_channels = None , use_conv = False , fir_kernel = (1 , 3 , 3 , 1 )):
313351 super ().__init__ ()
314352 out_channels = out_channels if out_channels else channels
0 commit comments