Skip to content

Commit c680249

Browse files
sayakpaulorpatashnik
authored andcommitted
[ControlNet SDXL training] fixes in the training script (huggingface#4223)
* fix: huggingface#4206 * add: sdxl controlnet training smoketest. * remove unnecessary token inits. * add: licensing to model card. * include SDXL licensing in the model card and make public visibility default * debugging * debugging * disable local file download. * fix: training test. * fix: ckpt prefix.
1 parent 206d1bb commit c680249

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
124124
for _ in range(args.num_validation_images):
125125
with torch.autocast("cuda"):
126126
image = pipeline(
127-
validation_prompt, validation_image, num_inference_steps=20, generator=generator
127+
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
128128
).images[0]
129129
images.append(image)
130130

@@ -178,7 +178,7 @@ def import_model_class_from_model_name_or_path(
178178
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
179179
):
180180
text_encoder_config = PretrainedConfig.from_pretrained(
181-
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
181+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
182182
)
183183
model_class = text_encoder_config.architectures[0]
184184

@@ -226,6 +226,12 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
226226
227227
These are controlnet weights trained on {base_model} with new type of conditioning.
228228
{img_str}
229+
"""
230+
model_card += """
231+
232+
## License
233+
234+
[SDXL 0.9 Research License](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9/blob/main/LICENSE.md)
229235
"""
230236
with open(os.path.join(repo_folder, "README.md"), "w") as f:
231237
f.write(yaml + model_card)
@@ -798,10 +804,7 @@ def main(args):
798804

799805
if args.push_to_hub:
800806
repo_id = create_repo(
801-
repo_id=args.hub_model_id or Path(args.output_dir).name,
802-
exist_ok=True,
803-
token=args.hub_token,
804-
private=True,
807+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
805808
).repo_id
806809

807810
# Load the tokenizers
@@ -839,7 +842,7 @@ def main(args):
839842
revision=args.revision,
840843
)
841844
unet = UNet2DConditionModel.from_pretrained(
842-
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_auth_token=True
845+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
843846
)
844847

845848
if args.controlnet_model_name_or_path:

examples/test_examples.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,25 @@ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_check
12961296
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
12971297
)
12981298

1299+
def test_controlnet_sdxl(self):
1300+
with tempfile.TemporaryDirectory() as tmpdir:
1301+
test_args = f"""
1302+
examples/controlnet/train_controlnet_sdxl.py
1303+
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
1304+
--dataset_name=hf-internal-testing/fill10
1305+
--output_dir={tmpdir}
1306+
--resolution=64
1307+
--train_batch_size=1
1308+
--gradient_accumulation_steps=1
1309+
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
1310+
--max_train_steps=9
1311+
--checkpointing_steps=2
1312+
""".split()
1313+
1314+
run_command(self._launch_args + test_args)
1315+
1316+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.bin")))
1317+
12991318
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
13001319
with tempfile.TemporaryDirectory() as tmpdir:
13011320
test_args = f"""

src/diffusers/models/controlnet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,6 @@ def forward(
751751
sample = self.conv_in(sample)
752752

753753
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
754-
755754
sample = sample + controlnet_cond
756755

757756
# 3. down

0 commit comments

Comments
 (0)