|
| 1 | +import argparse |
| 2 | +from typing import List, Optional, Union |
| 3 | +import os |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +import torchvision.ops.boxes as bops |
| 8 | + |
| 9 | +import norfair |
| 10 | +from norfair import Detection, Tracker, Video, Paths |
| 11 | + |
| 12 | +DISTANCE_THRESHOLD_BBOX: float = 3.33 |
| 13 | +DISTANCE_THRESHOLD_CENTROID: int = 30 |
| 14 | +MAX_DISTANCE: int = 10000 |
| 15 | + |
| 16 | + |
| 17 | +class YOLO: |
| 18 | + def __init__(self, model_path: str, device: Optional[str] = None): |
| 19 | + if device is not None and "cuda" in device and not torch.cuda.is_available(): |
| 20 | + raise Exception( |
| 21 | + "Selected device='cuda', but cuda is not available to Pytorch." |
| 22 | + ) |
| 23 | + # automatically set device if its None |
| 24 | + elif device is None: |
| 25 | + device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| 26 | + |
| 27 | + if not os.path.exists(model_path): |
| 28 | + os.system(f'wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/{os.path.basename(model_path)} -O {model_path}') |
| 29 | + |
| 30 | + # load model |
| 31 | + try: |
| 32 | + self.model = torch.hub.load('WongKinYiu/yolov7', 'custom', model_path) |
| 33 | + except: |
| 34 | + raise Exception("Failed to load model from {}".format(model_path)) |
| 35 | + def __call__( |
| 36 | + self, |
| 37 | + img: Union[str, np.ndarray], |
| 38 | + conf_threshold: float = 0.25, |
| 39 | + iou_threshold: float = 0.45, |
| 40 | + image_size: int = 720, |
| 41 | + classes: Optional[List[int]] = None |
| 42 | + ) -> torch.tensor: |
| 43 | + |
| 44 | + self.model.conf = conf_threshold |
| 45 | + self.model.iou = iou_threshold |
| 46 | + if classes is not None: |
| 47 | + self.model.classes = classes |
| 48 | + detections = self.model(img, size=image_size) |
| 49 | + return detections |
| 50 | + |
| 51 | + |
| 52 | +def euclidean_distance(detection, tracked_object): |
| 53 | + return np.linalg.norm(detection.points - tracked_object.estimate) |
| 54 | + |
| 55 | + |
| 56 | +def center(points): |
| 57 | + return [np.mean(np.array(points), axis=0)] |
| 58 | + |
| 59 | + |
| 60 | +def iou_pytorch(detection, tracked_object): |
| 61 | + # Slower but simplier version of iou |
| 62 | + |
| 63 | + detection_points = np.concatenate([detection.points[0], detection.points[1]]) |
| 64 | + tracked_object_points = np.concatenate( |
| 65 | + [tracked_object.estimate[0], tracked_object.estimate[1]] |
| 66 | + ) |
| 67 | + |
| 68 | + box_a = torch.tensor([detection_points], dtype=torch.float) |
| 69 | + box_b = torch.tensor([tracked_object_points], dtype=torch.float) |
| 70 | + iou = bops.box_iou(box_a, box_b) |
| 71 | + |
| 72 | + # Since 0 <= IoU <= 1, we define 1/IoU as a distance. |
| 73 | + # Distance values will be in [1, inf) |
| 74 | + return np.float(1 / iou if iou else MAX_DISTANCE) |
| 75 | + |
| 76 | + |
| 77 | +def iou(detection, tracked_object): |
| 78 | + # Detection points will be box A |
| 79 | + # Tracked objects point will be box B. |
| 80 | + |
| 81 | + box_a = np.concatenate([detection.points[0], detection.points[1]]) |
| 82 | + box_b = np.concatenate([tracked_object.estimate[0], tracked_object.estimate[1]]) |
| 83 | + |
| 84 | + x_a = max(box_a[0], box_b[0]) |
| 85 | + y_a = max(box_a[1], box_b[1]) |
| 86 | + x_b = min(box_a[2], box_b[2]) |
| 87 | + y_b = min(box_a[3], box_b[3]) |
| 88 | + |
| 89 | + # Compute the area of intersection rectangle |
| 90 | + inter_area = max(0, x_b - x_a + 1) * max(0, y_b - y_a + 1) |
| 91 | + |
| 92 | + # Compute the area of both the prediction and tracker |
| 93 | + # rectangles |
| 94 | + box_a_area = (box_a[2] - box_a[0] + 1) * (box_a[3] - box_a[1] + 1) |
| 95 | + box_b_area = (box_b[2] - box_b[0] + 1) * (box_b[3] - box_b[1] + 1) |
| 96 | + |
| 97 | + # Compute the intersection over union by taking the intersection |
| 98 | + # area and dividing it by the sum of prediction + tracker |
| 99 | + # areas - the interesection area |
| 100 | + iou = inter_area / float(box_a_area + box_b_area - inter_area) |
| 101 | + |
| 102 | + # Since 0 <= IoU <= 1, we define 1/IoU as a distance. |
| 103 | + # Distance values will be in [1, inf) |
| 104 | + return 1 / iou if iou else (MAX_DISTANCE) |
| 105 | + |
| 106 | + |
| 107 | +def yolo_detections_to_norfair_detections( |
| 108 | + yolo_detections: torch.tensor, |
| 109 | + track_points: str = "centroid" # bbox or centroid |
| 110 | +) -> List[Detection]: |
| 111 | + """convert detections_as_xywh to norfair detections |
| 112 | + """ |
| 113 | + norfair_detections: List[Detection] = [] |
| 114 | + |
| 115 | + if track_points == "centroid": |
| 116 | + detections_as_xywh = yolo_detections.xywh[0] |
| 117 | + for detection_as_xywh in detections_as_xywh: |
| 118 | + centroid = np.array( |
| 119 | + [ |
| 120 | + detection_as_xywh[0].item(), |
| 121 | + detection_as_xywh[1].item() |
| 122 | + ] |
| 123 | + ) |
| 124 | + scores = np.array([detection_as_xywh[4].item()]) |
| 125 | + norfair_detections.append( |
| 126 | + Detection(points=centroid, scores=scores) |
| 127 | + ) |
| 128 | + elif track_points == "bbox": |
| 129 | + detections_as_xyxy = yolo_detections.xyxy[0] |
| 130 | + for detection_as_xyxy in detections_as_xyxy: |
| 131 | + bbox = np.array( |
| 132 | + [ |
| 133 | + [detection_as_xyxy[0].item(), detection_as_xyxy[1].item()], |
| 134 | + [detection_as_xyxy[2].item(), detection_as_xyxy[3].item()] |
| 135 | + ] |
| 136 | + ) |
| 137 | + scores = np.array([detection_as_xyxy[4].item(), detection_as_xyxy[4].item()]) |
| 138 | + norfair_detections.append( |
| 139 | + Detection(points=bbox, scores=scores) |
| 140 | + ) |
| 141 | + |
| 142 | + return norfair_detections |
| 143 | + |
| 144 | + |
| 145 | +parser = argparse.ArgumentParser(description="Track objects in a video.") |
| 146 | +parser.add_argument("files", type=str, nargs="+", help="Video files to process") |
| 147 | +parser.add_argument("--detector-path", type=str, default="/yolov7.pt", help="YOLOv7 model path") |
| 148 | +parser.add_argument("--img-size", type=int, default="720", help="YOLOv7 inference size (pixels)") |
| 149 | +parser.add_argument("--conf-threshold", type=float, default="0.25", help="YOLOv7 object confidence threshold") |
| 150 | +parser.add_argument("--iou-threshold", type=float, default="0.45", help="YOLOv7 IOU threshold for NMS") |
| 151 | +parser.add_argument("--classes", nargs="+", type=int, help="Filter by class: --classes 0, or --classes 0 2 3") |
| 152 | +parser.add_argument("--device", type=str, default=None, help="Inference device: 'cpu' or 'cuda'") |
| 153 | +parser.add_argument("--track-points", type=str, default="centroid", help="Track points: 'centroid' or 'bbox'") |
| 154 | +args = parser.parse_args() |
| 155 | + |
| 156 | +model = YOLO(args.detector_path, device=args.device) |
| 157 | + |
| 158 | +for input_path in args.files: |
| 159 | + video = Video(input_path=input_path) |
| 160 | + |
| 161 | + distance_function = iou if args.track_points == "bbox" else euclidean_distance |
| 162 | + distance_threshold = ( |
| 163 | + DISTANCE_THRESHOLD_BBOX |
| 164 | + if args.track_points == "bbox" |
| 165 | + else DISTANCE_THRESHOLD_CENTROID |
| 166 | + ) |
| 167 | + |
| 168 | + tracker = Tracker( |
| 169 | + distance_function=distance_function, |
| 170 | + distance_threshold=distance_threshold, |
| 171 | + ) |
| 172 | + paths_drawer = Paths(center, attenuation=0.01) |
| 173 | + |
| 174 | + for frame in video: |
| 175 | + yolo_detections = model( |
| 176 | + frame, |
| 177 | + conf_threshold=args.conf_threshold, |
| 178 | + iou_threshold=args.iou_threshold, |
| 179 | + image_size=args.img_size, |
| 180 | + classes=args.classes |
| 181 | + ) |
| 182 | + detections = yolo_detections_to_norfair_detections(yolo_detections, track_points=args.track_points) |
| 183 | + tracked_objects = tracker.update(detections=detections) |
| 184 | + if args.track_points == "centroid": |
| 185 | + norfair.draw_points(frame, detections) |
| 186 | + norfair.draw_tracked_objects(frame, tracked_objects) |
| 187 | + elif args.track_points == "bbox": |
| 188 | + norfair.draw_boxes(frame, detections) |
| 189 | + norfair.draw_tracked_boxes(frame, tracked_objects) |
| 190 | + frame = paths_drawer.draw(frame, tracked_objects) |
| 191 | + video.write(frame) |
0 commit comments