@@ -433,7 +433,8 @@ def forward(self, x):
433
433
x = F .pad (x , (self .pad ,) * 4 , self .pad_mode )
434
434
weight = x .new_zeros ([x .shape [1 ], x .shape [1 ], self .kernel .shape [0 ], self .kernel .shape [1 ]])
435
435
indices = torch .arange (x .shape [1 ], device = x .device )
436
- weight [indices , indices ] = self .kernel .to (weight )
436
+ kernel = self .kernel .to (weight )[None , :].expand (x .shape [1 ], - 1 , - 1 )
437
+ weight [indices , indices ] = kernel
437
438
return F .conv2d (x , weight , stride = 2 )
438
439
439
440
@@ -449,7 +450,8 @@ def forward(self, x):
449
450
x = F .pad (x , ((self .pad + 1 ) // 2 ,) * 4 , self .pad_mode )
450
451
weight = x .new_zeros ([x .shape [1 ], x .shape [1 ], self .kernel .shape [0 ], self .kernel .shape [1 ]])
451
452
indices = torch .arange (x .shape [1 ], device = x .device )
452
- weight [indices , indices ] = self .kernel .to (weight )
453
+ kernel = self .kernel .to (weight )[None , :].expand (x .shape [1 ], - 1 , - 1 )
454
+ weight [indices , indices ] = kernel
453
455
return F .conv_transpose2d (x , weight , stride = 2 , padding = self .pad * 2 + 1 )
454
456
455
457
0 commit comments