Skip to content

Commit c94c0bd

Browse files
committed
Add the comment for axis order change
1 parent c5c149c commit c94c0bd

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

examples/qualcomm/oss_scripts/llama/model/static_llama.py

+2
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def forward_sha(
106106
v_caches: Optional[List[torch.Tensor]] = None,
107107
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
108108
bsz, seq_len, _ = hidden_states.shape
109+
# In the HTP backend, the input axis order for the convolution operation is
110+
# more efficient with [1, 1, seq_len, dim] compared to [1, seq_len, 1, dim].
109111
hidden_states = torch.reshape(
110112
hidden_states, (bsz, seq_len, 1, self.dim)
111113
).transpose(1, 3)

0 commit comments

Comments
 (0)