@@ -300,7 +300,8 @@ def forward(self, hidden_states):
300300 hidden_states = F .pad (hidden_states , (self .pad ,) * 2 , self .pad_mode )
301301 weight = hidden_states .new_zeros ([hidden_states .shape [1 ], hidden_states .shape [1 ], self .kernel .shape [0 ]])
302302 indices = torch .arange (hidden_states .shape [1 ], device = hidden_states .device )
303- weight [indices , indices ] = self .kernel .to (weight )
303+ kernel = self .kernel .to (weight )[None , :].expand (hidden_states .shape [1 ], - 1 )
304+ weight [indices , indices ] = kernel
304305 return F .conv1d (hidden_states , weight , stride = 2 )
305306
306307
@@ -316,7 +317,8 @@ def forward(self, hidden_states, temb=None):
316317 hidden_states = F .pad (hidden_states , ((self .pad + 1 ) // 2 ,) * 2 , self .pad_mode )
317318 weight = hidden_states .new_zeros ([hidden_states .shape [1 ], hidden_states .shape [1 ], self .kernel .shape [0 ]])
318319 indices = torch .arange (hidden_states .shape [1 ], device = hidden_states .device )
319- weight [indices , indices ] = self .kernel .to (weight )
320+ kernel = self .kernel .to (weight )[None , :].expand (hidden_states .shape [1 ], - 1 )
321+ weight [indices , indices ] = kernel
320322 return F .conv_transpose1d (hidden_states , weight , stride = 2 , padding = self .pad * 2 + 1 )
321323
322324
0 commit comments