Skip to content

Commit 052dbb4

Browse files
Streamlined Rearrange in SpatialAttentionBlock (#8130)
The Rearrange code failed dynamo export in 24.09 container: pytorch/pytorch#137629 While we can't still use dynamo export with TRT in 23.09, I also noticed that my workaround improved runtime by about 1 second end-to-end for 100 seconds run. ### Description Replaced einops Rearrange with reshape/transpose ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: Boris Fomitchev <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 35b3894 commit 052dbb4

File tree

1 file changed

+3
-20
lines changed

1 file changed

+3
-20
lines changed

monai/networks/blocks/spatialattention.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717
import torch.nn as nn
1818

1919
from monai.networks.blocks import SABlock
20-
from monai.utils import optional_import
21-
22-
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
2320

2421

2522
class SpatialAttentionBlock(nn.Module):
@@ -74,24 +71,10 @@ def __init__(
7471

7572
def forward(self, x: torch.Tensor):
7673
residual = x
77-
78-
if self.spatial_dims == 1:
79-
h = x.shape[2]
80-
rearrange_input = Rearrange("b c h -> b h c")
81-
rearrange_output = Rearrange("b h c -> b c h", h=h)
82-
if self.spatial_dims == 2:
83-
h, w = x.shape[2], x.shape[3]
84-
rearrange_input = Rearrange("b c h w -> b (h w) c")
85-
rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w)
86-
else:
87-
h, w, d = x.shape[2], x.shape[3], x.shape[4]
88-
rearrange_input = Rearrange("b c h w d -> b (h w d) c")
89-
rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d)
90-
74+
shape = x.shape
9175
x = self.norm(x)
92-
x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C
93-
76+
x = x.reshape(*shape[:2], -1).transpose(1, 2) # "b c h w d -> b (h w d) c"
9477
x = self.attn(x)
95-
x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim]
78+
x = x.transpose(1, 2).reshape(shape) # "b (h w d) c -> b c h w d"
9679
x = x + residual
9780
return x

0 commit comments

Comments
 (0)