Skip to content

Commit 7aeaacc

Browse files
committed
feat(bin): add importable segment function
1 parent e3a020b commit 7aeaacc

File tree

2 files changed

+49
-33
lines changed

2 files changed

+49
-33
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ python -m sign_language_segmentation.src.train --dataset=dgs_corpus --pose=holis
6969
### E1: Bidirectional BIO Tagger
7070
We replace the IO tagging heads in E0 with BIO heads to form our baseline. Our preliminary experiments indicate that inputting only the 75 hand and body keypoints and making the LSTM layer bidirectional yields optimal results.
7171
```bash
72-
python -m sign_language_segmentation.src.train --dataset=dgs_corpus --pose=holistic --fps=25 --hidden_dim=256 --encoder_depth=1 --encoder_bidirectional=true
72+
conda activate segmentation
73+
export CUDA_VISIBLE_DEVICES=3
74+
python -m sign_language_segmentation.src.train --dataset=dgs_corpus --pose=holistic --fps=25 --hidden_dim=256 --encoder_depth=4 --encoder_bidirectional=true --no_wandb true
7375
```
7476
Or for the mediapi-skel dataset (only phrase segmentation)
7577
```bash

sign_language_segmentation/bin.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
#!/usr/bin/env python
2-
from pathlib import Path
32
import argparse
43
import os
4+
from pathlib import Path
55

66
import numpy as np
77
import pympi
88
import torch
99
from pose_format import Pose
1010
from pose_format.utils.generic import pose_normalization_info, pose_hide_legs, normalize_hands_3d
11+
from torch.fx.experimental.symbolic_shapes import lru_cache
1112

1213
from sign_language_segmentation.src.utils.probs_to_segments import probs_to_segments
1314

15+
DEFAULT_MODEL = "model_E1s-1.pth"
16+
1417

15-
def add_optical_flow(pose: Pose)->None:
18+
def add_optical_flow(pose: Pose) -> None:
1619
from pose_format.numpy.representation.distance import DistanceRepresentation
1720
from pose_format.utils.optical_flow import OpticalFlowCalculator
1821

@@ -44,6 +47,7 @@ def process_pose(pose: Pose, optical_flow=False, hand_normalization=False) -> Po
4447
return pose
4548

4649

50+
@lru_cache(maxsize=1)
4751
def load_model(model_path: str):
4852
model = torch.jit.load(model_path)
4953
model.eval()
@@ -58,7 +62,7 @@ def predict(model, pose: Pose):
5862
return model(pose_data)
5963

6064

61-
def save_pose_segments(tiers:dict, tier_id:str, input_file_path:Path)->None:
65+
def save_pose_segments(tiers: dict, tier_id: str, input_file_path: Path) -> None:
6266
# reload it without any of the processing, so we get all the original points and such.
6367
with input_file_path.open("rb") as f:
6468
pose = Pose.read(f.read())
@@ -83,42 +87,64 @@ def get_args():
8387
)
8488
parser.add_argument("--video", default=None, required=False, type=str, help="path to video file")
8589
parser.add_argument("--subtitles", default=None, required=False, type=str, help="path to subtitle file")
86-
parser.add_argument("--model", default="model_E1s-1.pth", required=False, type=str, help="path to model file")
90+
parser.add_argument("--model", default=DEFAULT_MODEL, required=False, type=str, help="path to model file")
8791
parser.add_argument("--no-pose-link", action="store_true", help="whether to link the pose file")
8892

8993
return parser.parse_args()
9094

9195

92-
def main():
93-
args = get_args()
96+
def segment_pose(pose: Pose, model: str = DEFAULT_MODEL, verbose=True):
97+
if "E4" in model:
98+
pose = process_pose(pose, optical_flow=True, hand_normalization=True)
99+
else:
100+
pose = process_pose(pose)
94101

95-
print("Loading pose ...")
96-
with open(args.pose, "rb") as f:
97-
pose = Pose.read(f.read())
98-
if "E4" in args.model:
99-
pose = process_pose(pose, optical_flow=True, hand_normalization=True)
100-
else:
101-
pose = process_pose(pose)
102-
103-
print("Loading model ...")
102+
if verbose:
103+
print("Loading model ...")
104104
install_dir = str(os.path.dirname(os.path.abspath(__file__)))
105-
model = load_model(os.path.join(install_dir, "dist", args.model))
105+
model = load_model(os.path.join(install_dir, "dist", model))
106106

107-
print("Estimating segments ...")
107+
if verbose:
108+
print("Estimating segments ...")
108109
probs = predict(model, pose)
109110

110111
sign_segments = probs_to_segments(probs["sign"], 60, 50)
111112
sentence_segments = probs_to_segments(probs["sentence"], 90, 90)
112113

113-
print("Building ELAN file ...")
114+
if verbose:
115+
print("Building ELAN file ...")
116+
eaf = pympi.Elan.Eaf(author="sign-language-processing/transcription")
117+
118+
fps = pose.body.fps
119+
114120
tiers = {
115121
"SIGN": sign_segments,
116122
"SENTENCE": sentence_segments,
117123
}
118124

119-
fps = pose.body.fps
125+
for tier_id, segments in tiers.items():
126+
eaf.add_tier(tier_id)
127+
for segment in segments:
128+
if segment["end"] == segment["start"]:
129+
segment["end"] += 1
130+
131+
# convert frame numbers to millisecond timestamps, for Elan
132+
start_time_ms = int(segment["start"] / fps * 1000)
133+
end_time_ms = int(segment["end"] / fps * 1000)
134+
eaf.add_annotation(tier_id, start_time_ms, end_time_ms)
135+
136+
return eaf, tiers
137+
138+
139+
def main():
140+
args = get_args()
141+
142+
print("Loading pose ...")
143+
with open(args.pose, "rb") as f:
144+
pose = Pose.read(f.read())
145+
146+
eaf, tiers = segment_pose(pose, model=args.model)
120147

121-
eaf = pympi.Elan.Eaf(author="sign-language-processing/transcription")
122148
if args.video is not None:
123149
mimetype = None # pympi is not familiar with mp4 files
124150
if args.video.endswith(".mp4"):
@@ -128,18 +154,6 @@ def main():
128154
if not args.no_pose_link:
129155
eaf.add_linked_file(args.pose, mimetype="application/pose")
130156

131-
for tier_id, segments in tiers.items():
132-
eaf.add_tier(tier_id)
133-
for segment in segments:
134-
# convert frame numbers to millisecond timestamps, for Elan
135-
start_time_ms = int(segment["start"] / fps * 1000)
136-
end_time_ms = int(segment["end"] / fps * 1000)
137-
eaf.add_annotation(tier_id, start_time_ms, end_time_ms)
138-
139-
if args.save_segments:
140-
print(f"Saving {args.save_segments} cropped .pose files")
141-
save_pose_segments(tiers, tier_id=args.save_segments, input_file_path=args.pose)
142-
143157
if args.subtitles and os.path.exists(args.subtitles):
144158
import srt
145159

0 commit comments

Comments
 (0)