diff --git a/models/license_plate_detection_yunet/lpd_yunet.py b/models/license_plate_detection_yunet/lpd_yunet.py index 36d89613..49c56855 100644 --- a/models/license_plate_detection_yunet/lpd_yunet.py +++ b/models/license_plate_detection_yunet/lpd_yunet.py @@ -61,10 +61,17 @@ def infer(self, image): def _postprocess(self, blob): # Decode dets = self._decode(blob) + + # convert to x, y, w, h format + bboxes = np.zeros((dets.shape[0], 4)) + bboxes[:, 0] = np.min(dets[:, [0, 2, 4, 6]], axis=1) # top-left x + bboxes[:, 1] = np.min(dets[:, [1, 3, 5, 7]], axis=1) # top-left y + bboxes[:, 2] = np.max(dets[:, [0, 2, 4, 6]], axis=1) - bboxes[:, 0] # width + bboxes[:, 3] = np.max(dets[:, [1, 3, 5, 7]], axis=1) - bboxes[:, 1] # height # NMS keepIdx = cv.dnn.NMSBoxes( - bboxes=dets[:, 0:4].tolist(), + bboxes=bboxes.tolist(), scores=dets[:, -1].tolist(), score_threshold=self.confidence_threshold, nms_threshold=self.nms_threshold,