Skip to content

Commit e3a020b

Browse files
committed
fix(): include missing utilities to get the training to run
1 parent 1891628 commit e3a020b

File tree

5 files changed

+254
-4
lines changed

5 files changed

+254
-4
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
.idea/
1+
.idea/
2+
.env

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ authors = [
99
readme = "README.md"
1010
dependencies = [
1111
"pose-format>=0.3.2",
12-
"numpy",
12+
"numpy<2.0.0",
1313
"pympi-ling", # Working with ELAN files in CLI
1414
"torch",
1515
]
@@ -19,7 +19,8 @@ dev = [
1919
"pytest",
2020
"pylint",
2121
"pytorch-lightning",
22-
"sign-language-datasets",
22+
"mediapipe",
23+
"sign_language_datasets @ git+https://github.com/sign-language-processing/datasets",
2324
"wandb",
2425
"matplotlib",
2526
"scikit-learn",

sign_language_segmentation/src/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import numpy.ma as ma
77
import torch
8-
from _shared.tfds_dataset import ProcessedPoseDatum, get_tfds_dataset
8+
from .tfds_dataset import ProcessedPoseDatum, get_tfds_dataset
99
from pose_format import Pose
1010
from pose_format.numpy.representation.distance import DistanceRepresentation
1111
from pose_format.utils.generic import normalize_hands_3d
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from typing import Tuple
2+
3+
import numpy as np
4+
from numpy import ma
5+
from pose_format import Pose
6+
from pose_format.numpy import NumPyPoseBody
7+
from pose_format.pose_header import PoseHeader, PoseHeaderDimensions
8+
from pose_format.utils.normalization_3d import PoseNormalizer
9+
from pose_format.utils.openpose import OpenPose_Components
10+
11+
12+
def pose_hide_legs(pose: Pose):
13+
if pose.header.components[0].name == "POSE_LANDMARKS":
14+
point_names = ["KNEE", "ANKLE", "HEEL", "FOOT_INDEX"]
15+
# pylint: disable=protected-access
16+
points = [
17+
pose.header._get_point_index("POSE_LANDMARKS", side + "_" + n)
18+
for n in point_names
19+
for side in ["LEFT", "RIGHT"]
20+
]
21+
pose.body.data[:, :, points, :] = 0
22+
pose.body.confidence[:, :, points] = 0
23+
elif pose.header.components[0].name == "pose_keypoints_2d":
24+
point_names = ["Hip", "Knee", "Ankle", "BigToe", "SmallToe", "Heel"]
25+
# pylint: disable=protected-access
26+
points = [
27+
pose.header._get_point_index("pose_keypoints_2d", side + n)
28+
for n in point_names
29+
for side in ["L", "R"]
30+
]
31+
pose.body.data[:, :, points, :] = 0
32+
pose.body.confidence[:, :, points] = 0
33+
else:
34+
raise ValueError("Unknown pose header schema for hiding legs")
35+
36+
37+
def pose_shoulders(pose_header: PoseHeader):
38+
if pose_header.components[0].name == "POSE_LANDMARKS":
39+
return ("POSE_LANDMARKS", "RIGHT_SHOULDER"), ("POSE_LANDMARKS", "LEFT_SHOULDER")
40+
41+
if pose_header.components[0].name == "BODY_135":
42+
return ("BODY_135", "RShoulder"), ("BODY_135", "LShoulder")
43+
44+
if pose_header.components[0].name == "pose_keypoints_2d":
45+
return ("pose_keypoints_2d", "RShoulder"), ("pose_keypoints_2d", "LShoulder")
46+
47+
raise ValueError("Unknown pose header schema for normalization")
48+
49+
50+
def hands_indexes(pose_header: PoseHeader):
51+
if pose_header.components[0].name == "POSE_LANDMARKS":
52+
return [pose_header._get_point_index("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"),
53+
pose_header._get_point_index("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP")]
54+
55+
if pose_header.components[0].name == "pose_keypoints_2d":
56+
return [pose_header._get_point_index("hand_left_keypoints_2d", "M_CMC"),
57+
pose_header._get_point_index("hand_right_keypoints_2d", "M_CMC")]
58+
59+
60+
def pose_normalization_info(pose_header: PoseHeader):
61+
(c1, p1), (c2, p2) = pose_shoulders(pose_header)
62+
return pose_header.normalization_info(p1=(c1, p1), p2=(c2, p2))
63+
64+
65+
def hands_components(pose_header: PoseHeader):
66+
if pose_header.components[0].name in ["POSE_LANDMARKS", "LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"]:
67+
return ("LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"), \
68+
("WRIST", "PINKY_MCP", "INDEX_FINGER_MCP"), \
69+
("WRIST", "MIDDLE_FINGER_MCP")
70+
71+
if pose_header.components[0].name in ["pose_keypoints_2d", "hand_left_keypoints_2d", "hand_right_keypoints_2d"]:
72+
return ("hand_left_keypoints_2d", "hand_right_keypoints_2d"), \
73+
("BASE", "P_CMC", "I_CMC"), \
74+
("BASE", "M_CMC")
75+
76+
raise ValueError("Unknown pose header")
77+
78+
79+
def normalize_component_3d(pose, component_name: str, plane: Tuple[str, str, str], line: Tuple[str, str]):
80+
hand_pose = pose.get_components([component_name])
81+
plane = hand_pose.header.normalization_info(p1=(component_name, plane[0]),
82+
p2=(component_name, plane[1]),
83+
p3=(component_name, plane[2]))
84+
line = hand_pose.header.normalization_info(p1=(component_name, line[0]),
85+
p2=(component_name, line[1]))
86+
normalizer = PoseNormalizer(plane=plane, line=line)
87+
normalized_hand = normalizer(hand_pose.body.data)
88+
89+
# Add normalized hand to pose
90+
pose.body.data = ma.concatenate([pose.body.data, normalized_hand], axis=2).astype(np.float32)
91+
pose.body.confidence = np.concatenate([pose.body.confidence, hand_pose.body.confidence], axis=2)
92+
93+
94+
def normalize_hands_3d(pose: Pose, left_hand=True, right_hand=True):
95+
(left_hand_component, right_hand_component), plane, line = hands_components(pose.header)
96+
if left_hand:
97+
normalize_component_3d(pose, left_hand_component, plane, line)
98+
if right_hand:
99+
normalize_component_3d(pose, right_hand_component, plane, line)
100+
101+
102+
def fake_pose(num_frames: int, fps=25, dims=2, components=OpenPose_Components):
103+
dimensions = PoseHeaderDimensions(width=1, height=1, depth=1)
104+
header = PoseHeader(version=0.1, dimensions=dimensions, components=components)
105+
106+
total_points = header.total_points()
107+
data = np.random.randn(num_frames, 1, total_points, dims)
108+
confidence = np.random.randn(num_frames, 1, total_points)
109+
masked_data = ma.masked_array(data)
110+
111+
body = NumPyPoseBody(fps=int(fps), data=masked_data, confidence=confidence)
112+
113+
return Pose(header, body)
114+
115+
116+
def correct_wrist(pose: Pose, hand: str) -> Pose:
117+
wrist_index = pose.header._get_point_index(f'{hand}_HAND_LANDMARKS', 'WRIST')
118+
wrist = pose.body.data[:, :, wrist_index]
119+
wrist_conf = pose.body.confidence[:, :, wrist_index]
120+
121+
body_wrist_index = pose.header._get_point_index('POSE_LANDMARKS', f'{hand}_WRIST')
122+
body_wrist = pose.body.data[:, :, body_wrist_index]
123+
body_wrist_conf = pose.body.confidence[:, :, body_wrist_index]
124+
125+
new_wrist_data = ma.where(wrist.data == 0, body_wrist, wrist)
126+
new_wrist_conf = ma.where(wrist_conf == 0, body_wrist_conf, wrist_conf)
127+
128+
pose.body.data[:, :, body_wrist_index] = ma.masked_equal(new_wrist_data, 0)
129+
pose.body.confidence[:, :, body_wrist_index] = new_wrist_conf
130+
return pose
131+
132+
133+
def correct_wrists(pose: Pose) -> Pose:
134+
pose = correct_wrist(pose, 'LEFT')
135+
pose = correct_wrist(pose, 'RIGHT')
136+
return pose
137+
138+
139+
def reduce_holistic(pose: Pose) -> Pose:
140+
if pose.header.components[0].name != "POSE_LANDMARKS":
141+
return pose
142+
143+
import mediapipe as mp
144+
points_set = set([p for p_tup in list(mp.solutions.holistic.FACEMESH_CONTOURS) for p in p_tup])
145+
face_contours = [str(p) for p in sorted(points_set)]
146+
147+
ignore_names = [
148+
"EAR", "NOSE", "MOUTH", "EYE", # Face
149+
"THUMB", "PINKY", "INDEX", # Hands
150+
"KNEE", "ANKLE", "HEEL", "FOOT_INDEX" # Feet
151+
]
152+
153+
body_component = [c for c in pose.header.components if c.name == 'POSE_LANDMARKS'][0]
154+
body_no_face_no_hands = [p for p in body_component.points if all([i not in p for i in ignore_names])]
155+
156+
components = [c.name for c in pose.header.components if c.name != 'POSE_WORLD_LANDMARKS']
157+
return pose.get_components(components, {
158+
"FACE_LANDMARKS": face_contours,
159+
"POSE_LANDMARKS": body_no_face_no_hands
160+
})
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import importlib
2+
from typing import Dict, List, TypedDict, Union
3+
4+
import tensorflow_datasets as tfds
5+
from pose_format import Pose
6+
from pose_format.numpy.pose_body import NumPyPoseBody
7+
from pose_format.pose_header import PoseHeader
8+
from pose_format.utils.reader import BufferReader
9+
from sign_language_datasets.datasets.config import SignDatasetConfig
10+
from sign_language_datasets.datasets.dgs_corpus import DgsCorpusConfig
11+
from tqdm import tqdm
12+
import mediapipe as mp
13+
14+
from .pose_utils import pose_hide_legs, pose_normalization_info
15+
16+
17+
mp_holistic = mp.solutions.holistic
18+
FACEMESH_CONTOURS_POINTS = [str(p) for p in sorted(set([p for p_tup in list(mp_holistic.FACEMESH_CONTOURS) for p in p_tup]))]
19+
20+
class ProcessedPoseDatum(TypedDict):
21+
id: str
22+
pose: Union[Pose, Dict[str, Pose]]
23+
tf_datum: dict
24+
25+
26+
def get_tfds_dataset(name,
27+
poses="holistic",
28+
fps=25,
29+
split="train",
30+
components: List[str] = None,
31+
reduce_face=False,
32+
data_dir=None,
33+
version="1.0.0",
34+
filter_func=None):
35+
dataset_module = importlib.import_module("sign_language_datasets.datasets." + name + "." + name)
36+
37+
config_kwargs = dict(
38+
name=poses + "-" + str(fps),
39+
version=version, # Specific version
40+
include_video=False, # Download and load dataset videos
41+
fps=fps, # Load videos at constant fps
42+
include_pose=poses)
43+
44+
# Loading a dataset with custom configuration
45+
if name == "dgs_corpus":
46+
config = DgsCorpusConfig(**config_kwargs, split="3.0.0-uzh-document")
47+
else:
48+
config = SignDatasetConfig(**config_kwargs)
49+
50+
tfds_dataset = tfds.load(name=name, builder_kwargs=dict(config=config), split=split, data_dir=data_dir)
51+
52+
# pylint: disable=protected-access
53+
with open(dataset_module._POSE_HEADERS[poses], "rb") as buffer:
54+
pose_header = PoseHeader.read(BufferReader(buffer.read()))
55+
56+
normalization_info = pose_normalization_info(pose_header)
57+
return [process_datum(datum, pose_header, normalization_info, components, reduce_face)
58+
for datum in tqdm(tfds_dataset, desc="Loading dataset")
59+
if filter_func is None or filter_func(datum)]
60+
61+
62+
def process_datum(datum,
63+
pose_header: PoseHeader,
64+
normalization_info,
65+
components: List[str] = None,
66+
reduce_face=False) -> ProcessedPoseDatum:
67+
tf_poses = {"": datum["pose"]} if "pose" in datum else datum["poses"]
68+
poses = {}
69+
for key, tf_pose in tf_poses.items():
70+
fps = int(tf_pose["fps"].numpy())
71+
pose_body = NumPyPoseBody(fps, tf_pose["data"].numpy(), tf_pose["conf"].numpy())
72+
pose = Pose(pose_header, pose_body)
73+
74+
# Get subset of components if needed
75+
if reduce_face:
76+
pose = pose.get_components(components, {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS})
77+
elif components and len(components) != len(pose_header.components):
78+
pose = pose.get_components(components)
79+
80+
pose = pose.normalize(normalization_info)
81+
pose_hide_legs(pose)
82+
poses[key] = pose
83+
84+
return {
85+
"id": datum["id"].numpy().decode('utf-8'),
86+
"pose": poses[""] if "pose" in datum else poses,
87+
"tf_datum": datum
88+
}

0 commit comments

Comments
 (0)