diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index c0b52291fc9b..d52e610ca52d 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -525,6 +525,11 @@ def parse_args(input_args=None): " or the same number of `--validation_prompt`s and `--validation_image`s" ) + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + return args @@ -607,6 +612,7 @@ def tokenize_captions(examples, is_train=True): image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -615,6 +621,7 @@ def tokenize_captions(examples, is_train=True): conditioning_image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), transforms.ToTensor(), ] )