Skip to content

Commit 9cfb0b7

Browse files
NicolasHugvfdev-5
andauthored
Fixes device mismatch issue while building docs (#5428) (#5429)
Co-authored-by: vfdev <[email protected]>
1 parent 2662797 commit 9cfb0b7

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchvision/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,9 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
449449
"""
450450

451451
N, _, H, W = normalized_flow.shape
452-
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8)
453-
colorwheel = _make_colorwheel() # shape [55x3]
452+
device = normalized_flow.device
453+
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
454+
colorwheel = _make_colorwheel().to(device) # shape [55x3]
454455
num_cols = colorwheel.shape[0]
455456
norm = torch.sum(normalized_flow ** 2, dim=1).sqrt()
456457
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi

0 commit comments

Comments
 (0)