Skip to content

Commit 4249bd2

Browse files
authored
[compile] workaround for compile error FakeTensor no op for builtin.- (#2891)
1 parent 10c31c0 commit 4249bd2

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

torchtune/models/llama4/_position_embeddings.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,14 @@ def forward(
179179
# tensor has shape [b, s, n_h, h_d // 2, 2]
180180
x_out = torch.stack(
181181
[
182-
xshaped[..., 0] * rope_cache[..., 0]
183-
- xshaped[..., 1] * rope_cache[..., 1],
184-
xshaped[..., 1] * rope_cache[..., 0]
185-
+ xshaped[..., 0] * rope_cache[..., 1],
182+
torch.sub(
183+
xshaped[..., 0] * rope_cache[..., 0],
184+
xshaped[..., 1] * rope_cache[..., 1],
185+
),
186+
torch.add(
187+
xshaped[..., 1] * rope_cache[..., 0],
188+
xshaped[..., 0] * rope_cache[..., 1],
189+
),
186190
],
187191
-1,
188192
)

0 commit comments

Comments
 (0)