Skip to content

Commit 01a88be

Browse files
authored
support coco-wholebody visualization in pose_tracker python demo (#2450)
* update * update
1 parent 062abd9 commit 01a88be

File tree

1 file changed

+98
-23
lines changed

1 file changed

+98
-23
lines changed

demo/python/pose_tracker.py

Lines changed: 98 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44

55
import cv2
6+
import numpy as np
67
from mmdeploy_runtime import PoseTracker
78

89

@@ -18,39 +19,111 @@ def parse_args():
1819
help='path of mmdeploy SDK model dumped by model converter')
1920
parser.add_argument('video', help='video path or camera index')
2021
parser.add_argument('--output_dir', help='output directory', default=None)
22+
parser.add_argument(
23+
'--skeleton',
24+
default='coco',
25+
choices=['coco', 'coco_wholebody'],
26+
help='skeleton for keypoints')
27+
2128
args = parser.parse_args()
2229
if args.video.isnumeric():
2330
args.video = int(args.video)
2431
return args
2532

2633

27-
def visualize(frame, results, output_dir, frame_id, thr=0.5, resize=1280):
28-
skeleton = [(15, 13), (13, 11), (16, 14), (14, 12), (11, 12), (5, 11),
29-
(6, 12), (5, 6), (5, 7), (6, 8), (7, 9), (8, 10), (1, 2),
30-
(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6)]
31-
palette = [(255, 128, 0), (255, 153, 51), (255, 178, 102), (230, 230, 0),
32-
(255, 153, 255), (153, 204, 255), (255, 102, 255),
33-
(255, 51, 255), (102, 178, 255),
34-
(51, 153, 255), (255, 153, 153), (255, 102, 102), (255, 51, 51),
35-
(153, 255, 153), (102, 255, 102), (51, 255, 51), (0, 255, 0),
36-
(0, 0, 255), (255, 0, 0), (255, 255, 255)]
37-
link_color = [
38-
0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16
39-
]
40-
point_color = [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]
34+
VISUALIZATION_CFG = dict(
35+
coco=dict(
36+
skeleton=[(15, 13), (13, 11), (16, 14), (14, 12), (11, 12), (5, 11),
37+
(6, 12), (5, 6), (5, 7), (6, 8), (7, 9), (8, 10), (1, 2),
38+
(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6)],
39+
palette=[(255, 128, 0), (255, 153, 51), (255, 178, 102), (230, 230, 0),
40+
(255, 153, 255), (153, 204, 255), (255, 102, 255),
41+
(255, 51, 255), (102, 178, 255), (51, 153, 255),
42+
(255, 153, 153), (255, 102, 102), (255, 51, 51),
43+
(153, 255, 153), (102, 255, 102), (51, 255, 51), (0, 255, 0),
44+
(0, 0, 255), (255, 0, 0), (255, 255, 255)],
45+
link_color=[
46+
0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16
47+
],
48+
point_color=[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0],
49+
sigmas=[
50+
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072,
51+
0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089
52+
]),
53+
coco_wholebody=dict(
54+
skeleton=[(15, 13), (13, 11), (16, 14), (14, 12), (11, 12), (5, 11),
55+
(6, 12), (5, 6), (5, 7), (6, 8), (7, 9), (8, 10), (1, 2),
56+
(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (15, 17),
57+
(15, 18), (15, 19), (16, 20), (16, 21), (16, 22), (91, 92),
58+
(92, 93), (93, 94), (94, 95), (91, 96), (96, 97), (97, 98),
59+
(98, 99), (91, 100), (100, 101), (101, 102), (102, 103),
60+
(91, 104), (104, 105), (105, 106), (106, 107), (91, 108),
61+
(108, 109), (109, 110), (110, 111), (112, 113), (113, 114),
62+
(114, 115), (115, 116), (112, 117), (117, 118), (118, 119),
63+
(119, 120), (112, 121), (121, 122), (122, 123), (123, 124),
64+
(112, 125), (125, 126), (126, 127), (127, 128), (112, 129),
65+
(129, 130), (130, 131), (131, 132)],
66+
palette=[(51, 153, 255), (0, 255, 0), (255, 128, 0), (255, 255, 255),
67+
(255, 153, 255), (102, 178, 255), (255, 51, 51)],
68+
link_color=[
69+
1, 1, 2, 2, 0, 0, 0, 0, 1, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
70+
2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 1, 1, 1,
71+
1, 2, 2, 2, 2, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 1, 1, 1, 1
72+
],
73+
point_color=[
74+
0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2,
75+
2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
76+
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
77+
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
78+
3, 3, 3, 3, 2, 2, 2, 2, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 1, 1,
79+
1, 1, 3, 2, 2, 2, 2, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 1, 1, 1, 1
80+
],
81+
sigmas=[
82+
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072,
83+
0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089, 0.068,
84+
0.066, 0.066, 0.092, 0.094, 0.094, 0.042, 0.043, 0.044, 0.043,
85+
0.040, 0.035, 0.031, 0.025, 0.020, 0.023, 0.029, 0.032, 0.037,
86+
0.038, 0.043, 0.041, 0.045, 0.013, 0.012, 0.011, 0.011, 0.012,
87+
0.012, 0.011, 0.011, 0.013, 0.015, 0.009, 0.007, 0.007, 0.007,
88+
0.012, 0.009, 0.008, 0.016, 0.010, 0.017, 0.011, 0.009, 0.011,
89+
0.009, 0.007, 0.013, 0.008, 0.011, 0.012, 0.010, 0.034, 0.008,
90+
0.008, 0.009, 0.008, 0.008, 0.007, 0.010, 0.008, 0.009, 0.009,
91+
0.009, 0.007, 0.007, 0.008, 0.011, 0.008, 0.008, 0.008, 0.01,
92+
0.008, 0.029, 0.022, 0.035, 0.037, 0.047, 0.026, 0.025, 0.024,
93+
0.035, 0.018, 0.024, 0.022, 0.026, 0.017, 0.021, 0.021, 0.032,
94+
0.02, 0.019, 0.022, 0.031, 0.029, 0.022, 0.035, 0.037, 0.047,
95+
0.026, 0.025, 0.024, 0.035, 0.018, 0.024, 0.022, 0.026, 0.017,
96+
0.021, 0.021, 0.032, 0.02, 0.019, 0.022, 0.031
97+
]))
98+
99+
100+
def visualize(frame,
101+
results,
102+
output_dir,
103+
frame_id,
104+
thr=0.5,
105+
resize=1280,
106+
skeleton_type='coco'):
107+
108+
skeleton = VISUALIZATION_CFG[skeleton_type]['skeleton']
109+
palette = VISUALIZATION_CFG[skeleton_type]['palette']
110+
link_color = VISUALIZATION_CFG[skeleton_type]['link_color']
111+
point_color = VISUALIZATION_CFG[skeleton_type]['point_color']
112+
41113
scale = resize / max(frame.shape[0], frame.shape[1])
42114
keypoints, bboxes, _ = results
43115
scores = keypoints[..., 2]
44116
keypoints = (keypoints[..., :2] * scale).astype(int)
45117
bboxes *= scale
46118
img = cv2.resize(frame, (0, 0), fx=scale, fy=scale)
47119
for kpts, score, bbox in zip(keypoints, scores, bboxes):
48-
show = [0] * len(kpts)
120+
show = [1] * len(kpts)
49121
for (u, v), color in zip(skeleton, link_color):
50122
if score[u] > thr and score[v] > thr:
51123
cv2.line(img, kpts[u], tuple(kpts[v]), palette[color], 1,
52124
cv2.LINE_AA)
53-
show[u] = show[v] = 1
125+
else:
126+
show[u] = show[v] = 0
54127
for kpt, show, color in zip(kpts, show, point_color):
55128
if show:
56129
cv2.circle(img, kpt, 1, palette[color], 2, cv2.LINE_AA)
@@ -64,7 +137,7 @@ def visualize(frame, results, output_dir, frame_id, thr=0.5, resize=1280):
64137

65138
def main():
66139
args = parse_args()
67-
140+
np.set_printoptions(precision=4, suppress=True)
68141
video = cv2.VideoCapture(args.video)
69142

70143
tracker = PoseTracker(
@@ -73,12 +146,9 @@ def main():
73146
device_name=args.device_name)
74147

75148
# optionally use OKS for keypoints similarity comparison
76-
coco_sigmas = [
77-
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062,
78-
0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089
79-
]
149+
sigmas = VISUALIZATION_CFG[args.skeleton]['sigmas']
80150
state = tracker.create_state(
81-
det_interval=1, det_min_bbox_size=100, keypoint_sigmas=coco_sigmas)
151+
det_interval=1, det_min_bbox_size=100, keypoint_sigmas=sigmas)
82152

83153
if args.output_dir:
84154
os.makedirs(args.output_dir, exist_ok=True)
@@ -89,7 +159,12 @@ def main():
89159
if not success:
90160
break
91161
results = tracker(state, frame, detect=-1)
92-
if not visualize(frame, results, args.output_dir, frame_id):
162+
if not visualize(
163+
frame,
164+
results,
165+
args.output_dir,
166+
frame_id,
167+
skeleton_type=args.skeleton):
93168
break
94169
frame_id += 1
95170

0 commit comments

Comments
 (0)