Skip to content

Commit 043df07

Browse files
committed
Run ruff, setup initial text to image node
1 parent f4f5c46 commit 043df07

File tree

15 files changed

+330
-155
lines changed

15 files changed

+330
-155
lines changed

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
import torch
2-
3-
4-
from einops import repeat
5-
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
62
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
73

84
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
95
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
106
from invokeai.app.invocations.model import CLIPField, T5EncoderField
117
from invokeai.app.invocations.primitives import ConditioningOutput
128
from invokeai.app.services.shared.invocation_context import InvocationContext
13-
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
14-
from invokeai.backend.util.devices import TorchDevice
159
from invokeai.backend.flux.modules.conditioner import HFEncoder
10+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
1611

1712

1813
@invocation(
Lines changed: 78 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
from typing import Literal
2-
31
import torch
4-
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
5-
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
6-
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
2+
from einops import rearrange, repeat
73
from PIL import Image
8-
from transformers.models.auto import AutoModelForTextEncoding
94

105
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
116
from invokeai.app.invocations.fields import (
@@ -19,20 +14,11 @@
1914
from invokeai.app.invocations.model import TransformerField, VAEField
2015
from invokeai.app.invocations.primitives import ImageOutput
2116
from invokeai.app.services.shared.invocation_context import InvocationContext
22-
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
23-
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
17+
from invokeai.backend.flux.model import Flux
18+
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
19+
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
2420
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
25-
26-
TFluxModelKeys = Literal["flux-schnell"]
27-
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
28-
29-
30-
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
31-
base_class = FluxTransformer2DModel
32-
33-
34-
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
35-
auto_class = AutoModelForTextEncoding
21+
from invokeai.backend.util.devices import TorchDevice
3622

3723

3824
@invocation(
@@ -75,7 +61,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
7561
assert isinstance(flux_conditioning, FLUXConditioningInfo)
7662

7763
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
78-
image = self._run_vae_decoding(context, latents)
64+
image = self._run_vae_decoding(context, flux_ae_path, latents)
7965
image_dto = context.images.save(image=image)
8066
return ImageOutput.build(image_dto)
8167

@@ -86,42 +72,79 @@ def _run_diffusion(
8672
t5_embeddings: torch.Tensor,
8773
):
8874
transformer_info = context.models.load(self.transformer.transformer)
75+
inference_dtype = TorchDevice.choose_torch_dtype()
76+
77+
# Prepare input noise.
78+
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
79+
# CPU RNG?
80+
x = get_noise(
81+
num_samples=1,
82+
height=self.height,
83+
width=self.width,
84+
device=TorchDevice.choose_torch_device(),
85+
dtype=inference_dtype,
86+
seed=self.seed,
87+
)
88+
89+
img, img_ids = self._prepare_latent_img_patches(x)
90+
91+
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
92+
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
93+
timesteps = get_schedule(
94+
num_steps=self.num_steps,
95+
image_seq_len=img.shape[1],
96+
shift=not is_schnell,
97+
)
98+
99+
bs, t5_seq_len, _ = t5_embeddings.shape
100+
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
89101

90102
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
91103
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
92104
# if the cache is not empty.
93-
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
105+
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
94106

95107
with transformer_info as transformer:
96-
assert isinstance(transformer, FluxTransformer2DModel)
97-
98-
flux_pipeline_with_transformer = FluxPipeline(
99-
scheduler=scheduler,
100-
vae=None,
101-
text_encoder=None,
102-
tokenizer=None,
103-
text_encoder_2=None,
104-
tokenizer_2=None,
105-
transformer=transformer,
108+
assert isinstance(transformer, Flux)
109+
110+
x = denoise(
111+
model=transformer,
112+
img=img,
113+
img_ids=img_ids,
114+
txt=t5_embeddings,
115+
txt_ids=txt_ids,
116+
vec=clip_embeddings,
117+
timesteps=timesteps,
118+
guidance=self.guidance,
106119
)
107120

108-
t5_embeddings = t5_embeddings.to(dtype=transformer.dtype)
109-
clip_embeddings = clip_embeddings.to(dtype=transformer.dtype)
121+
x = unpack(x.float(), self.height, self.width)
122+
123+
return x
110124

111-
latents = flux_pipeline_with_transformer(
112-
height=self.height,
113-
width=self.width,
114-
num_inference_steps=self.num_steps,
115-
guidance_scale=self.guidance,
116-
generator=torch.Generator().manual_seed(self.seed),
117-
prompt_embeds=t5_embeddings,
118-
pooled_prompt_embeds=clip_embeddings,
119-
output_type="latent",
120-
return_dict=False,
121-
)[0]
125+
def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
126+
"""Convert an input image in latent space to patches for diffusion.
122127
123-
assert isinstance(latents, torch.Tensor)
124-
return latents
128+
This implementation was extracted from:
129+
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
130+
131+
Returns:
132+
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
133+
"""
134+
bs, c, h, w = latent_img.shape
135+
136+
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
137+
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
138+
if img.shape[0] == 1 and bs > 1:
139+
img = repeat(img, "1 ... -> bs ...", bs=bs)
140+
141+
# Generate patch position ids.
142+
img_ids = torch.zeros(h // 2, w // 2, 3)
143+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
144+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
145+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
146+
147+
return img, img_ids
125148

126149
def _run_vae_decoding(
127150
self,
@@ -130,27 +153,13 @@ def _run_vae_decoding(
130153
) -> Image.Image:
131154
vae_info = context.models.load(self.vae.vae)
132155
with vae_info as vae:
133-
assert isinstance(vae, AutoencoderKL)
134-
135-
flux_pipeline_with_vae = FluxPipeline(
136-
scheduler=None,
137-
vae=vae,
138-
text_encoder=None,
139-
tokenizer=None,
140-
text_encoder_2=None,
141-
tokenizer_2=None,
142-
transformer=None,
143-
)
156+
assert isinstance(vae, AutoEncoder)
157+
# TODO(ryand): Test that this works with both float16 and bfloat16.
158+
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
159+
img = vae.decode(latents)
144160

145-
latents = flux_pipeline_with_vae._unpack_latents(
146-
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
147-
)
148-
latents = (
149-
latents / flux_pipeline_with_vae.vae.config.scaling_factor
150-
) + flux_pipeline_with_vae.vae.config.shift_factor
151-
latents = latents.to(dtype=vae.dtype)
152-
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
153-
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
154-
155-
assert isinstance(image, Image.Image)
156-
return image
161+
img.clamp(-1, 1)
162+
img = rearrange(img[0], "c h w -> h w c")
163+
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
164+
165+
return img_pil

invokeai/app/invocations/model.py

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
from time import sleep
3-
from typing import List, Optional, Literal, Dict
3+
from typing import Dict, List, Literal, Optional
44

55
from pydantic import BaseModel, Field
66

@@ -12,10 +12,10 @@
1212
invocation_output,
1313
)
1414
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
15+
from invokeai.app.services.model_records import ModelRecordChanges
1516
from invokeai.app.services.shared.invocation_context import InvocationContext
1617
from invokeai.app.shared.models import FreeUConfig
17-
from invokeai.app.services.model_records import ModelRecordChanges
18-
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType, ModelFormat
18+
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
1919

2020

2121
class ModelIdentifierField(BaseModel):
@@ -132,31 +132,22 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
132132

133133
return ModelIdentifierOutput(model=self.model)
134134

135-
T5_ENCODER_OPTIONS = Literal["base", "16b_quantized", "8b_quantized"]
135+
136+
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
136137
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
137138
"base": {
138-
"text_encoder_repo": "black-forest-labs/FLUX.1-schnell::text_encoder_2",
139-
"tokenizer_repo": "black-forest-labs/FLUX.1-schnell::tokenizer_2",
140-
"text_encoder_name": "FLUX.1-schnell_text_encoder_2",
141-
"tokenizer_name": "FLUX.1-schnell_tokenizer_2",
139+
"repo": "invokeai/flux_dev::t5_xxl_encoder/base",
140+
"name": "t5_base_encoder",
142141
"format": ModelFormat.T5Encoder,
143142
},
144143
"8b_quantized": {
145-
"text_encoder_repo": "hf_repo1",
146-
"tokenizer_repo": "hf_repo1",
147-
"text_encoder_name": "hf_repo1",
148-
"tokenizer_name": "hf_repo1",
149-
"format": ModelFormat.T5Encoder8b,
150-
},
151-
"4b_quantized": {
152-
"text_encoder_repo": "hf_repo2",
153-
"tokenizer_repo": "hf_repo2",
154-
"text_encoder_name": "hf_repo2",
155-
"tokenizer_name": "hf_repo2",
156-
"format": ModelFormat.T5Encoder8b,
144+
"repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized",
145+
"name": "t5_8b_quantized_encoder",
146+
"format": ModelFormat.T5Encoder,
157147
},
158148
}
159149

150+
160151
@invocation_output("flux_model_loader_output")
161152
class FluxModelLoaderOutput(BaseInvocationOutput):
162153
"""Flux base model loader output"""
@@ -176,7 +167,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
176167
ui_type=UIType.FluxMainModel,
177168
input=Input.Direct,
178169
)
179-
170+
180171
t5_encoder: T5_ENCODER_OPTIONS = InputField(description="The T5 Encoder model to use.")
181172

182173
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
@@ -189,7 +180,15 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
189180
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
190181
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
191182
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
192-
vae = self._install_model(context, SubModelType.VAE, "FLUX.1-schnell_ae", "black-forest-labs/FLUX.1-schnell::ae.safetensors", ModelFormat.Checkpoint, ModelType.VAE, BaseModelType.Flux)
183+
vae = self._install_model(
184+
context,
185+
SubModelType.VAE,
186+
"FLUX.1-schnell_ae",
187+
"black-forest-labs/FLUX.1-schnell::ae.safetensors",
188+
ModelFormat.Checkpoint,
189+
ModelType.VAE,
190+
BaseModelType.Flux,
191+
)
193192

194193
return FluxModelLoaderOutput(
195194
transformer=TransformerField(transformer=transformer),
@@ -198,33 +197,59 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
198197
vae=VAEField(vae=vae),
199198
)
200199

201-
def _get_model(self, context: InvocationContext, submodel:SubModelType) -> ModelIdentifierField:
202-
match(submodel):
200+
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
201+
match submodel:
203202
case SubModelType.Transformer:
204203
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
205204
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
206-
return self._install_model(context, submodel, "clip-vit-large-patch14", "openai/clip-vit-large-patch14", ModelFormat.Diffusers, ModelType.CLIPEmbed, BaseModelType.Any)
207-
case SubModelType.TextEncoder2:
208-
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["text_encoder_name"], T5_ENCODER_MAP[self.t5_encoder]["text_encoder_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
209-
case SubModelType.Tokenizer2:
210-
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["tokenizer_name"], T5_ENCODER_MAP[self.t5_encoder]["tokenizer_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
205+
return self._install_model(
206+
context,
207+
submodel,
208+
"clip-vit-large-patch14",
209+
"openai/clip-vit-large-patch14",
210+
ModelFormat.Diffusers,
211+
ModelType.CLIPEmbed,
212+
BaseModelType.Any,
213+
)
214+
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
215+
return self._install_model(
216+
context,
217+
submodel,
218+
T5_ENCODER_MAP[self.t5_encoder]["name"],
219+
T5_ENCODER_MAP[self.t5_encoder]["repo"],
220+
ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]),
221+
ModelType.T5Encoder,
222+
BaseModelType.Any,
223+
)
211224
case _:
212-
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
213-
214-
def _install_model(self, context: InvocationContext, submodel:SubModelType, name: str, repo_id: str, format: ModelFormat, type: ModelType, base: BaseModelType):
215-
if (models := context.models.search_by_attrs(name=name, base=base, type=type)):
225+
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
226+
227+
def _install_model(
228+
self,
229+
context: InvocationContext,
230+
submodel: SubModelType,
231+
name: str,
232+
repo_id: str,
233+
format: ModelFormat,
234+
type: ModelType,
235+
base: BaseModelType,
236+
):
237+
if models := context.models.search_by_attrs(name=name, base=base, type=type):
216238
if len(models) != 1:
217239
raise Exception(f"Multiple models detected for selected model with name {name}")
218240
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
219241
else:
220242
model_path = context.models.download_and_cache_model(repo_id)
221-
config = ModelRecordChanges(name = name, base = base, type=type, format=format)
243+
config = ModelRecordChanges(name=name, base=base, type=type, format=format)
222244
model_install_job = context.models.import_local_model(model_path=model_path, config=config)
223245
while not model_install_job.in_terminal_state:
224246
sleep(0.01)
225247
if not model_install_job.config_out:
226248
raise Exception(f"Failed to install {name}")
227-
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(update={"submodel_type": submodel})
249+
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(
250+
update={"submodel_type": submodel}
251+
)
252+
228253

229254
@invocation(
230255
"main_model_loader",

invokeai/app/services/model_records/model_records_sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def search_by_attr(
301301
for row in result:
302302
try:
303303
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
304-
except pydantic.ValidationError as e:
304+
except pydantic.ValidationError:
305305
# We catch this error so that the app can still run if there are invalid model configs in the database.
306306
# One reason that an invalid model config might be in the database is if someone had to rollback from a
307307
# newer version of the app that added a new model type.

0 commit comments

Comments
 (0)