2
2
import json
3
3
import jsonlines
4
4
import os
5
+ from args import args
5
6
from pathlib import Path
6
7
from PIL import Image
7
8
from utils import get_datasets
8
9
9
10
10
- # TODO: pass gs_url as a command line flag
11
- # see https://cloud.google.com/docs/authentication/provide-credentials-adc to authorize
12
- gs_url = "gs://shark-datasets/portraits"
13
-
14
11
shark_root = Path (__file__ ).parent .parent
15
12
demo_css = shark_root .joinpath ("web/demo.css" ).resolve ()
16
13
nodlogo_loc = shark_root .joinpath (
30
27
elem_id = "top_logo" ,
31
28
).style (width = 150 , height = 100 )
32
29
33
- datasets , images = get_datasets (gs_url )
30
+ datasets , images , ds_w_prompts = get_datasets (args . gs_url )
34
31
prompt_data = dict ()
35
32
36
33
with gr .Row (elem_id = "ui_body" ):
37
- # TODO: add multiselect dataset
34
+ # TODO: add multiselect dataset, there is a gradio version conflict
38
35
dataset = gr .Dropdown (label = "Dataset" , choices = datasets )
39
36
image_name = gr .Dropdown (label = "Image" , choices = [])
40
37
41
- with gr .Row (elem_id = "ui_body" , visible = True ):
38
+ with gr .Row (elem_id = "ui_body" ):
42
39
# TODO: add ability to search image by typing
43
40
with gr .Column (scale = 1 , min_width = 600 ):
44
41
image = gr .Image (type = "filepath" ).style (height = 512 )
61
58
finish = gr .Button ("Finish" )
62
59
63
60
def filter_datasets (dataset ):
64
- # TODO: execute finish process when switching dataset
65
61
if dataset is None :
66
62
return gr .Dropdown .update (value = None , choices = [])
67
63
68
64
# create the dataset dir if doesn't exist and download prompt file
69
65
dataset_path = str (shark_root ) + "/dataset/" + dataset
70
- # TODO: check if metadata.jsonl exists
71
- prompt_gs_path = gs_url + "/" + dataset + "/metadata.jsonl"
72
66
if not os .path .exists (dataset_path ):
73
67
os .mkdir (dataset_path )
74
- os .system (f'gsutil cp "{ prompt_gs_path } " "{ dataset_path } "/' )
75
68
76
69
# read prompt jsonlines file
77
70
prompt_data .clear ()
78
- with jsonlines .open (dataset_path + "/metadata.jsonl" ) as reader :
79
- for line in reader .iter (type = dict , skip_invalid = True ):
80
- prompt_data [line ["file_name" ]] = (
81
- [line ["text" ]]
82
- if type (line ["text" ]) is str
83
- else line ["text" ]
84
- )
71
+ if dataset in ds_w_prompts :
72
+ prompt_gs_path = args .gs_url + "/" + dataset + "/metadata.jsonl"
73
+ os .system (f'gsutil cp "{ prompt_gs_path } " "{ dataset_path } "/' )
74
+ with jsonlines .open (dataset_path + "/metadata.jsonl" ) as reader :
75
+ for line in reader .iter (type = dict , skip_invalid = True ):
76
+ prompt_data [line ["file_name" ]] = (
77
+ [line ["text" ]]
78
+ if type (line ["text" ]) is str
79
+ else line ["text" ]
80
+ )
85
81
86
82
return gr .Dropdown .update (choices = images [dataset ])
87
83
@@ -92,8 +88,7 @@ def display_image(dataset, image_name):
92
88
return gr .Image .update (value = None ), gr .Dropdown .update (value = None )
93
89
94
90
# download and load the image
95
- # TODO: remove previous image if change image from dropdown
96
- img_gs_path = gs_url + "/" + dataset + "/" + image_name
91
+ img_gs_path = args .gs_url + "/" + dataset + "/" + image_name
97
92
img_sub_path = "/" .join (image_name .split ("/" )[:- 1 ])
98
93
img_dst_path = (
99
94
str (shark_root ) + "/dataset/" + dataset + "/" + img_sub_path + "/"
@@ -103,6 +98,8 @@ def display_image(dataset, image_name):
103
98
os .system (f'gsutil cp "{ img_gs_path } " "{ img_dst_path } "' )
104
99
img = Image .open (img_dst_path + image_name .split ("/" )[- 1 ])
105
100
101
+ if image_name not in prompt_data .keys ():
102
+ prompt_data [image_name ] = []
106
103
prompt_choices = ["Add new" ]
107
104
prompt_choices += prompt_data [image_name ]
108
105
return gr .Image .update (value = img ), gr .Dropdown .update (
@@ -144,6 +141,8 @@ def save_prompt(dataset, image_name, prompts, prompt):
144
141
# write prompt jsonlines file
145
142
with open (prompt_path , "w" ) as f :
146
143
for key , value in prompt_data .items ():
144
+ if not value :
145
+ continue
147
146
v = value if len (value ) > 1 else value [0 ]
148
147
f .write (json .dumps ({"file_name" : key , "text" : v }))
149
148
f .write ("\n " )
@@ -171,6 +170,8 @@ def delete_prompt(dataset, image_name, prompts):
171
170
# write prompt jsonlines file
172
171
with open (prompt_path , "w" ) as f :
173
172
for key , value in prompt_data .items ():
173
+ if not value :
174
+ continue
174
175
v = value if len (value ) > 1 else value [0 ]
175
176
f .write (json .dumps ({"file_name" : key , "text" : v }))
176
177
f .write ("\n " )
@@ -227,7 +228,7 @@ def finish_annotation(dataset):
227
228
228
229
# upload prompt and remove local data
229
230
dataset_path = str (shark_root ) + "/dataset/" + dataset
230
- dataset_gs_path = gs_url + "/" + dataset + "/"
231
+ dataset_gs_path = args . gs_url + "/" + dataset + "/"
231
232
os .system (
232
233
f'gsutil cp "{ dataset_path } /metadata.jsonl" "{ dataset_gs_path } "'
233
234
)
@@ -240,8 +241,8 @@ def finish_annotation(dataset):
240
241
241
242
if __name__ == "__main__" :
242
243
shark_web .launch (
243
- share = False ,
244
+ share = args . share ,
244
245
inbrowser = True ,
245
246
server_name = "0.0.0.0" ,
246
- server_port = 8080 ,
247
+ server_port = args . server_port ,
247
248
)
0 commit comments