Skip to content

Commit db56f8a

Browse files
explicit broadcasts for assignments (#3535)
1 parent c13dbd5 commit db56f8a

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/diffusers/models/resnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,8 @@ def forward(self, x):
433433
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
434434
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
435435
indices = torch.arange(x.shape[1], device=x.device)
436-
weight[indices, indices] = self.kernel.to(weight)
436+
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
437+
weight[indices, indices] = kernel
437438
return F.conv2d(x, weight, stride=2)
438439

439440

@@ -449,7 +450,8 @@ def forward(self, x):
449450
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
450451
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
451452
indices = torch.arange(x.shape[1], device=x.device)
452-
weight[indices, indices] = self.kernel.to(weight)
453+
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
454+
weight[indices, indices] = kernel
453455
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
454456

455457

0 commit comments

Comments
 (0)