From 407fa9e1c9e5b5387d9ad2ff1726e370446bd6a2 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Tue, 23 May 2023 20:54:58 +0000 Subject: [PATCH] explicit broadcasts for assignments --- src/diffusers/models/resnet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index debe120e8ead..92bc89c80099 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -433,7 +433,8 @@ def forward(self, x): x = F.pad(x, (self.pad,) * 4, self.pad_mode) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(x.shape[1], device=x.device) - weight[indices, indices] = self.kernel.to(weight) + kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1) + weight[indices, indices] = kernel return F.conv2d(x, weight, stride=2) @@ -449,7 +450,8 @@ def forward(self, x): x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(x.shape[1], device=x.device) - weight[indices, indices] = self.kernel.to(weight) + kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1) + weight[indices, indices] = kernel return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)