From a330fe01d73d6a1e265e8289e71a04b5e37579fd Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 6 Jun 2025 11:16:39 +0200 Subject: [PATCH] update --- .../transformers/transformer_hidream_image.py | 69 +++++++++++++++---- 1 file changed, 54 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 77902dcf5852..a4cca05a8397 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -6,9 +7,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...models.modeling_outputs import Transformer2DModelOutput from ...models.modeling_utils import ModelMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention from ..embeddings import TimestepEmbedding, Timesteps @@ -17,6 +17,29 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +@dataclass +class HiDreamImageModelOutput(BaseOutput): + sample: torch.Tensor + double_blocks_auxiliary_loss: Optional[Tuple[torch.Tensor, ...]] = None + single_blocks_auxiliary_loss: Optional[Tuple[torch.Tensor, ...]] = None + + +class AddAuxiliaryLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x, loss): + assert loss.numel() == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + return grad_output, grad_loss + + class HiDreamImageFeedForwardSwiGLU(nn.Module): def __init__( self, @@ -332,7 +355,6 @@ def forward(self, hidden_states): else: mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) ce = mask_ce.float().mean(0) - Pi = scores_for_aux.mean(0) fi = ce * self.n_routed_experts aux_loss = (Pi * fi).sum() * self.alpha @@ -379,11 +401,11 @@ def forward(self, x): y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = y.view(*orig_shape).to(dtype=wtype) - # y = AddAuxiliaryLoss.apply(y, aux_loss) + y = AddAuxiliaryLoss.apply(y, aux_loss) else: y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) y = y + self.shared_experts(identity) - return y + return y, aux_loss @torch.no_grad() def moe_infer(self, x, flat_expert_indices, flat_expert_weights): @@ -481,9 +503,10 @@ def forward( # 2. Feed-forward norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype) norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i - ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype)) + ff_output_i, aux_loss = self.ff_i(norm_hidden_states.to(dtype=wtype)) + ff_output_i = gate_mlp_i * ff_output_i hidden_states = ff_output_i + hidden_states - return hidden_states + return hidden_states, aux_loss @maybe_allow_in_graph @@ -573,11 +596,12 @@ def forward( norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype) norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t - ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states) + ff_output_i, aux_loss = self.ff_i(norm_hidden_states) + ff_output_i = gate_mlp_i * ff_output_i ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states) hidden_states = ff_output_i + hidden_states encoder_hidden_states = ff_output_t + encoder_hidden_states - return hidden_states, encoder_hidden_states + return hidden_states, encoder_hidden_states, aux_loss class HiDreamBlock(nn.Module): @@ -785,6 +809,7 @@ def forward( hidden_states_masks: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, + return_auxiliary_loss: bool = False, **kwargs, ): encoder_hidden_states = kwargs.get("encoder_hidden_states", None) @@ -866,15 +891,19 @@ def forward( # 2. Blocks block_id = 0 + double_blocks_aux_losses = [] + single_blocks_aux_losses = [] + initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] + for bid, block in enumerate(self.double_stream_blocks): cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] cur_encoder_hidden_states = torch.cat( [initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1 ) if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func( + hidden_states, initial_encoder_hidden_states, aux_loss = self._gradient_checkpointing_func( block, hidden_states, hidden_states_masks, @@ -883,7 +912,7 @@ def forward( image_rotary_emb, ) else: - hidden_states, initial_encoder_hidden_states = block( + hidden_states, initial_encoder_hidden_states, aux_loss = block( hidden_states=hidden_states, hidden_states_masks=hidden_states_masks, encoder_hidden_states=cur_encoder_hidden_states, @@ -891,6 +920,7 @@ def forward( image_rotary_emb=image_rotary_emb, ) initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] + double_blocks_aux_losses.append(aux_loss) block_id += 1 image_tokens_seq_len = hidden_states.shape[1] @@ -908,7 +938,7 @@ def forward( cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + hidden_states, aux_loss = self._gradient_checkpointing_func( block, hidden_states, hidden_states_masks, @@ -917,7 +947,7 @@ def forward( image_rotary_emb, ) else: - hidden_states = block( + hidden_states, aux_loss = block( hidden_states=hidden_states, hidden_states_masks=hidden_states_masks, encoder_hidden_states=None, @@ -925,6 +955,7 @@ def forward( image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states[:, :hidden_states_seq_len] + single_blocks_aux_losses.append(aux_loss) block_id += 1 hidden_states = hidden_states[:, :image_tokens_seq_len, ...] @@ -938,5 +969,13 @@ def forward( unscale_lora_layers(self, lora_scale) if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) + return_values = (output,) + if return_auxiliary_loss: + return_values += (double_blocks_aux_losses, single_blocks_aux_losses) + return return_values + + return HiDreamImageModelOutput( + sample=output, + double_blocks_auxiliary_loss=double_blocks_aux_losses, + single_blocks_auxiliary_loss=single_blocks_aux_losses, + )