12
12
from convert_to_diffusers import main as convert_to_diffusers
13
13
from download_checkpoint import main as download_checkpoint
14
14
from status import status
15
+ import asyncio
15
16
16
17
USE_DREAMBOOTH = os .environ .get ("USE_DREAMBOOTH" )
17
18
HF_AUTH_TOKEN = os .environ .get ("HF_AUTH_TOKEN" )
23
24
24
25
25
26
# i.e. don't run during build
26
- def send (type : str , status : str , payload : dict = {}, send_opts : dict = {}):
27
+ async def send (type : str , status : str , payload : dict = {}, send_opts : dict = {}):
27
28
if RUNTIME_DOWNLOADS :
28
29
from send import send as _send
29
30
30
- _send (type , status , payload , send_opts )
31
+ await _send (type , status , payload , send_opts )
31
32
32
33
33
34
def normalize_model_id (model_id : str , model_revision ):
@@ -37,7 +38,7 @@ def normalize_model_id(model_id: str, model_revision):
37
38
return normalized_model_id
38
39
39
40
40
- def download_model (
41
+ async def download_model (
41
42
model_url = None ,
42
43
model_id = None ,
43
44
model_revision = None ,
@@ -109,14 +110,14 @@ def download_model(
109
110
# This would be quicker to just model.to(device) afterwards, but
110
111
# this conveniently logs all the timings (and doesn't happen often)
111
112
print ("download" )
112
- send ("download" , "start" , {}, send_opts )
113
+ await send ("download" , "start" , {}, send_opts )
113
114
model = loadModel (
114
115
hf_model_id ,
115
116
False ,
116
117
precision = model_precision ,
117
118
revision = model_revision ,
118
119
) # download
119
- send ("download" , "done" , {}, send_opts )
120
+ await send ("download" , "done" , {}, send_opts )
120
121
121
122
print ("load" )
122
123
model = loadModel (
@@ -127,19 +128,19 @@ def download_model(
127
128
model .save_pretrained (dir , safe_serialization = True )
128
129
129
130
# This is all duped from train_dreambooth, need to refactor TODO XXX
130
- send ("compress" , "start" , {}, send_opts )
131
+ await send ("compress" , "start" , {}, send_opts )
131
132
subprocess .run (
132
133
f"tar cvf - -C { dir } . | zstd -o { model_file } " ,
133
134
shell = True ,
134
135
check = True , # TODO, rather don't raise and return an error in JSON
135
136
)
136
137
137
- send ("compress" , "done" , {}, send_opts )
138
+ await send ("compress" , "done" , {}, send_opts )
138
139
subprocess .run (["ls" , "-l" , model_file ])
139
140
140
- send ("upload" , "start" , {}, send_opts )
141
+ await send ("upload" , "start" , {}, send_opts )
141
142
upload_result = storage .upload_file (model_file , filename )
142
- send ("upload" , "done" , {}, send_opts )
143
+ await send ("upload" , "done" , {}, send_opts )
143
144
print (upload_result )
144
145
os .remove (model_file )
145
146
@@ -185,12 +186,14 @@ def download_model(
185
186
186
187
187
188
if __name__ == "__main__" :
188
- download_model (
189
- model_url = os .environ .get ("MODEL_URL" ),
190
- model_id = os .environ .get ("MODEL_ID" ),
191
- hf_model_id = os .environ .get ("HF_MODEL_ID" ),
192
- model_revision = os .environ .get ("MODEL_REVISION" ),
193
- model_precision = os .environ .get ("MODEL_PRECISION" ),
194
- checkpoint_url = os .environ .get ("CHECKPOINT_URL" ),
195
- checkpoint_config_url = os .environ .get ("CHECKPOINT_CONFIG_URL" ),
189
+ asyncio .run (
190
+ download_model (
191
+ model_url = os .environ .get ("MODEL_URL" ),
192
+ model_id = os .environ .get ("MODEL_ID" ),
193
+ hf_model_id = os .environ .get ("HF_MODEL_ID" ),
194
+ model_revision = os .environ .get ("MODEL_REVISION" ),
195
+ model_precision = os .environ .get ("MODEL_PRECISION" ),
196
+ checkpoint_url = os .environ .get ("CHECKPOINT_URL" ),
197
+ checkpoint_config_url = os .environ .get ("CHECKPOINT_CONFIG_URL" ),
198
+ )
196
199
)
0 commit comments