Skip to content

Commit a13bd2f

Browse files
authored
Add render_openposes_json API endpoint (#2325)
* Add render_openpose_json API endpoint * nits
1 parent 0d08416 commit a13bd2f

File tree

5 files changed

+270
-8
lines changed

5 files changed

+270
-8
lines changed

annotator/openpose/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,23 +62,19 @@ def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, dr
6262
return canvas
6363

6464

65-
def decode_json_as_poses(json_string: str, normalize_coords: bool = False) -> Tuple[List[PoseResult], int, int]:
65+
def decode_json_as_poses(pose_json: dict) -> Tuple[List[PoseResult], int, int]:
6666
""" Decode the json_string complying with the openpose JSON output format
6767
to poses that controlnet recognizes.
6868
https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md
6969
7070
Args:
7171
json_string: The json string to decode.
72-
normalize_coords: Whether to normalize coordinates of each keypoint by canvas height/width.
73-
`draw_pose` only accepts normalized keypoints. Set this param to True if
74-
the input coords are not normalized.
75-
72+
7673
Returns:
7774
poses
7875
canvas_height
7976
canvas_width
8077
"""
81-
pose_json = json.loads(json_string)
8278
height = pose_json['canvas_height']
8379
width = pose_json['canvas_width']
8480

scripts/api.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from typing import List
1+
from typing import List, Optional
22

33
import numpy as np
44
from fastapi import FastAPI, Body
55
from fastapi.exceptions import HTTPException
6+
from pydantic import BaseModel
7+
68
from PIL import Image
79

810
import gradio as gr
@@ -13,6 +15,7 @@
1315
from scripts import external_code, global_state
1416
from scripts.processor import preprocessor_filters
1517
from scripts.logging import logger
18+
from annotator.openpose import draw_poses, decode_json_as_poses
1619

1720

1821
def encode_to_base64(image):
@@ -148,6 +151,32 @@ def accept(self, json_dict: dict) -> None:
148151

149152
return res
150153

154+
class Person(BaseModel):
155+
pose_keypoints_2d: List[float]
156+
hand_right_keypoints_2d: Optional[List[float]]
157+
hand_left_keypoints_2d: Optional[List[float]]
158+
face_keypoints_2d: Optional[List[float]]
159+
160+
class PoseData(BaseModel):
161+
people: List[Person]
162+
canvas_width: int
163+
canvas_height: int
164+
165+
@app.post("/controlnet/render_openpose_json")
166+
async def render_openpose_json(
167+
pose_data: List[PoseData] = Body([], title="Pose json files to render.")
168+
):
169+
if not pose_data:
170+
return {"info": "No pose data detected."}
171+
else:
172+
return {
173+
"images": [
174+
encode_to_base64(draw_poses(*decode_json_as_poses(pose.dict())))
175+
for pose in pose_data
176+
],
177+
"info": "Success",
178+
}
179+
151180

152181
try:
153182
import modules.script_callbacks as script_callbacks

scripts/controlnet_ui/openpose_editor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import gradio as gr
3+
import json
34
from typing import List, Dict, Any, Tuple
45

56
from annotator.openpose import decode_json_as_poses, draw_poses
@@ -80,7 +81,7 @@ def register_callbacks(
8081
):
8182
def render_pose(pose_url: str) -> Tuple[Dict, Dict]:
8283
json_string = parse_data_url(pose_url).decode('utf-8')
83-
poses, height, weight = decode_json_as_poses(json_string)
84+
poses, height, weight = decode_json_as_poses(json.loads(json_string))
8485
logger.info("Preview as input is enabled.")
8586
return (
8687
# Generated image.

tests/web_api/pose.json

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
{
2+
"people": [
3+
{
4+
"pose_keypoints_2d": [
5+
275.2506064884899,
6+
196.32469357280343,
7+
1,
8+
303.3188016469506,
9+
272.70982071889466,
10+
1,
11+
244.98447950024644,
12+
292.09994638829477,
13+
1,
14+
236.38292745027104,
15+
517.7037729015278,
16+
1,
17+
168.3984479500246,
18+
418.0022744632577,
19+
1,
20+
412.90526047751445,
21+
257.2121425016039,
22+
1,
23+
403.17894535813576,
24+
510.14290732520925,
25+
1,
26+
294.43481869004336,
27+
376.4781345848482,
28+
1,
29+
265.25747216047955,
30+
562.5137576822758,
31+
1,
32+
0,
33+
0,
34+
0,
35+
0,
36+
0,
37+
0,
38+
359.0078938961762,
39+
562.0711206495608,
40+
1,
41+
0,
42+
0,
43+
0,
44+
0,
45+
0,
46+
0,
47+
240.097671925037,
48+
184.513914838073,
49+
1,
50+
308.33409148775263,
51+
161.22089906296208,
52+
1,
53+
204.7558201874076,
54+
213.05887067565308,
55+
1,
56+
366.61934701298674,
57+
148.9278832878512,
58+
1
59+
],
60+
"hand_right_keypoints_2d": [
61+
168.39790150130915,
62+
418.0005271461072,
63+
1,
64+
181.79401055357368,
65+
399.3767976307846,
66+
1,
67+
184.8627576873498,
68+
384.75709227716266,
69+
1,
70+
198.118414869015,
71+
381.90007483819153,
72+
1,
73+
215.15048903799024,
74+
386.8527024571639,
75+
1,
76+
180.1325238303079,
77+
346.88417988394843,
78+
1,
79+
178.10795487568174,
80+
321.1018085790239,
81+
1,
82+
190.70710955474613,
83+
320.5669160600145,
84+
1,
85+
203.3062827659017,
86+
325.5456061326966,
87+
1,
88+
172.3536896669116,
89+
350.7525276652412,
90+
1,
91+
170.000622912838,
92+
325.0336533262108,
93+
1,
94+
185.70972426755964,
95+
323.476117950832,
96+
1,
97+
208.50724912413568,
98+
333.4635128502484,
99+
1,
100+
163.50256975319706,
101+
356.1325737292637,
102+
1,
103+
162.59147123450128,
104+
335.0197021116338,
105+
1,
106+
183.9828354600188,
107+
328.4553078224726,
108+
1,
109+
201.57171021013423,
110+
337.94383954551654,
111+
1,
112+
152.9805889462092,
113+
357.94403689402753,
114+
1,
115+
167.09651929267878,
116+
341.7437601311937,
117+
1,
118+
180.9402668216081,
119+
337.10207327446744,
120+
1,
121+
194.60028340222678,
122+
343.0449246874465,
123+
1
124+
],
125+
"hand_left_keypoints_2d": [
126+
294.4393772120137,
127+
376.476024395234,
128+
1,
129+
271.70933825161165,
130+
384.48117305399165,
131+
1,
132+
257.2452829806548,
133+
374.58948859472207,
134+
1,
135+
238.26122936397638,
136+
375.2887100029166,
137+
1,
138+
219.89983184668415,
139+
382.69322630254595,
140+
1,
141+
263.0323651487124,
142+
320.1279349241104,
143+
1,
144+
246.94602107917282,
145+
309.8099960810156,
146+
1,
147+
233.73717716804694,
148+
314.1485136789638,
149+
1,
150+
224.27755744411303,
151+
322.7892154545116,
152+
1,
153+
264.97558037166135,
154+
334.6319791090978,
155+
1,
156+
254.35598193615226,
157+
315.5629746257517,
158+
1,
159+
238.25810853722876,
160+
321.2812182403252,
161+
1,
162+
223.57727818251382,
163+
328.39525394113423,
164+
1,
165+
278.15831452661644,
166+
337.1533682086847,
167+
1,
168+
265.12624946042416,
169+
323.3619430418993,
170+
1,
171+
250.5919197031302,
172+
325.30525908324694,
173+
1,
174+
235.2500911122877,
175+
332.6721359855453,
176+
1,
177+
285.9427695830851,
178+
341.50671458478496,
179+
1,
180+
274.50497773130155,
181+
333.1376809270594,
182+
1,
183+
261.49768784257105,
184+
328.7012203257942,
185+
1,
186+
248.90495501067193,
187+
332.0535195828255,
188+
1
189+
]
190+
}
191+
],
192+
"canvas_width": 512,
193+
"canvas_height": 512
194+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import requests
2+
import unittest
3+
import importlib
4+
import json
5+
from pathlib import Path
6+
7+
utils = importlib.import_module("extensions.sd-webui-controlnet.tests.utils", "utils")
8+
9+
10+
def render(poses):
11+
return requests.post(
12+
utils.BASE_URL + "/controlnet/render_openpose_json", json=poses
13+
).json()
14+
15+
16+
with open(Path(__file__).parent / "pose.json", "r") as f:
17+
pose = json.load(f)
18+
19+
20+
class TestDetectEndpointWorking(unittest.TestCase):
21+
def test_render_single(self):
22+
res = render([pose])
23+
self.assertEqual(res["info"], "Success")
24+
self.assertEqual(len(res["images"]), 1)
25+
26+
def test_render_multiple(self):
27+
res = render([pose, pose])
28+
self.assertEqual(res["info"], "Success")
29+
self.assertEqual(len(res["images"]), 2)
30+
31+
def test_render_no_pose(self):
32+
res = render([])
33+
self.assertNotEqual(res["info"], "Success")
34+
35+
def test_render_invalid_pose(self):
36+
res = render([{"foo": 10, "bar": 100}])
37+
self.assertNotIn("info", res)
38+
self.assertNotIn("images", res)
39+
40+
41+
if __name__ == "__main__":
42+
unittest.main()

0 commit comments

Comments
 (0)