Skip to content

Commit 0c464f6

Browse files
authored
Merge pull request pytorch#3 from sprt/tpu_compat
Mask R-CNN on TPU
2 parents f2031cc + 3906ba9 commit 0c464f6

File tree

9 files changed

+51
-29
lines changed

9 files changed

+51
-29
lines changed

references/detection/train_tpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def do_prediction(image_path, use_cpu=False):
133133
model.eval()
134134

135135
image = Image.open(image_path)
136+
image = torchvision.transforms.functional.resize(image, (800, 600))
136137
image_tensor = torchvision.transforms.functional.to_tensor(image)
137138

138139
if not use_cpu:

torchvision/csrc/cpu/nms_cpu.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@ template <typename scalar_t>
44
at::Tensor nms_cpu_kernel(
55
const at::Tensor& dets,
66
const at::Tensor& scores,
7-
const float iou_threshold) {
7+
const double iou_threshold,
8+
const long post_nms_top_n) {
89
AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor");
910
AT_ASSERTM(!scores.type().is_cuda(), "scores must be a CPU tensor");
1011
AT_ASSERTM(
1112
dets.type() == scores.type(), "dets should have the same type as scores");
13+
AT_ASSERTM(
14+
dets.size(0) >= post_nms_top_n,
15+
"should have at least post_nms_top_n boxes");
1216

1317
if (dets.numel() == 0)
1418
return at::empty({0}, dets.options().dtype(at::kLong));
@@ -41,7 +45,11 @@ at::Tensor nms_cpu_kernel(
4145
auto i = order[_i];
4246
if (suppressed[i] == 1)
4347
continue;
48+
4449
keep[num_to_keep++] = i;
50+
if (num_to_keep == post_nms_top_n)
51+
goto end;
52+
4553
auto ix1 = x1[i];
4654
auto iy1 = y1[i];
4755
auto ix2 = x2[i];
@@ -65,17 +73,20 @@ at::Tensor nms_cpu_kernel(
6573
suppressed[j] = 1;
6674
}
6775
}
68-
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
76+
77+
end:
78+
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/post_nms_top_n);
6979
}
7080

7181
at::Tensor nms_cpu(
7282
const at::Tensor& dets,
7383
const at::Tensor& scores,
74-
const float iou_threshold) {
84+
const double iou_threshold,
85+
const long post_nms_top_n) {
7586
auto result = at::empty({0}, dets.options());
7687

7788
AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] {
78-
result = nms_cpu_kernel<scalar_t>(dets, scores, iou_threshold);
89+
result = nms_cpu_kernel<scalar_t>(dets, scores, iou_threshold, post_nms_top_n);
7990
});
8091
return result;
8192
}

torchvision/csrc/cpu/vision_cpu.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,5 @@ at::Tensor PSROIAlign_backward_cpu(
8383
at::Tensor nms_cpu(
8484
const at::Tensor& dets,
8585
const at::Tensor& scores,
86-
const float iou_threshold);
86+
const double iou_threshold,
87+
const long post_nms_top_n);

torchvision/csrc/nms.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
at::Tensor nms(
99
const at::Tensor& dets,
1010
const at::Tensor& scores,
11-
const double iou_threshold) {
11+
const double iou_threshold,
12+
const long post_nms_top_n) {
1213
if (dets.device().is_cuda()) {
1314
#ifdef WITH_CUDA
1415
if (dets.numel() == 0) {
@@ -21,6 +22,6 @@ at::Tensor nms(
2122
#endif
2223
}
2324

24-
at::Tensor result = nms_cpu(dets, scores, iou_threshold);
25+
at::Tensor result = nms_cpu(dets, scores, iou_threshold, post_nms_top_n);
2526
return result;
2627
}

torchvision/models/detection/_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,10 @@ def decode_single(self, rel_codes, boxes):
202202
pred_w = torch.exp(dw) * widths[:, None]
203203
pred_h = torch.exp(dh) * heights[:, None]
204204

205-
pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
206-
pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h
207-
pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
208-
pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h
205+
pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_ctr_x.device) * pred_w
206+
pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_ctr_y.device) * pred_h
207+
pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_ctr_x.device) * pred_w
208+
pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_ctr_y.device) * pred_h
209209
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
210210
return pred_boxes
211211

torchvision/models/detection/roi_heads.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,7 @@ def postprocess_detections(self, class_logits, box_regression, proposals, image_
603603
boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
604604

605605
# non-maximum suppression, independently done per class
606-
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
607-
# keep only topk scoring predictions
608-
keep = keep[:self.detections_per_img]
606+
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh, self.detections_per_img)
609607

