Skip to content

Commit 153764a

Browse files
committed
add prompt option '--f' for filename
1 parent 589c2aa commit 153764a

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
179179

180180
- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
181181

182+
- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving.
183+
182184
- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
183185
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
184186
- `sdxl_train.py``--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。
@@ -219,6 +221,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します
219221

220222
- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
221223

224+
- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。
222225

223226
### Apr 7, 2024 / 2024-04-07: v0.8.7
224227

gen_img.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,7 @@ class BatchDataBase(NamedTuple):
14351435
clip_prompt: str
14361436
guide_image: Any
14371437
raw_prompt: str
1438+
file_name: Optional[str]
14381439

14391440

14401441
class BatchDataExt(NamedTuple):
@@ -2316,7 +2317,7 @@ def scale_and_round(x):
23162317
# このバッチの情報を取り出す
23172318
(
23182319
return_latents,
2319-
(step_first, _, _, _, init_image, mask_image, _, guide_image, _),
2320+
(step_first, _, _, _, init_image, mask_image, _, guide_image, _, _),
23202321
(
23212322
width,
23222323
height,
@@ -2339,6 +2340,7 @@ def scale_and_round(x):
23392340
prompts = []
23402341
negative_prompts = []
23412342
raw_prompts = []
2343+
filenames = []
23422344
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
23432345
noises = [
23442346
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
@@ -2371,14 +2373,15 @@ def scale_and_round(x):
23712373
all_guide_images_are_same = True
23722374
for i, (
23732375
_,
2374-
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
2376+
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt, filename),
23752377
_,
23762378
) in enumerate(batch):
23772379
prompts.append(prompt)
23782380
negative_prompts.append(negative_prompt)
23792381
seeds.append(seed)
23802382
clip_prompts.append(clip_prompt)
23812383
raw_prompts.append(raw_prompt)
2384+
filenames.append(filename)
23822385

23832386
if init_image is not None:
23842387
init_images.append(init_image)
@@ -2478,8 +2481,8 @@ def scale_and_round(x):
24782481
# save image
24792482
highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
24802483
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
2481-
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate(
2482-
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts)
2484+
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt, filename) in enumerate(
2485+
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts, filenames)
24832486
):
24842487
if highres_fix:
24852488
seed -= 1 # record original seed
@@ -2505,17 +2508,23 @@ def scale_and_round(x):
25052508
metadata.add_text("crop-top", str(crop_top))
25062509
metadata.add_text("crop-left", str(crop_left))
25072510

2508-
if args.use_original_file_name and init_images is not None:
2509-
if type(init_images) is list:
2510-
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
2511-
else:
2512-
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
2513-
elif args.sequential_file_name:
2514-
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
2511+
if filename is not None:
2512+
fln = filename
25152513
else:
2516-
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
2514+
if args.use_original_file_name and init_images is not None:
2515+
if type(init_images) is list:
2516+
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
2517+
else:
2518+
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
2519+
elif args.sequential_file_name:
2520+
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
2521+
else:
2522+
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
25172523

2518-
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
2524+
if fln.endswith(".webp"):
2525+
image.save(os.path.join(args.outdir, fln), pnginfo=metadata, quality=100) # lossy
2526+
else:
2527+
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
25192528

25202529
if not args.no_preview and not highres_1st and args.interactive:
25212530
try:
@@ -2562,6 +2571,7 @@ def scale_and_round(x):
25622571
# repeat prompt
25632572
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
25642573
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
2574+
filename = None
25652575

25662576
if pi == 0 or len(raw_prompts) > 1:
25672577
# parse prompt: if prompt is not changed, skip parsing
@@ -2783,6 +2793,12 @@ def scale_and_round(x):
27832793
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
27842794
continue
27852795

2796+
m = re.match(r"f (.+)", parg, re.IGNORECASE)
2797+
if m: # filename
2798+
filename = m.group(1)
2799+
logger.info(f"filename: {filename}")
2800+
continue
2801+
27862802
except ValueError as ex:
27872803
logger.error(f"Exception in parsing / 解析エラー: {parg}")
27882804
logger.error(f"{ex}")
@@ -2873,7 +2889,16 @@ def scale_and_round(x):
28732889
b1 = BatchData(
28742890
False,
28752891
BatchDataBase(
2876-
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt
2892+
global_step,
2893+
prompt,
2894+
negative_prompt,
2895+
seed,
2896+
init_image,
2897+
mask_image,
2898+
clip_prompt,
2899+
guide_image,
2900+
raw_prompt,
2901+
filename,
28772902
),
28782903
BatchDataExt(
28792904
width,
@@ -2916,7 +2941,7 @@ def setup_parser() -> argparse.ArgumentParser:
29162941
parser = argparse.ArgumentParser()
29172942

29182943
add_logging_arguments(parser)
2919-
2944+
29202945
parser.add_argument(
29212946
"--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む"
29222947
)

0 commit comments

Comments
 (0)