Skip to content

Commit 0fd6444

Browse files
authored
[Feature] Add large image demo with sahi (open-mmlab#284)
* add large image demo with sahi * fix some typos * restructure based on reviews * update default patch size * add docstring and update docs * updates based on reviews * print information * add debug, update docs, add large image sample * update docs * update docs * update docs * direct user to install sahi
1 parent 5cee9c9 commit 0fd6444

File tree

5 files changed

+338
-0
lines changed

5 files changed

+338
-0
lines changed

demo/large_image.jpg

168 KB
Loading

demo/large_image_demo.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
"""Perform MMYOLO inference on large images (as satellite imagery) as:
3+
4+
```shell
5+
wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth # noqa: E501, E261.
6+
7+
python demo/large_image_demo.py \
8+
demo/large_image.jpg \
9+
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
10+
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
11+
```
12+
"""
13+
14+
import os
15+
from argparse import ArgumentParser
16+
17+
import mmcv
18+
from mmdet.apis import inference_detector, init_detector
19+
from mmengine.logging import print_log
20+
from mmengine.utils import ProgressBar
21+
22+
try:
23+
from sahi.slicing import slice_image
24+
except ImportError:
25+
raise ImportError('Please run "pip install -U sahi" '
26+
'to install sahi first for large image inference.')
27+
28+
from mmyolo.registry import VISUALIZERS
29+
from mmyolo.utils import register_all_modules, switch_to_deploy
30+
from mmyolo.utils.large_image import merge_results_by_nms
31+
from mmyolo.utils.misc import get_file_list
32+
33+
34+
def parse_args():
35+
parser = ArgumentParser(
36+
description='Perform MMYOLO inference on large images.')
37+
parser.add_argument(
38+
'img', help='Image path, include image file, dir and URL.')
39+
parser.add_argument('config', help='Config file')
40+
parser.add_argument('checkpoint', help='Checkpoint file')
41+
parser.add_argument(
42+
'--out-dir', default='./output', help='Path to output file')
43+
parser.add_argument(
44+
'--device', default='cuda:0', help='Device used for inference')
45+
parser.add_argument(
46+
'--show', action='store_true', help='Show the detection results')
47+
parser.add_argument(
48+
'--deploy',
49+
action='store_true',
50+
help='Switch model to deployment mode')
51+
parser.add_argument(
52+
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
53+
parser.add_argument(
54+
'--patch-size', type=int, default=640, help='The size of patches')
55+
parser.add_argument(
56+
'--patch-overlap-ratio',
57+
type=int,
58+
default=0.25,
59+
help='Ratio of overlap between two patches')
60+
parser.add_argument(
61+
'--merge-iou-thr',
62+
type=float,
63+
default=0.25,
64+
help='IoU threshould for merging results')
65+
parser.add_argument(
66+
'--merge-nms-type',
67+
type=str,
68+
default='nms',
69+
help='NMS type for merging results')
70+
parser.add_argument(
71+
'--batch-size',
72+
type=int,
73+
default=1,
74+
help='Batch size, must greater than or equal to 1')
75+
parser.add_argument(
76+
'--debug',
77+
action='store_true',
78+
help='Export debug images at each stage for 1 input')
79+
args = parser.parse_args()
80+
return args
81+
82+
83+
def main():
84+
args = parse_args()
85+
86+
# register all modules in mmdet into the registries
87+
register_all_modules()
88+
89+
# build the model from a config file and a checkpoint file
90+
model = init_detector(args.config, args.checkpoint, device=args.device)
91+
92+
if args.deploy:
93+
switch_to_deploy(model)
94+
95+
if not os.path.exists(args.out_dir) and not args.show:
96+
os.mkdir(args.out_dir)
97+
98+
# init visualizer
99+
visualizer = VISUALIZERS.build(model.cfg.visualizer)
100+
visualizer.dataset_meta = model.dataset_meta
101+
102+
# get file list
103+
files, source_type = get_file_list(args.img)
104+
105+
# if debug, only process the first file
106+
if args.debug:
107+
files = files[:1]
108+
109+
# start detector inference
110+
print(f'Performing inference on {len(files)} images... \
111+
This may take a while.')
112+
progress_bar = ProgressBar(len(files))
113+
for file in files:
114+
# read image
115+
img = mmcv.imread(file)
116+
117+
# arrange slices
118+
height, width = img.shape[:2]
119+
sliced_image_object = slice_image(
120+
img,
121+
slice_height=args.patch_size,
122+
slice_width=args.patch_size,
123+
auto_slice_resolution=False,
124+
overlap_height_ratio=args.patch_overlap_ratio,
125+
overlap_width_ratio=args.patch_overlap_ratio,
126+
)
127+
128+
# perform sliced inference
129+
slice_results = []
130+
start = 0
131+
while True:
132+
# prepare batch slices
133+
end = min(start + args.batch_size, len(sliced_image_object))
134+
images = []
135+
for sliced_image in sliced_image_object.images[start:end]:
136+
images.append(sliced_image)
137+
138+
# forward the model
139+
slice_results.extend(inference_detector(model, images))
140+
141+
if end >= len(sliced_image_object):
142+
break
143+
start += args.batch_size
144+
145+
if source_type['is_dir']:
146+
filename = os.path.relpath(file, args.img).replace('/', '_')
147+
else:
148+
filename = os.path.basename(file)
149+
150+
# export debug images
151+
if args.debug:
152+
# export sliced images
153+
for i, image in enumerate(sliced_image_object.images):
154+
image = mmcv.imconvert(image, 'bgr', 'rgb')
155+
out_file = os.path.join(args.out_dir, 'sliced_images',
156+
filename + f'_slice_{i}.jpg')
157+
158+
mmcv.imwrite(image, out_file)
159+
160+
# export sliced image results
161+
for i, slice_result in enumerate(slice_results):
162+
out_file = os.path.join(args.out_dir, 'sliced_image_results',
163+
filename + f'_slice_{i}_result.jpg')
164+
image = mmcv.imconvert(sliced_image_object.images[i], 'bgr',
165+
'rgb')
166+
167+
visualizer.add_datasample(
168+
os.path.basename(out_file),
169+
image,
170+
data_sample=slice_result,
171+
draw_gt=False,
172+
show=args.show,
173+
wait_time=0,
174+
out_file=out_file,
175+
pred_score_thr=args.score_thr,
176+
)
177+
178+
image_result = merge_results_by_nms(
179+
slice_results,
180+
sliced_image_object.starting_pixels,
181+
src_image_shape=(height, width),
182+
nms_cfg={
183+
'type': args.merge_nms_type,
184+
'iou_thr': args.merge_iou_thr
185+
})
186+
187+
img = mmcv.imconvert(img, 'bgr', 'rgb')
188+
out_file = None if args.show else os.path.join(args.out_dir, filename)
189+
190+
visualizer.add_datasample(
191+
os.path.basename(out_file),
192+
img,
193+
data_sample=image_result,
194+
draw_gt=False,
195+
show=args.show,
196+
wait_time=0,
197+
out_file=out_file,
198+
pred_score_thr=args.score_thr,
199+
)
200+
progress_bar.update()
201+
202+
if not args.show:
203+
print_log(
204+
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
205+
206+
207+
if __name__ == '__main__':
208+
main()

docs/en/user_guides/useful_tools.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,59 @@ python tools/analysis_tools/optimize_anchors.py ${CONFIG} \
350350
--output-dir ${OUTPUT_DIR}
351351
```
352352
353+
## Perform inference on large images
354+
355+
First install [`sahi`](https://github.com/obss/sahi) with:
356+
357+
```shell
358+
pip install -U sahi>=0.11.4
359+
```
360+
361+
Perform MMYOLO inference on large images (as satellite imagery) as:
362+
363+
```shell
364+
wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth
365+
366+
python demo/large_image_demo.py \
367+
demo/large_image.jpg \
368+
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
369+
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
370+
```
371+
372+
Arrange slicing parameters as:
373+
374+
```shell
375+
python demo/large_image_demo.py \
376+
demo/large_image.jpg \
377+
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
378+
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
379+
--patch-size 512
380+
--patch-overlap-ratio 0.25
381+
```
382+
383+
Export debug visuals while performing inference on large images as:
384+
385+
```shell
386+
python demo/large_image_demo.py \
387+
demo/large_image.jpg \
388+
configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \
389+
checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \
390+
--debug
391+
```
392+
393+
[`sahi`](https://github.com/obss/sahi) citation:
394+
395+
```
396+
@article{akyon2022sahi,
397+
title={Slicing Aided Hyper Inference and Fine-tuning for Small Object Detection},
398+
author={Akyon, Fatih Cagatay and Altinuc, Sinan Onur and Temizel, Alptekin},
399+
journal={2022 IEEE International Conference on Image Processing (ICIP)},
400+
doi={10.1109/ICIP46576.2022.9897990},
401+
pages={966-970},
402+
year={2022}
403+
}
404+
```
405+
353406
## Extracts a subset of COCO
354407
355408
The training dataset of the COCO2017 dataset includes 118K images, and the validation set includes 5K images, which is a relatively large dataset. Loading JSON in debugging or quick verification scenarios will consume more resources and bring slower startup speed.

mmyolo/utils/large_image.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Sequence, Tuple
3+
4+
from mmcv.ops import batched_nms
5+
from mmdet.structures import DetDataSample, SampleList
6+
from mmengine.structures import InstanceData
7+
8+
9+
def shift_predictions(det_data_samples: SampleList,
10+
offsets: Sequence[Tuple[int, int]],
11+
src_image_shape: Tuple[int, int]) -> SampleList:
12+
"""Shift predictions to the original image.
13+
14+
Args:
15+
det_data_samples (List[:obj:`DetDataSample`]): A list of patch results.
16+
offsets (Sequence[Tuple[int, int]]): Positions of the left top points
17+
of patches.
18+
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large
19+
image's width and height.
20+
Returns:
21+
(List[:obj:`DetDataSample`]): shifted results.
22+
"""
23+
try:
24+
from sahi.slicing import shift_bboxes, shift_masks
25+
except ImportError:
26+
raise ImportError('Please run "pip install -U sahi" '
27+
'to install sahi first for large image inference.')
28+
29+
assert len(det_data_samples) == len(
30+
offsets), 'The `results` should has the ' 'same length with `offsets`.'
31+
shifted_predictions = []
32+
for det_data_sample, offset in zip(det_data_samples, offsets):
33+
pred_inst = det_data_sample.pred_instances.clone()
34+
35+
# shift bboxes and masks
36+
pred_inst.bboxes = shift_bboxes(pred_inst.bboxes, offset)
37+
if 'masks' in det_data_sample:
38+
pred_inst.masks = shift_masks(pred_inst.masks, offset,
39+
src_image_shape)
40+
41+
shifted_predictions.append(pred_inst.clone())
42+
43+
shifted_predictions = InstanceData.cat(shifted_predictions)
44+
45+
return shifted_predictions
46+
47+
48+
def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int,
49+
int]],
50+
src_image_shape: Tuple[int, int],
51+
nms_cfg: dict) -> DetDataSample:
52+
"""Merge patch results by nms.
53+
54+
Args:
55+
results (List[:obj:`DetDataSample`]): A list of patch results.
56+
offsets (Sequence[Tuple[int, int]]): Positions of the left top points
57+
of patches.
58+
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large
59+
image's width and height.
60+
nms_cfg (dict): it should specify nms type and other parameters
61+
like `iou_threshold`.
62+
Returns:
63+
:obj:`DetDataSample`: merged results.
64+
"""
65+
shifted_instances = shift_predictions(results, offsets, src_image_shape)
66+
67+
_, keeps = batched_nms(
68+
boxes=shifted_instances.bboxes,
69+
scores=shifted_instances.scores,
70+
idxs=shifted_instances.labels,
71+
nms_cfg=nms_cfg)
72+
merged_instances = shifted_instances[keeps]
73+
74+
merged_result = results[0].clone()
75+
merged_result.pred_instances = merged_instances
76+
return merged_result

requirements/sahi.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
sahi>=0.11.4

0 commit comments

Comments
 (0)