610608
# keep.shape = [0]
611609
boxes, scores, labels = boxes[keep], scores[keep], labels[keep]

torchvision/models/detection/rpn.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
self.sizes = sizes
6161
self.aspect_ratios = aspect_ratios
6262
self.cell_anchors = None
63-
self._cache = {}
63+
# self._cache = {}
6464

6565
@staticmethod
6666
def generate_anchors(scales, aspect_ratios, dtype=torch.float32, device="cpu"):
@@ -76,8 +76,8 @@ def generate_anchors(scales, aspect_ratios, dtype=torch.float32, device="cpu"):
7676
return base_anchors.round()
7777

7878
def set_cell_anchors(self, dtype, device):
79-
if self.cell_anchors is not None:
80-
return self.cell_anchors
79+
# if self.cell_anchors is not None:
80+
# return self.cell_anchors
8181
cell_anchors = [
8282
self.generate_anchors(
8383
sizes,
@@ -123,10 +123,10 @@ def grid_anchors(self, grid_sizes, strides):
123123

124124
def cached_grid_anchors(self, grid_sizes, strides):
125125
key = tuple(grid_sizes) + tuple(strides)
126-
if key in self._cache:
127-
return self._cache[key]
126+
# if key in self._cache:
127+
# return self._cache[key]
128128
anchors = self.grid_anchors(grid_sizes, strides)
129-
self._cache[key] = anchors
129+
# self._cache[key] = anchors
130130
return anchors
131131

132132
def forward(self, image_list, feature_maps):
@@ -355,12 +355,9 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
355355
final_scores = []
356356
for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes):
357357
boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
358-
keep = box_ops.remove_small_boxes(boxes, self.min_size)
359-
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
360358
# non-maximum suppression, independently done per level
361-
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
359+
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh, self.post_nms_top_n)
362360
# keep only topk scoring predictions
363-
keep = keep[:self.post_nms_top_n]
364361
boxes, scores = boxes[keep], scores[keep]
365362
final_boxes.append(boxes)
366363
final_scores.append(scores)

torchvision/ops/boxes.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
2+
import torch_xla
23

34

4-
def nms(boxes, scores, iou_threshold):
5+
def nms(boxes, scores, iou_threshold, post_nms_top_n):
56
"""
67
Performs non-maximum suppression (NMS) on the boxes according
78
to their intersection-over-union (IoU).
@@ -28,10 +29,16 @@ def nms(boxes, scores, iou_threshold):
2829
of the elements that have been kept
2930
by NMS, sorted in decreasing order of scores
3031
"""
31-
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
32+
device = boxes.device
33+
torch_xla._XLAC._xla_sync_multi([boxes, scores], devices=[])
34+
boxes_cpu = boxes.cpu().clone()
35+
scores_cpu = scores.cpu().clone()
36+
keep = torch.ops.torchvision.nms(boxes_cpu, scores_cpu, iou_threshold)
37+
keep = keep.to(device=device)
38+
return keep
3239

3340

34-
def batched_nms(boxes, scores, idxs, iou_threshold):
41+
def batched_nms(boxes, scores, idxs, iou_threshold, post_nms_top_n):
3542
"""
3643
Performs non-maximum suppression in a batched fashion.
3744
@@ -67,7 +74,7 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
6774
max_coordinate = boxes.max()
6875
offsets = idxs.to(boxes) * (max_coordinate + 1)
6976
boxes_for_nms = boxes + offsets[:, None]
70-
keep = nms(boxes_for_nms, scores, iou_threshold)
77+
keep = nms(boxes_for_nms, scores, iou_threshold, post_nms_top_n)
7178
print("ops/boxes.py; keep.shape: {}".format(keep.shape))
7279
return keep
7380

torchvision/ops/poolers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,16 @@ def forward(self, x, boxes, image_shapes):
184184
idx_in_level = torch.nonzero(levels == level).squeeze(1)
185185
rois_per_level = rois[idx_in_level]
186186

187+
xla_device = per_level_feature.device
188+
torch_xla._XLAC._xla_sync_multi([per_level_feature, rois_per_level], devices=[])
189+
190+
per_level_feature_cpu = per_level_feature.cpu().clone()
191+
rois_per_level_cpu = rois_per_level.cpu().clone()
187192
result_idx_in_level = roi_align(
188193
per_level_feature, rois_per_level,
189194
output_size=self.output_size,
190-
spatial_scale=scale, sampling_ratio=self.sampling_ratio)
195+
spatial_scale=scale, sampling_ratio=self.sampling_ratio
196+
).to(xla_device)
191197

192198
if torchvision._is_tracing():
193199
results.append(result_idx_in_level.to(dtype))

0 commit comments

Comments
 (0)