Skip to content

Commit c13dbd5

Browse files
fix attention mask pad check (#3531)
1 parent bde2cb5 commit c13dbd5

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,7 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None,
381381
return attention_mask
382382

383383
current_length: int = attention_mask.shape[-1]
384-
if current_length > target_length:
385-
# we *could* trim the mask with:
386-
# attention_mask = attention_mask[:,:target_length]
387-
# but this is weird enough that it's more likely to be a mistake than a shortcut
388-
raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).")
389-
elif current_length < target_length:
384+
if current_length != target_length:
390385
if attention_mask.device.type == "mps":
391386
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
392387
# Instead, we can manually construct the padding tensor.

0 commit comments

Comments
 (0)