Skip to content

T2I-Adapters implementation does not support all official adapter models #6275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
vladmandic opened this issue Dec 21, 2023 · 7 comments
Closed
Labels
bug Something isn't working

Comments

@vladmandic
Copy link
Contributor

Describe the bug

While testing adapters for SD15, canny and sketch models are found to be incompatible with current T2IAdapter implementation in diffusers

looking at models itself, seems that those models actually have different configuration:
all working models have same shape:

T2IAdapter(
  (adapter): FullAdapter(
    (unshuffle): PixelUnshuffle(downscale_factor=8)
    (conv_in): Conv2d(192,320, kernel_size=(3,3), stride=(1,1), padding=(1,1))

while canny and sketch models have different shape:

T2IAdapter(
  (adapter): FullAdapter(
    (unshuffle): PixelUnshuffle(downscale_factor=8)
    (conv_in): Conv2d(64,320, kernel_size=(3,3), stride=(1,1), padding=(1,1))

Reproduction

import torch
import diffusers
from PIL import Image
from rich import print # pylint: disable=redefined-builtin

model_id = "models/stable-diffusion/lyriel-v16.safetensors"
print(f'torch=={torch.__version__} diffusers=={diffusers.__version__}')

adapters = [
    'TencentARC/t2iadapter_seg_sd14v1',
    'TencentARC/t2iadapter_zoedepth_sd15v1',
    'TencentARC/t2iadapter_openpose_sd14v1',
    'TencentARC/t2iadapter_keypose_sd14v1',
    'TencentARC/t2iadapter_color_sd14v1',
    'TencentARC/t2iadapter_depth_sd14v1',
    'TencentARC/t2iadapter_depth_sd15v2',
    'TencentARC/t2iadapter_canny_sd14v1',
    'TencentARC/t2iadapter_canny_sd15v2',
    'TencentARC/t2iadapter_sketch_sd14v1',
    'TencentARC/t2iadapter_sketch_sd15v2',
]
seeds = [42]

print(f'loading: {model_id}')
base = diffusers.StableDiffusionPipeline.from_single_file(model_id, variant="fp16", cache_dir='/mnt/d/Models/Diffusers').to('cuda')
image = Image.new('RGB', (512, 512), color='red')
print('loaded')

for adapter_id in adapters:
    adapter = diffusers.T2IAdapter.from_pretrained(adapter_id, cache_dir='/mnt/d/Models/control/adapters')
    pipe = diffusers.StableDiffusionAdapterPipeline(
        vae=base.vae,
        text_encoder=base.text_encoder,
        tokenizer=base.tokenizer,
        unet=base.unet,
        scheduler=base.scheduler,
        requires_safety_checker=False,
        safety_checker=None,
        feature_extractor=None,
        adapter=adapter,
    ).to('cuda')
    try:
        output = pipe(
            prompt='test',
            num_inference_steps=5,
            image=image,
            output_type='pil',
        )
        print(f'adapter: {adapter_id} {"ok" if output.images is not None else "failed"}')
    except Exception as e:
        print(f'adapter error: {adapter_id} {e}')

Logs

adapter: TencentARC/t2iadapter_seg_sd14v1 ok
adapter: TencentARC/t2iadapter_zoedepth_sd15v1 ok
adapter: TencentARC/t2iadapter_openpose_sd14v1 ok
adapter: TencentARC/t2iadapter_keypose_sd14v1 ok
adapter: TencentARC/t2iadapter_color_sd14v1 ok
adapter: TencentARC/t2iadapter_depth_sd14v1 ok
adapter: TencentARC/t2iadapter_depth_sd15v2 ok
adapter error: TencentARC/t2iadapter_canny_sd14v1 Given groups=1, weight of size [320, 64, 3, 3], expected input[1, 192, 64, 64] to have 64 channels, but got 192 channels instead
adapter error: TencentARC/t2iadapter_canny_sd15v2 Given groups=1, weight of size [320, 64, 3, 3], expected input[1, 192, 64, 64] to have 64 channels, but got 192 channels instead
adapter error: TencentARC/t2iadapter_sketch_sd14v1 Given groups=1, weight of size [320, 64, 3, 3], expected input[1, 192, 64, 64] to have 64 channels, but got 192 channels instead
adapter error: TencentARC/t2iadapter_sketch_sd15v2 Given groups=1, weight of size [320, 64, 3, 3], expected input[1, 192, 64, 64] to have 64 channels, but got 192 channels instead

System Info

torch==2.1.2+cu121 diffusers==0.25.0.dev0

Who can help?

@sayakpaul @yiyixuxu @DN6 @patrickvonplaten

@vladmandic vladmandic added the bug Something isn't working label Dec 21, 2023
@sayakpaul
Copy link
Member

Could you confirm if an earlier version of diffusers was working fine?

Cc: @williamberman

@vladmandic
Copy link
Contributor Author

nope, i tried few older versions of diffusers and its failing as well.

but #3932 converted the original models to diffusers and added explicit tests for all those models - how did those tests pass?
since they are included in official tests, i'd assume they are supported.
(https://github.com/huggingface/diffusers/blob/main/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py)

@sayakpaul
Copy link
Member

@williamberman could you check here?

@sayakpaul
Copy link
Member

I am able to run "TencentARC/t2iadapter_canny_sd14v1" (didn't check for others) following the code provided in https://huggingface.co/TencentARC/t2iadapter_canny_sd14v1:

Prepare condition:

from diffusers.utils import load_image
from PIL import Image
import numpy as np
import cv2

image = load_image("https://huggingface.co/TencentARC/t2iadapter_canny_sd14v1/resolve/main/images/canny_input.png")
image = np.array(image)

low_threshold = 100
high_threshold = 200

image = cv2.Canny(image, low_threshold, high_threshold)
image = Image.fromarray(image)

Load pipeline:

import torch
from diffusers import StableDiffusionAdapterPipeline, T2IAdapter

adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_canny_sd14v1").to(torch.float16)
pipe = StableDiffusionAdapterPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    torch_dtype=torch.float16,
    adapter=adapter,
).to("cuda")

Fire inference:

out_image = pipe(
    "a rabbit wearing glasses",
    image=image,
    generator=torch.manual_seed(0),
).images[0]

I would suggest following the official code from the model cards closely and retrying and then coming back here.

@vladmandic
Copy link
Contributor Author

vladmandic commented Dec 25, 2023

ok, i've narrowed it down - t2iadapter canny and sketch models are extremely sensitive to input and will cause an exception as i've described if input image is not correctly preprocessed.

both your and my example work (nothing wrong with how pipeline is created), IF:

image = load_image("https://huggingface.co/TencentARC/t2iadapter_canny_sd14v1/resolve/main/images/canny_input.png")
image = np.array(image)
image = cv2.Canny(image, threshold1=100, threshold2=200)
image = Image.fromarray(image)

however, both your and my example will FAIL IF:

image = Image.new('RGB', (512, 512), color='red')

so what's the difference? one thing that comes to mind is that cv2.Canny returns BGR image, not RGB (all cv2 operations are BGR). using RGB image as input to t2adapter canny or sketch causes problems.
and PIL images SHOULD always be RGB. so the fact that t2iadapter only works with BGR images is a problem.

to confirm, just add

image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

to your example and it will FAIL.

no other models exhibits this issue (only canny and sketch) - and it is a problem for production since how can we ensure that image is exactly what model expects? its users choice what to use as input and using wrong input causes an exception

@sayakpaul
Copy link
Member

Problem is the following.

When we do:

image = load_image("https://huggingface.co/TencentARC/t2iadapter_canny_sd14v1/resolve/main/images/canny_input.png")
image = np.array(image)
print(image.shape)

image = cv2.Canny(image, threshold1=100, threshold2=200)
image = Image.fromarray(image)
print(np.array(image).shape)

We get:

(651, 628, 3)
(651, 628)

It's a grayscale image.

But with:

image = Image.new('RGB', (512, 512), color='red')
print(np.array(image).shape)

It evaluates to:

(512, 512, 3)

So, if you do:

- image = Image.new('RGB', (512, 512), color='red')
+ image = Image.new('RGB', (512, 512), color='red')
+ image = image.convert('L')

It works as expected.

@vladmandic
Copy link
Contributor Author

right. i was checking

image = cv2.Canny(image, threshold1=100, threshold2=200)
image = Image.fromarray(image)
print(image)

and it showed RGB, i guess that's PIL reporting problem.

its still a bit messy since all other models (including controlnet canny and sketch models) work with RGB, only t2iadapter canny and sketch work with L. but ok, that would be more of an enhancement request than actual issue - closing this as it works-as-designed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants