|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +"""Perform MMYOLO inference on large images (as satellite imagery) as: |
| 3 | +
|
| 4 | +```shell |
| 5 | +wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth syncbn_fast_8xb16-300e_coco/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth # noqa: E501, E261. |
| 6 | +
|
| 7 | +python demo/large_image_demo.py \ |
| 8 | + demo/large_image.jpg \ |
| 9 | + configs/yolov5/yolov5_m-v61_syncbn_fast_8xb16-300e_coco.py \ |
| 10 | + checkpoint/yolov5_m-v61_syncbn_fast_8xb16-300e_coco_20220917_204944-516a710f.pth \ |
| 11 | +``` |
| 12 | +""" |
| 13 | + |
| 14 | +import os |
| 15 | +from argparse import ArgumentParser |
| 16 | + |
| 17 | +import mmcv |
| 18 | +from mmdet.apis import inference_detector, init_detector |
| 19 | +from mmengine.logging import print_log |
| 20 | +from mmengine.utils import ProgressBar |
| 21 | + |
| 22 | +try: |
| 23 | + from sahi.slicing import slice_image |
| 24 | +except ImportError: |
| 25 | + raise ImportError('Please run "pip install -U sahi" ' |
| 26 | + 'to install sahi first for large image inference.') |
| 27 | + |
| 28 | +from mmyolo.registry import VISUALIZERS |
| 29 | +from mmyolo.utils import register_all_modules, switch_to_deploy |
| 30 | +from mmyolo.utils.large_image import merge_results_by_nms |
| 31 | +from mmyolo.utils.misc import get_file_list |
| 32 | + |
| 33 | + |
| 34 | +def parse_args(): |
| 35 | + parser = ArgumentParser( |
| 36 | + description='Perform MMYOLO inference on large images.') |
| 37 | + parser.add_argument( |
| 38 | + 'img', help='Image path, include image file, dir and URL.') |
| 39 | + parser.add_argument('config', help='Config file') |
| 40 | + parser.add_argument('checkpoint', help='Checkpoint file') |
| 41 | + parser.add_argument( |
| 42 | + '--out-dir', default='./output', help='Path to output file') |
| 43 | + parser.add_argument( |
| 44 | + '--device', default='cuda:0', help='Device used for inference') |
| 45 | + parser.add_argument( |
| 46 | + '--show', action='store_true', help='Show the detection results') |
| 47 | + parser.add_argument( |
| 48 | + '--deploy', |
| 49 | + action='store_true', |
| 50 | + help='Switch model to deployment mode') |
| 51 | + parser.add_argument( |
| 52 | + '--score-thr', type=float, default=0.3, help='Bbox score threshold') |
| 53 | + parser.add_argument( |
| 54 | + '--patch-size', type=int, default=640, help='The size of patches') |
| 55 | + parser.add_argument( |
| 56 | + '--patch-overlap-ratio', |
| 57 | + type=int, |
| 58 | + default=0.25, |
| 59 | + help='Ratio of overlap between two patches') |
| 60 | + parser.add_argument( |
| 61 | + '--merge-iou-thr', |
| 62 | + type=float, |
| 63 | + default=0.25, |
| 64 | + help='IoU threshould for merging results') |
| 65 | + parser.add_argument( |
| 66 | + '--merge-nms-type', |
| 67 | + type=str, |
| 68 | + default='nms', |
| 69 | + help='NMS type for merging results') |
| 70 | + parser.add_argument( |
| 71 | + '--batch-size', |
| 72 | + type=int, |
| 73 | + default=1, |
| 74 | + help='Batch size, must greater than or equal to 1') |
| 75 | + parser.add_argument( |
| 76 | + '--debug', |
| 77 | + action='store_true', |
| 78 | + help='Export debug images at each stage for 1 input') |
| 79 | + args = parser.parse_args() |
| 80 | + return args |
| 81 | + |
| 82 | + |
| 83 | +def main(): |
| 84 | + args = parse_args() |
| 85 | + |
| 86 | + # register all modules in mmdet into the registries |
| 87 | + register_all_modules() |
| 88 | + |
| 89 | + # build the model from a config file and a checkpoint file |
| 90 | + model = init_detector(args.config, args.checkpoint, device=args.device) |
| 91 | + |
| 92 | + if args.deploy: |
| 93 | + switch_to_deploy(model) |
| 94 | + |
| 95 | + if not os.path.exists(args.out_dir) and not args.show: |
| 96 | + os.mkdir(args.out_dir) |
| 97 | + |
| 98 | + # init visualizer |
| 99 | + visualizer = VISUALIZERS.build(model.cfg.visualizer) |
| 100 | + visualizer.dataset_meta = model.dataset_meta |
| 101 | + |
| 102 | + # get file list |
| 103 | + files, source_type = get_file_list(args.img) |
| 104 | + |
| 105 | + # if debug, only process the first file |
| 106 | + if args.debug: |
| 107 | + files = files[:1] |
| 108 | + |
| 109 | + # start detector inference |
| 110 | + print(f'Performing inference on {len(files)} images... \ |
| 111 | +This may take a while.') |
| 112 | + progress_bar = ProgressBar(len(files)) |
| 113 | + for file in files: |
| 114 | + # read image |
| 115 | + img = mmcv.imread(file) |
| 116 | + |
| 117 | + # arrange slices |
| 118 | + height, width = img.shape[:2] |
| 119 | + sliced_image_object = slice_image( |
| 120 | + img, |
| 121 | + slice_height=args.patch_size, |
| 122 | + slice_width=args.patch_size, |
| 123 | + auto_slice_resolution=False, |
| 124 | + overlap_height_ratio=args.patch_overlap_ratio, |
| 125 | + overlap_width_ratio=args.patch_overlap_ratio, |
| 126 | + ) |
| 127 | + |
| 128 | + # perform sliced inference |
| 129 | + slice_results = [] |
| 130 | + start = 0 |
| 131 | + while True: |
| 132 | + # prepare batch slices |
| 133 | + end = min(start + args.batch_size, len(sliced_image_object)) |
| 134 | + images = [] |
| 135 | + for sliced_image in sliced_image_object.images[start:end]: |
| 136 | + images.append(sliced_image) |
| 137 | + |
| 138 | + # forward the model |
| 139 | + slice_results.extend(inference_detector(model, images)) |
| 140 | + |
| 141 | + if end >= len(sliced_image_object): |
| 142 | + break |
| 143 | + start += args.batch_size |
| 144 | + |
| 145 | + if source_type['is_dir']: |
| 146 | + filename = os.path.relpath(file, args.img).replace('/', '_') |
| 147 | + else: |
| 148 | + filename = os.path.basename(file) |
| 149 | + |
| 150 | + # export debug images |
| 151 | + if args.debug: |
| 152 | + # export sliced images |
| 153 | + for i, image in enumerate(sliced_image_object.images): |
| 154 | + image = mmcv.imconvert(image, 'bgr', 'rgb') |
| 155 | + out_file = os.path.join(args.out_dir, 'sliced_images', |
| 156 | + filename + f'_slice_{i}.jpg') |
| 157 | + |
| 158 | + mmcv.imwrite(image, out_file) |
| 159 | + |
| 160 | + # export sliced image results |
| 161 | + for i, slice_result in enumerate(slice_results): |
| 162 | + out_file = os.path.join(args.out_dir, 'sliced_image_results', |
| 163 | + filename + f'_slice_{i}_result.jpg') |
| 164 | + image = mmcv.imconvert(sliced_image_object.images[i], 'bgr', |
| 165 | + 'rgb') |
| 166 | + |
| 167 | + visualizer.add_datasample( |
| 168 | + os.path.basename(out_file), |
| 169 | + image, |
| 170 | + data_sample=slice_result, |
| 171 | + draw_gt=False, |
| 172 | + show=args.show, |
| 173 | + wait_time=0, |
| 174 | + out_file=out_file, |
| 175 | + pred_score_thr=args.score_thr, |
| 176 | + ) |
| 177 | + |
| 178 | + image_result = merge_results_by_nms( |
| 179 | + slice_results, |
| 180 | + sliced_image_object.starting_pixels, |
| 181 | + src_image_shape=(height, width), |
| 182 | + nms_cfg={ |
| 183 | + 'type': args.merge_nms_type, |
| 184 | + 'iou_thr': args.merge_iou_thr |
| 185 | + }) |
| 186 | + |
| 187 | + img = mmcv.imconvert(img, 'bgr', 'rgb') |
| 188 | + out_file = None if args.show else os.path.join(args.out_dir, filename) |
| 189 | + |
| 190 | + visualizer.add_datasample( |
| 191 | + os.path.basename(out_file), |
| 192 | + img, |
| 193 | + data_sample=image_result, |
| 194 | + draw_gt=False, |
| 195 | + show=args.show, |
| 196 | + wait_time=0, |
| 197 | + out_file=out_file, |
| 198 | + pred_score_thr=args.score_thr, |
| 199 | + ) |
| 200 | + progress_bar.update() |
| 201 | + |
| 202 | + if not args.show: |
| 203 | + print_log( |
| 204 | + f'\nResults have been saved at {os.path.abspath(args.out_dir)}') |
| 205 | + |
| 206 | + |
| 207 | +if __name__ == '__main__': |
| 208 | + main() |
0 commit comments