-
Notifications
You must be signed in to change notification settings - Fork 32
Open
Description
Hi! In the function shrinkMaskStrict, line 77 and line 78 seem useless. What do these two lines of code do?
radial-attention/radial_attn/attn_mask.py
Lines 67 to 79 in 72788d4
| def shrinkMaskStrict(mask, block_size=128): | |
| seqlen = mask.shape[0] | |
| block_num = seqlen // block_size | |
| mask = mask[:block_num * block_size, :block_num * block_size].view(block_num, block_size, block_num, block_size) | |
| col_densities = mask.sum(dim = 1) / block_size | |
| # we want the minimum non-zero column density in the block | |
| non_zero_densities = col_densities > 0 | |
| high_density_cols = col_densities > 1/3 | |
| frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9) | |
| block_mask = frac_high_density_cols > 0.6 | |
| block_mask[0:0] = True | |
| block_mask[-1:-1] = True | |
| return block_mask |
Metadata
Metadata
Assignees
Labels
No labels