Skip to content

Commit 6e0cc4e

Browse files
committed
Fixing return casting with autocast.
1 parent b71f0d3 commit 6e0cc4e

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

torchvision/csrc/ROIPool.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_autocast(
2828
const int64_t pooled_height,
2929
const int64_t pooled_width) {
3030
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
31-
return roi_pool(
32-
at::autocast::cached_cast(at::kFloat, input),
33-
at::autocast::cached_cast(at::kFloat, rois),
34-
spatial_scale,
35-
pooled_height,
36-
pooled_width)
37-
.to(input.scalar_type());
31+
auto result = roi_pool(
32+
at::autocast::cached_cast(at::kFloat, input),
33+
at::autocast::cached_cast(at::kFloat, rois),
34+
spatial_scale,
35+
pooled_height,
36+
pooled_width);
37+
38+
return std::make_tuple(
39+
result[0].to(input.scalar_type()), result[1].to(input.scalar_type()));
3840
}
3941
#endif
4042

0 commit comments

Comments
 (0)