Skip to content

Commit 7e6886f

Browse files
controlnet training resize inputs to multiple of 8 (#3135)
controlnet training center crop input images to multiple of 8 The pipeline code resizes inputs to multiples of 8. Not doing this resizing in the training script is causing the encoded image to have different height/width dimensions than the encoded conditioning image (which uses a separate encoder that's part of the controlnet model). We resize and center crop the inputs to make sure they're the same size (as well as all other images in the batch). We also check that the initial resolution is a multiple of 8.
1 parent a4c91be commit 7e6886f

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,11 @@ def parse_args(input_args=None):
525525
" or the same number of `--validation_prompt`s and `--validation_image`s"
526526
)
527527

528+
if args.resolution % 8 != 0:
529+
raise ValueError(
530+
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
531+
)
532+
528533
return args
529534

530535

@@ -607,6 +612,7 @@ def tokenize_captions(examples, is_train=True):
607612
image_transforms = transforms.Compose(
608613
[
609614
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
615+
transforms.CenterCrop(args.resolution),
610616
transforms.ToTensor(),
611617
transforms.Normalize([0.5], [0.5]),
612618
]
@@ -615,6 +621,7 @@ def tokenize_captions(examples, is_train=True):
615621
conditioning_image_transforms = transforms.Compose(
616622
[
617623
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
624+
transforms.CenterCrop(args.resolution),
618625
transforms.ToTensor(),
619626
]
620627
)

0 commit comments

Comments
 (0)