Skip to content

Commit dfe15f0

Browse files
committed
controlnet training resize inputs to multiple of 8
1 parent cb63feb commit dfe15f0

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch
2626
import torch.nn.functional as F
2727
import torch.utils.checkpoint
28+
import torchvision.transforms.functional as TF
2829
import transformers
2930
from accelerate import Accelerator
3031
from accelerate.logging import get_logger
@@ -604,9 +605,22 @@ def tokenize_captions(examples, is_train=True):
604605
)
605606
return inputs.input_ids
606607

608+
def resize_multiple_8(image):
609+
height = image.height
610+
width = image.width
611+
612+
# round down to nearest multiple of 8
613+
height = (height // 8) * 8
614+
width = (width // 8) * 8
615+
616+
image = TF.resize(image, (height, width), transforms.InterpolationMode.BILINEAR)
617+
618+
return image
619+
607620
image_transforms = transforms.Compose(
608621
[
609622
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
623+
resize_multiple_8,
610624
transforms.ToTensor(),
611625
transforms.Normalize([0.5], [0.5]),
612626
]
@@ -615,6 +629,7 @@ def tokenize_captions(examples, is_train=True):
615629
conditioning_image_transforms = transforms.Compose(
616630
[
617631
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
632+
resize_multiple_8,
618633
transforms.ToTensor(),
619634
]
620635
)

0 commit comments

Comments
 (0)