Skip to content

Commit ccc07bf

Browse files
committed
remove bug workaround fixed in huggingface/diffusers#12702, use split attention
update merge with attention selection, add FLASH_SPLIT
1 parent 83125e9 commit ccc07bf

18 files changed

+61
-32
lines changed

modules/model/QwenModel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ def encode_text(
177177
#pad to 16 because attention processors and/or torch.compile can have issues with uneven sequence lengths, but only pad if an attention mask has to be used anyway:
178178
#TODO the second condition could trigger https://github.com/pytorch/pytorch/issues/165506 again, but try like this because no attention mask
179179
#is preferable: https://github.com/Nerogar/OneTrainer/pull/1109
180-
if max_seq_length % 16 > 0 and (seq_lengths != max_seq_length).any():
181-
max_seq_length += (16 - max_seq_length % 16)
180+
if max_seq_length % 64 > 0 and (seq_lengths != max_seq_length).any():
181+
max_seq_length += (64 - max_seq_length % 64)
182182

183183
text_encoder_output = text_encoder_output[:, :max_seq_length, :]
184184
bool_attention_mask = tokens_mask[:, :max_seq_length].bool()

modules/modelSampler/QwenSampler.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,6 @@ def __sample_base(
9797
if "generator" in set(inspect.signature(noise_scheduler.step).parameters.keys()):
9898
extra_step_kwargs["generator"] = generator #TODO purpose?
9999

100-
#txt_seq_lens = text_attention_mask.sum(dim=1).tolist()
101-
txt_seq_lens = [text_attention_mask.shape[1]] * text_attention_mask.shape[0]
102-
103100
#FIXME list of lists is not according to type hint, but according to diffusers code
104101
#https://github.com/huggingface/diffusers/issues/12295
105102
img_shapes = [[(
@@ -110,25 +107,15 @@ def __sample_base(
110107

111108
self.model.transformer_to(self.train_device)
112109

113-
#FIXME bug workaround for https://github.com/huggingface/diffusers/issues/12294
114-
image_seq_len = latent_image.shape[1]
115-
image_attention_mask=torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=latent_image.device)
116-
attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1)
117-
attention_mask_2d = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
118-
119110
for i, timestep in enumerate(tqdm(timesteps, desc="sampling")):
120111
latent_model_input = torch.cat([latent_image] * batch_size)
121112
expanded_timestep = timestep.expand(batch_size)
122113
noise_pred = transformer(
123114
hidden_states=latent_model_input.to(dtype=self.model.train_dtype.torch_dtype()),
124115
timestep=expanded_timestep / 1000,
125116
encoder_hidden_states=combined_prompt_embedding.to(dtype=self.model.train_dtype.torch_dtype()),
126-
encoder_hidden_states_mask=text_attention_mask,
127-
txt_seq_lens=txt_seq_lens,
117+
encoder_hidden_states_mask=text_attention_mask if not torch.all(text_attention_mask) else None,
128118
img_shapes=img_shapes,
129-
attention_kwargs = {
130-
"attention_mask": attention_mask_2d,
131-
},
132119
return_dict=True
133120
).sample
134121

modules/modelSetup/BaseChromaSetup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def setup_optimizations(
8686
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
8787
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
8888

89+
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)
90+
8991
def _setup_embeddings(
9092
self,
9193
model: ChromaModel,

modules/modelSetup/BaseFluxSetup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def setup_optimizations(
8989
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
9090
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
9191

92+
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=False)
93+
9294
def _setup_embeddings(
9395
self,
9496
model: FluxModel,

modules/modelSetup/BaseHiDreamSetup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def setup_optimizations(
105105
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
106106
quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype, config)
107107

108+
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)
109+
108110
def _setup_embeddings(
109111
self,
110112
model: HiDreamModel,

modules/modelSetup/BaseHunyuanVideoSetup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def setup_optimizations(
9191
quantize_layers(model.transformer, self.train_device, model.transformer_train_dtype, config)
9292

9393
model.vae.enable_tiling()
94+
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)
95+
9496

9597
def _setup_embeddings(
9698
self,

modules/modelSetup/BaseModelSetup.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from modules.model.BaseModel import BaseModel
55
from modules.util.config.TrainConfig import TrainConfig, TrainEmbeddingConfig, TrainModelPartConfig
6+
from modules.util.enum.AttentionMechanism import AttentionMechanism
67
from modules.util.enum.TrainingMethod import TrainingMethod
78
from modules.util.ModuleFilter import ModuleFilter
89
from modules.util.NamedParameterGroup import NamedParameterGroup, NamedParameterGroupCollection
@@ -235,3 +236,21 @@ def _setup_model_part_requires_grad(
235236
if unique_name in self.frozen_parameters:
236237
for param in self.frozen_parameters[unique_name]:
237238
param.requires_grad_(False)
239+
240+
@staticmethod
241+
def _set_attention_backend(component, attn: AttentionMechanism, mask: bool=False, varlen: bool=False):
242+
match attn:
243+
case AttentionMechanism.SDP:
244+
component.set_attention_backend("native")
245+
case AttentionMechanism.FLASH:
246+
if mask or varlen:
247+
print("Warning: FLASH attention might fail for this model, depending on other configuration (batch size > 1, etc.)")
248+
component.set_attention_backend("flash")
249+
case AttentionMechanism.SPLIT:
250+
component.set_attention_backend("native_split")
251+
case AttentionMechanism.FLASH_SPLIT:
252+
component.set_attention_backend("flash_split")
253+
if mask and not varlen:
254+
print("Warning: FLASH attention might fail for this model, depending on other configuration (batch size > 1, etc.)")
255+
case _:
256+
raise NotImplementedError(f"attention mechanism {str(attn)} not implemented")

modules/modelSetup/BasePixArtAlphaSetup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def setup_optimizations(
8585
quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config)
8686
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
8787
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
88+
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)
8889

8990
def _setup_embeddings(
9091
self,

modules/modelSetup/BaseQwenSetup.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def setup_optimizations(
8080
quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config)
8181
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
8282
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
83+
self._set_attention_backend(model.transformer, config.attention_mechanism, varlen=True)
8384

8485
def predict(
8586
self,
@@ -131,11 +132,6 @@ def predict(
131132
latent_input = scaled_noisy_latent_image
132133
packed_latent_input = model.pack_latents(latent_input)
133134

134-
#FIXME this is the only case that the transformer accepts:
135-
#see https://github.com/huggingface/diffusers/issues/12344
136-
#actual text sequence lengths can be shorter,but they might be padded and masked
137-
txt_seq_lens = [text_encoder_output.shape[1]] * text_encoder_output.shape[0]
138-
139135
#FIXME list of lists is not according to type hint, but according to diffusers code:
140136
#https://github.com/huggingface/diffusers/issues/12295
141137
img_shapes = [[(
@@ -144,21 +140,12 @@ def predict(
144140
latent_input.shape[-1] // 2)
145141
]] * latent_input.shape[0]
146142

147-
#FIXME bug workaround for https://github.com/huggingface/diffusers/issues/12294
148-
image_attention_mask=torch.ones((packed_latent_input.shape[0], packed_latent_input.shape[1]), dtype=torch.bool, device=latent_image.device)
149-
attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1)
150-
attention_mask_2d = attention_mask[:, None, None, :] if not torch.all(text_attention_mask) else None
151-
152143
packed_predicted_flow = model.transformer(
153144
hidden_states=packed_latent_input.to(dtype=model.train_dtype.torch_dtype()),
154145
timestep=timestep / 1000,
155146
encoder_hidden_states=text_encoder_output.to(dtype=model.train_dtype.torch_dtype()),
156-
encoder_hidden_states_mask=text_attention_mask,
157-
txt_seq_lens=txt_seq_lens,
147+
encoder_hidden_states_mask=text_attention_mask if not torch.all(text_attention_mask) else None,
158148
img_shapes=img_shapes,
159-
attention_kwargs = {
160-
"attention_mask": attention_mask_2d,
161-
},
162149
return_dict=True,
163150
).sample
164151

modules/modelSetup/BaseSanaSetup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def setup_optimizations(
9797
quantize_layers(model.text_encoder, self.train_device, model.text_encoder_train_dtype, config)
9898
quantize_layers(model.vae, self.train_device, model.train_dtype, config)
9999
quantize_layers(model.transformer, self.train_device, model.train_dtype, config)
100+
self._set_attention_backend(model.transformer, config.attention_mechanism, mask=True)
100101

101102
def _setup_embeddings(
102103
self,

0 commit comments

Comments
 (0)