Hello and thank you for the great project.
I am trying to run wan_t2v_inference.sh, but I encountered the following issue.
in radial_attn/attn_mask.py line 327, query simple has 3 dimensions.
batch_size, seq_len, num_head, hidden_dim = query.shape
Before calling this, following rearrange is done:
|
query = rearrange(query, "b s h d -> (b s) h d") |
and passed to radial attention, so it has 3 dimensions.
hidden_states = RadialAttention(
query=query, key=key, value=value, mask_map=self.mask_map, sparsity_type="dense", block_size=128, decay_factor=self.decay_factor, model_type="wan", pre_defined_mask=None, use_sage_attention=self.use_sage_attention
)
If I simply set batch_size to 1 and do seq_len, num_head, hidden_dim = query.shape, it works but then I get errors afterwards.