-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Open
Description
Reason of this issue in really big models, which are more than 60GB. So diffusers tries to put all of them to GPU VRAM.
Now there are couple ways to fix it.
First one is to add this line of code to your script:
pipe.enable_sequential_cpu_offload()
You will now be able start your scripts, bit it will be kinda slow.
Second way is to quantize your models. Here I write the examples of code for different ways of using with different models:
# This one is for using with Flux.1-dev for generating images
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
model_id = "black-forest-labs/FLUX.1-dev"
nf4_id = "sayakpaul/flux.1-dev-nf4-with-bnb-integration"
model_nf4 = FluxTransformer2DModel.from_pretrained(nf4_id, torch_dtype=torch.bfloat16)
print(model_nf4.dtype)
print(model_nf4.config.quantization_config)
pipe = FluxPipeline.from_pretrained(model_id, transformer=model_nf4, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = "A mystic cat with a sign that says hello world!"
image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
image.save("flux-nf4-dev-loaded.png")
# this one for upscaling images with jasperai/Flux.1-dev-Controlnet-Upscaler
import torch
from diffusers.utils import load_image
from diffusers import FluxControlNetModel, BitsAndBytesConfig, FluxTransformer2DModel
from diffusers.pipelines import FluxControlNetPipeline
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
controlnet = FluxControlNetModel.from_pretrained(
"jasperai/Flux.1-dev-Controlnet-Upscaler",
quantization_config=nf4_config,
)
model_id = "black-forest-labs/FLUX.1-dev"
nf4_id = "sayakpaul/flux.1-dev-nf4-with-bnb-integration"
model_nf4 = FluxTransformer2DModel.from_pretrained(nf4_id, torch_dtype=torch.float16)
pipe = FluxControlNetPipeline.from_pretrained(
model_id,
transformer=model_nf4,
torch_dtype=torch.float16,
controlnet=controlnet
)
pipe.enable_model_cpu_offload()
control_image = load_image(
"image.jpg"
)
image = pipe(
prompt="",
control_image=control_image,
controlnet_conditioning_scale=0.6,
num_inference_steps=28,
guidance_scale=3.5,
height=control_image.size[1],
width=control_image.size[0]
).images[0]
image.save("upscaled_img_quanted.png")
For this solutions we must to say thank you to @sayakpaul
budaLi
Metadata
Metadata
Assignees
Labels
No labels