diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c0fa031b9faf..21046f61f3ad 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -262,6 +262,11 @@ class QwenDoubleStreamAttnProcessor2_0: """ Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor implements joint attention computation where text and image streams are processed together. + + Args: + encoder_hidden_states_mask (`torch.BoolTensor`, *optional*): + Boolean mask for text padding tokens. Shape: `[batch_size, text_seq_len]`. `True` indicates tokens that + should be attended to, `False` masks out padding tokens. Only boolean masks are supported. """ _attention_backend = None @@ -278,7 +283,7 @@ def __call__( attn: Attention, hidden_states: torch.FloatTensor, # Image stream encoder_hidden_states: torch.FloatTensor = None, # Text stream - encoder_hidden_states_mask: torch.FloatTensor = None, + encoder_hidden_states_mask: Optional[torch.BoolTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: @@ -330,6 +335,32 @@ def __call__( joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) + # Convert encoder_hidden_states_mask to 2D attention mask if provided. + if encoder_hidden_states_mask is not None and attention_mask is None: + batch_size = hidden_states.shape[0] + image_seq_len = hidden_states.shape[1] + text_seq_len = encoder_hidden_states.shape[1] + + if encoder_hidden_states_mask.shape[0] != batch_size: + raise ValueError( + f"encoder_hidden_states_mask batch size ({encoder_hidden_states_mask.shape[0]}) " + f"must match hidden_states batch size ({batch_size})" + ) + if encoder_hidden_states_mask.shape[1] != text_seq_len: + raise ValueError( + f"encoder_hidden_states_mask sequence length ({encoder_hidden_states_mask.shape[1]}) " + f"must match encoder_hidden_states sequence length ({text_seq_len})" + ) + + text_attention_mask = encoder_hidden_states_mask.bool() + image_attention_mask = torch.ones( + (batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device + ) + + joint_attention_mask_1d = torch.cat([text_attention_mask, image_attention_mask], dim=1) + # broadcastable shape for SDPA + attention_mask = joint_attention_mask_1d[:, None, None, :] + # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, @@ -630,7 +661,15 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + # Use padded sequence length for RoPE when mask is present. + # The attention mask will handle excluding padding tokens. + if encoder_hidden_states_mask is not None: + txt_seq_lens_for_rope = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0] + else: + txt_seq_lens_for_rope = ( + txt_seq_lens if txt_seq_lens is not None else [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0] + ) + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens_for_rope, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index b24fa90503ef..352037aa0534 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -91,6 +91,124 @@ def test_gradient_checkpointing_is_applied(self): expected_set = {"QwenImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_attention_mask_with_padding(self): + """Test that encoder_hidden_states_mask properly handles padded sequences.""" + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device).eval() + + batch_size = 2 + height = width = 4 + num_latent_channels = embedding_dim = 16 + text_seq_len = 7 + vae_scale_factor = 4 + + # Create inputs with padding + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, text_seq_len, embedding_dim)).to(torch_device) + + # First sample: 5 real tokens, 2 padding + # Second sample: 3 real tokens, 4 padding + encoder_hidden_states_mask = torch.tensor( + [[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0]], dtype=torch.long + ).to(torch_device) + + # Zero out padding in embeddings + encoder_hidden_states = encoder_hidden_states * encoder_hidden_states_mask.unsqueeze(-1).float() + + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size + txt_seq_lens = encoder_hidden_states_mask.sum(dim=1).tolist() + + inputs_with_mask = { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + # Run with proper mask + with torch.no_grad(): + output_with_mask = model(**inputs_with_mask).sample + + # Run with all-ones mask (treating padding as real tokens) + inputs_without_mask = { + "hidden_states": hidden_states.clone(), + "encoder_hidden_states": encoder_hidden_states.clone(), + "encoder_hidden_states_mask": torch.ones_like(encoder_hidden_states_mask), + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": [text_seq_len] * batch_size, + } + + with torch.no_grad(): + output_without_mask = model(**inputs_without_mask).sample + + # Outputs should differ when mask is applied correctly + diff = (output_with_mask - output_without_mask).abs().mean().item() + assert diff > 1e-5, f"Mask appears to be ignored (diff={diff})" + + def test_attention_mask_padding_isolation(self): + """Test that changing padding content doesn't affect output when mask is used.""" + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device).eval() + + batch_size = 2 + height = width = 4 + num_latent_channels = embedding_dim = 16 + text_seq_len = 7 + vae_scale_factor = 4 + + # Create inputs + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, text_seq_len, embedding_dim)).to(torch_device) + encoder_hidden_states_mask = torch.tensor( + [[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0]], dtype=torch.long + ).to(torch_device) + + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size + txt_seq_lens = encoder_hidden_states_mask.sum(dim=1).tolist() + + inputs1 = { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + with torch.no_grad(): + output1 = model(**inputs1).sample + + # Modify padding content with large noise + encoder_hidden_states2 = encoder_hidden_states.clone() + mask = encoder_hidden_states_mask.unsqueeze(-1).float() + noise = torch.randn_like(encoder_hidden_states2) * 10.0 + encoder_hidden_states2 = encoder_hidden_states2 + noise * (1 - mask) + + inputs2 = { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states2, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + with torch.no_grad(): + output2 = model(**inputs2).sample + + # Outputs should be nearly identical (padding is masked out) + diff = (output1 - output2).abs().mean().item() + assert diff < 1e-4, f"Padding content affected output (diff={diff})" + class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel