Skip to content

Use .expand() instead of .repeat() to avoid unnecessary memory allocation in scatter_reduce_ index construction #65

@Mulanss

Description

@Mulanss

🚀 Suggestion: Use .expand() Instead of .repeat() in scatter_reduce_ Index Construction

Hi team 👋,

In hi_diffusers/models/moe.py at line 153, the following line is used to construct the index tensor for scatter_reduce_:

exp_token_idx.view(-1, 1).repeat(1, x.shape[-1])

This can be safely replaced with .expand(-1, x.shape[-1]), which produces the same result but avoids the unnecessary memory duplication caused by repeat(). Since this index tensor is not modified after creation, .expand() provides a more memory- and compute-efficient way to achieve the same broadcasting effect.


✅ Suggested Change

exp_token_idx.view(-1, 1).expand(-1, x.shape[-1])

✅ Benefits

  • 🧠 Reduces memory usage (especially when x.shape[-1] is large)
  • 🚀 Slightly improves performance (no data duplication)
  • 🧩 Same behavior and output as before
  • 📦 Cleaner and more efficient indexing logic

Note: This is safe because the tensor is only used as a read-only index, and expand() creates a broadcasted view without copying data.


Thanks for your excellent work on HiDream‑I1! 🙌

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions