We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b71f0d3 commit 6e0cc4eCopy full SHA for 6e0cc4e
torchvision/csrc/ROIPool.h
@@ -28,13 +28,15 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_autocast(
28
const int64_t pooled_height,
29
const int64_t pooled_width) {
30
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
31
- return roi_pool(
32
- at::autocast::cached_cast(at::kFloat, input),
33
- at::autocast::cached_cast(at::kFloat, rois),
34
- spatial_scale,
35
- pooled_height,
36
- pooled_width)
37
- .to(input.scalar_type());
+ auto result = roi_pool(
+ at::autocast::cached_cast(at::kFloat, input),
+ at::autocast::cached_cast(at::kFloat, rois),
+ spatial_scale,
+ pooled_height,
+ pooled_width);
+
38
+ return std::make_tuple(
39
+ result[0].to(input.scalar_type()), result[1].to(input.scalar_type()));
40
}
41
#endif
42
0 commit comments