Skip to content

Commit 144f098

Browse files
authored
Fixes device mismatch issue while building docs (#5428)
1 parent 8bf46d4 commit 144f098

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchvision/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,9 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
449449
"""
450450

451451
N, _, H, W = normalized_flow.shape
452-
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8)
453-
colorwheel = _make_colorwheel() # shape [55x3]
452+
device = normalized_flow.device
453+
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
454+
colorwheel = _make_colorwheel().to(device) # shape [55x3]
454455
num_cols = colorwheel.shape[0]
455456
norm = torch.sum(normalized_flow ** 2, dim=1).sqrt()
456457
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi

0 commit comments

Comments
 (0)