Skip to content

Commit 8b3346c

Browse files
authored
add kwargs to _scaled_dot_product_attention__tensorrt (#2332)
1 parent 8a59bae commit 8b3346c

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

mmdeploy/pytorch/functions/multi_head_attention_forward.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,11 @@ def symbolic(g, q, k, v, mask):
4545
@FUNCTION_REWRITER.register_rewriter(
4646
func_name='torch.nn.functional._scaled_dot_product_attention',
4747
backend=Backend.TENSORRT.value)
48-
def _scaled_dot_product_attention__tensorrt(
49-
q: Tensor,
50-
k: Tensor,
51-
v: Tensor,
52-
attn_mask: Optional[Tensor] = None,
53-
dropout_p: float = 0.0,
54-
) -> Tuple[Tensor, Tensor]:
48+
def _scaled_dot_product_attention__tensorrt(q: Tensor,
49+
k: Tensor,
50+
v: Tensor,
51+
attn_mask: Optional[Tensor] = None,
52+
dropout_p: float = 0.0,
53+
**kwargs) -> Tuple[Tensor, Tensor]:
5554
"""Rewrite for custom ops."""
5655
return ScaledDotProductAttentionTRT.apply(q, k, v, attn_mask)

0 commit comments

Comments
 (0)