Skip to content

Commit c3a641f

Browse files
authored
Address TODOs for dataset annotator (huggingface#872)
- add args usage, pass gs_url by CL flag - add support for no existing prompts
1 parent aafe7c4 commit c3a641f

File tree

4 files changed

+69
-28
lines changed

4 files changed

+69
-28
lines changed

dataset/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pip install -r requirements.txt
1616
python annotation_tool.py
1717
```
1818

19-
<img width="1308" alt="annotator" src="https://user-images.githubusercontent.com/49575973/214191759-24cc5fe6-cd53-4099-87f6-707068f8888d.png">
19+
<img width="1280" alt="annotator" src="https://user-images.githubusercontent.com/49575973/214521137-7ef6ae10-7cd8-46e6-b270-b6c0445157f1.png">
2020

2121
* Select a dataset from `Dataset` dropdown list
2222
* Select an image from `Image` dropdown list

dataset/annotation_tool.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@
22
import json
33
import jsonlines
44
import os
5+
from args import args
56
from pathlib import Path
67
from PIL import Image
78
from utils import get_datasets
89

910

10-
# TODO: pass gs_url as a command line flag
11-
# see https://cloud.google.com/docs/authentication/provide-credentials-adc to authorize
12-
gs_url = "gs://shark-datasets/portraits"
13-
1411
shark_root = Path(__file__).parent.parent
1512
demo_css = shark_root.joinpath("web/demo.css").resolve()
1613
nodlogo_loc = shark_root.joinpath(
@@ -30,15 +27,15 @@
3027
elem_id="top_logo",
3128
).style(width=150, height=100)
3229

33-
datasets, images = get_datasets(gs_url)
30+
datasets, images, ds_w_prompts = get_datasets(args.gs_url)
3431
prompt_data = dict()
3532

3633
with gr.Row(elem_id="ui_body"):
37-
# TODO: add multiselect dataset
34+
# TODO: add multiselect dataset, there is a gradio version conflict
3835
dataset = gr.Dropdown(label="Dataset", choices=datasets)
3936
image_name = gr.Dropdown(label="Image", choices=[])
4037

41-
with gr.Row(elem_id="ui_body", visible=True):
38+
with gr.Row(elem_id="ui_body"):
4239
# TODO: add ability to search image by typing
4340
with gr.Column(scale=1, min_width=600):
4441
image = gr.Image(type="filepath").style(height=512)
@@ -61,27 +58,26 @@
6158
finish = gr.Button("Finish")
6259

6360
def filter_datasets(dataset):
64-
# TODO: execute finish process when switching dataset
6561
if dataset is None:
6662
return gr.Dropdown.update(value=None, choices=[])
6763

6864
# create the dataset dir if doesn't exist and download prompt file
6965
dataset_path = str(shark_root) + "/dataset/" + dataset
70-
# TODO: check if metadata.jsonl exists
71-
prompt_gs_path = gs_url + "/" + dataset + "/metadata.jsonl"
7266
if not os.path.exists(dataset_path):
7367
os.mkdir(dataset_path)
74-
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
7568

7669
# read prompt jsonlines file
7770
prompt_data.clear()
78-
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
79-
for line in reader.iter(type=dict, skip_invalid=True):
80-
prompt_data[line["file_name"]] = (
81-
[line["text"]]
82-
if type(line["text"]) is str
83-
else line["text"]
84-
)
71+
if dataset in ds_w_prompts:
72+
prompt_gs_path = args.gs_url + "/" + dataset + "/metadata.jsonl"
73+
os.system(f'gsutil cp "{prompt_gs_path}" "{dataset_path}"/')
74+
with jsonlines.open(dataset_path + "/metadata.jsonl") as reader:
75+
for line in reader.iter(type=dict, skip_invalid=True):
76+
prompt_data[line["file_name"]] = (
77+
[line["text"]]
78+
if type(line["text"]) is str
79+
else line["text"]
80+
)
8581

8682
return gr.Dropdown.update(choices=images[dataset])
8783

@@ -92,8 +88,7 @@ def display_image(dataset, image_name):
9288
return gr.Image.update(value=None), gr.Dropdown.update(value=None)
9389

9490
# download and load the image
95-
# TODO: remove previous image if change image from dropdown
96-
img_gs_path = gs_url + "/" + dataset + "/" + image_name
91+
img_gs_path = args.gs_url + "/" + dataset + "/" + image_name
9792
img_sub_path = "/".join(image_name.split("/")[:-1])
9893
img_dst_path = (
9994
str(shark_root) + "/dataset/" + dataset + "/" + img_sub_path + "/"
@@ -103,6 +98,8 @@ def display_image(dataset, image_name):
10398
os.system(f'gsutil cp "{img_gs_path}" "{img_dst_path}"')
10499
img = Image.open(img_dst_path + image_name.split("/")[-1])
105100

101+
if image_name not in prompt_data.keys():
102+
prompt_data[image_name] = []
106103
prompt_choices = ["Add new"]
107104
prompt_choices += prompt_data[image_name]
108105
return gr.Image.update(value=img), gr.Dropdown.update(
@@ -144,6 +141,8 @@ def save_prompt(dataset, image_name, prompts, prompt):
144141
# write prompt jsonlines file
145142
with open(prompt_path, "w") as f:
146143
for key, value in prompt_data.items():
144+
if not value:
145+
continue
147146
v = value if len(value) > 1 else value[0]
148147
f.write(json.dumps({"file_name": key, "text": v}))
149148
f.write("\n")
@@ -171,6 +170,8 @@ def delete_prompt(dataset, image_name, prompts):
171170
# write prompt jsonlines file
172171
with open(prompt_path, "w") as f:
173172
for key, value in prompt_data.items():
173+
if not value:
174+
continue
174175
v = value if len(value) > 1 else value[0]
175176
f.write(json.dumps({"file_name": key, "text": v}))
176177
f.write("\n")
@@ -227,7 +228,7 @@ def finish_annotation(dataset):
227228

228229
# upload prompt and remove local data
229230
dataset_path = str(shark_root) + "/dataset/" + dataset
230-
dataset_gs_path = gs_url + "/" + dataset + "/"
231+
dataset_gs_path = args.gs_url + "/" + dataset + "/"
231232
os.system(
232233
f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"'
233234
)
@@ -240,8 +241,8 @@ def finish_annotation(dataset):
240241

241242
if __name__ == "__main__":
242243
shark_web.launch(
243-
share=False,
244+
share=args.share,
244245
inbrowser=True,
245246
server_name="0.0.0.0",
246-
server_port=8080,
247+
server_port=args.server_port,
247248
)

dataset/args.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import argparse
2+
3+
p = argparse.ArgumentParser(
4+
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
5+
)
6+
7+
##############################################################################
8+
### Dataset Annotator flags
9+
##############################################################################
10+
11+
p.add_argument(
12+
"--gs_url",
13+
type=str,
14+
required=True,
15+
help="URL to datasets in GS bucket",
16+
)
17+
18+
p.add_argument(
19+
"--share",
20+
default=False,
21+
action=argparse.BooleanOptionalAction,
22+
help="flag for generating a public URL",
23+
)
24+
25+
p.add_argument(
26+
"--server_port",
27+
type=int,
28+
default=8080,
29+
help="flag for setting server port",
30+
)
31+
32+
##############################################################################
33+
34+
args = p.parse_args()

dataset/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
def get_datasets(gs_url):
55
datasets = set()
66
images = dict()
7+
ds_w_prompts = []
78

89
storage_client = storage.Client()
910
bucket_name = gs_url.split("/")[2]
@@ -12,12 +13,17 @@ def get_datasets(gs_url):
1213

1314
for blob in blobs:
1415
dataset_name = blob.name.split("/")[1]
16+
if dataset_name == "":
17+
continue
1518
datasets.add(dataset_name)
16-
file_sub_path = "/".join(blob.name.split("/")[2:])
19+
if dataset_name not in images.keys():
20+
images[dataset_name] = []
21+
1722
# check if image or jsonl
23+
file_sub_path = "/".join(blob.name.split("/")[2:])
1824
if "/" in file_sub_path:
19-
if dataset_name not in images.keys():
20-
images[dataset_name] = []
2125
images[dataset_name] += [file_sub_path]
26+
elif "metadata.jsonl" in file_sub_path:
27+
ds_w_prompts.append(dataset_name)
2228

23-
return list(datasets), images
29+
return list(datasets), images, ds_w_prompts

0 commit comments

Comments
 (0)