Skip to content

Commit 0dcbd16

Browse files
committed
fix(app): async fixes for download, train_dreambooth
1 parent d1cd39e commit 0dcbd16

File tree

3 files changed

+41
-30
lines changed

3 files changed

+41
-30
lines changed

api/app.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,17 @@ def init():
6262
global dummy_safety_checker
6363
global always_normalize_model_id
6464

65-
send(
66-
"init",
67-
"start",
68-
{
69-
"device": device_name,
70-
"hostname": os.getenv("HOSTNAME"),
71-
"model_id": MODEL_ID,
72-
"diffusers": __version__,
73-
},
65+
asyncio.run(
66+
send(
67+
"init",
68+
"start",
69+
{
70+
"device": device_name,
71+
"hostname": os.getenv("HOSTNAME"),
72+
"model_id": MODEL_ID,
73+
"diffusers": __version__,
74+
},
75+
)
7476
)
7577

7678
dummy_safety_checker = DummySafetyChecker()
@@ -96,7 +98,7 @@ def init():
9698
else:
9799
model = None
98100

99-
send("init", "done")
101+
asyncio.run(send("init", "done"))
100102

101103

102104
def decodeBase64Image(imageStr: str, name: str) -> PIL.Image:
@@ -213,7 +215,7 @@ def sendStatus():
213215
# }
214216
# }
215217
normalized_model_id = hf_model_id or model_id
216-
download_model(
218+
await download_model(
217219
model_id=model_id,
218220
model_url=model_url,
219221
model_revision=model_revision,
@@ -426,7 +428,8 @@ def sendStatus():
426428
normalized_model_id = model_dir
427429

428430
torch.set_grad_enabled(True)
429-
result = result | TrainDreamBooth(
431+
result = result | await asyncio.to_thread(
432+
TrainDreamBooth,
430433
normalized_model_id,
431434
pipeline,
432435
model_inputs,

api/download.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from convert_to_diffusers import main as convert_to_diffusers
1313
from download_checkpoint import main as download_checkpoint
1414
from status import status
15+
import asyncio
1516

1617
USE_DREAMBOOTH = os.environ.get("USE_DREAMBOOTH")
1718
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
@@ -23,11 +24,11 @@
2324

2425

2526
# i.e. don't run during build
26-
def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):
27+
async def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):
2728
if RUNTIME_DOWNLOADS:
2829
from send import send as _send
2930

30-
_send(type, status, payload, send_opts)
31+
await _send(type, status, payload, send_opts)
3132

3233

3334
def normalize_model_id(model_id: str, model_revision):
@@ -37,7 +38,7 @@ def normalize_model_id(model_id: str, model_revision):
3738
return normalized_model_id
3839

3940

40-
def download_model(
41+
async def download_model(
4142
model_url=None,
4243
model_id=None,
4344
model_revision=None,
@@ -109,14 +110,14 @@ def download_model(
109110
# This would be quicker to just model.to(device) afterwards, but
110111
# this conveniently logs all the timings (and doesn't happen often)
111112
print("download")
112-
send("download", "start", {}, send_opts)
113+
await send("download", "start", {}, send_opts)
113114
model = loadModel(
114115
hf_model_id,
115116
False,
116117
precision=model_precision,
117118
revision=model_revision,
118119
) # download
119-
send("download", "done", {}, send_opts)
120+
await send("download", "done", {}, send_opts)
120121

121122
print("load")
122123
model = loadModel(
@@ -127,19 +128,19 @@ def download_model(
127128
model.save_pretrained(dir, safe_serialization=True)
128129

129130
# This is all duped from train_dreambooth, need to refactor TODO XXX
130-
send("compress", "start", {}, send_opts)
131+
await send("compress", "start", {}, send_opts)
131132
subprocess.run(
132133
f"tar cvf - -C {dir} . | zstd -o {model_file}",
133134
shell=True,
134135
check=True, # TODO, rather don't raise and return an error in JSON
135136
)
136137

137-
send("compress", "done", {}, send_opts)
138+
await send("compress", "done", {}, send_opts)
138139
subprocess.run(["ls", "-l", model_file])
139140

140-
send("upload", "start", {}, send_opts)
141+
await send("upload", "start", {}, send_opts)
141142
upload_result = storage.upload_file(model_file, filename)
142-
send("upload", "done", {}, send_opts)
143+
await send("upload", "done", {}, send_opts)
143144
print(upload_result)
144145
os.remove(model_file)
145146

@@ -185,12 +186,14 @@ def download_model(
185186

186187

187188
if __name__ == "__main__":
188-
download_model(
189-
model_url=os.environ.get("MODEL_URL"),
190-
model_id=os.environ.get("MODEL_ID"),
191-
hf_model_id=os.environ.get("HF_MODEL_ID"),
192-
model_revision=os.environ.get("MODEL_REVISION"),
193-
model_precision=os.environ.get("MODEL_PRECISION"),
194-
checkpoint_url=os.environ.get("CHECKPOINT_URL"),
195-
checkpoint_config_url=os.environ.get("CHECKPOINT_CONFIG_URL"),
189+
asyncio.run(
190+
download_model(
191+
model_url=os.environ.get("MODEL_URL"),
192+
model_id=os.environ.get("MODEL_ID"),
193+
hf_model_id=os.environ.get("HF_MODEL_ID"),
194+
model_revision=os.environ.get("MODEL_REVISION"),
195+
model_precision=os.environ.get("MODEL_PRECISION"),
196+
checkpoint_url=os.environ.get("CHECKPOINT_URL"),
197+
checkpoint_config_url=os.environ.get("CHECKPOINT_CONFIG_URL"),
198+
)
196199
)

api/train_dreambooth.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,22 @@
5858
from transformers import AutoTokenizer, PretrainedConfig
5959

6060
# DDA
61-
from send import send, get_now
61+
from send import send as _send
6262
from utils import Storage
6363
import subprocess
6464
import re
6565
import shutil
66+
import asyncio
6667

6768
# Our original code in docker-diffusers-api:
6869

6970
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
7071

7172

73+
def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):
74+
asyncio.run((_send(type, status, payload, send_opts)))
75+
76+
7277
def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs, send_opts):
7378
# required inputs: instance_images instance_prompt
7479

0 commit comments

Comments
 (0)