Skip to content

Commit 0c21b16

Browse files
authored
Fix conversion with rtmdet-inst, vit, conformer (#2453)
* fix * fix scaled_dot_product_attention
1 parent 01a88be commit 0c21b16

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

mmdeploy/mmcv/ops/nms.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -610,19 +610,19 @@ def multiclass_nms__torchscript(boxes: Tensor,
610610
611611
Use batched_nms from torchvision instead of custom nms.
612612
"""
613-
assert not output_index, 'output_index is not supported on this backend.'
614613
# TODO: simplify inference for non-batch model
615614
from torchvision.ops import batched_nms
616615
batch_size = scores.shape[0]
617616
num_boxes = scores.shape[1]
618617
num_classes = scores.shape[2]
619618
box_per_cls = len(boxes.shape) == 4
620619
scores = torch.where(scores > score_threshold, scores, scores.new_zeros(1))
621-
620+
pre_topk_inds = None
622621
# pre-topk
623622
if pre_top_k > 0:
624623
max_scores, _ = scores.max(-1)
625624
_, topk_inds = max_scores.topk(pre_top_k)
625+
pre_topk_inds = topk_inds
626626
batch_inds = torch.arange(batch_size).view(-1, 1).long()
627627
boxes = boxes[batch_inds, topk_inds, ...]
628628
scores = scores[batch_inds, topk_inds, :]
@@ -646,10 +646,14 @@ def multiclass_nms__torchscript(boxes: Tensor,
646646

647647
keeps = torch.cat(keeps)
648648
scores = scores.permute(0, 2, 1)
649-
dets, labels = _select_nms_index(
650-
scores, boxes, keeps, batch_size, keep_top_k=keep_top_k)
651-
652-
return dets, labels
649+
return _select_nms_index(
650+
scores,
651+
boxes,
652+
keeps,
653+
batch_size,
654+
keep_top_k=keep_top_k,
655+
pre_inds=pre_topk_inds,
656+
output_index=output_index)
653657

654658

655659
class AscendBatchNMSOp(torch.autograd.Function):

mmdeploy/pytorch/functions/multi_head_attention_forward.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,29 @@ def _scaled_dot_product_attention__tensorrt(q: Tensor,
5353
**kwargs) -> Tuple[Tensor, Tensor]:
5454
"""Rewrite for custom ops."""
5555
return ScaledDotProductAttentionTRT.apply(q, k, v, attn_mask)
56+
57+
58+
@FUNCTION_REWRITER.register_rewriter(
59+
func_name='torch.nn.functional.scaled_dot_product_attention',
60+
backend=Backend.DEFAULT.value)
61+
def scaled_dot_product_attention__default(query,
62+
key,
63+
value,
64+
attn_mask=None,
65+
dropout_p=0.,
66+
scale=None,
67+
is_causal=False):
68+
"""Rewrite to export to onnx on torch>=2.0.0."""
69+
scale = scale or query.size(-1)**0.5
70+
if is_causal and attn_mask is not None:
71+
attn_mask = torch.ones(
72+
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
73+
if attn_mask is not None and attn_mask.dtype == torch.bool:
74+
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))
75+
76+
attn_weight = query @ key.transpose(-2, -1) / scale
77+
if attn_mask is not None:
78+
attn_weight += attn_mask
79+
attn_weight = torch.softmax(attn_weight, dim=-1)
80+
attn_weight = torch.dropout(attn_weight, dropout_p, True)
81+
return attn_weight @ value

0 commit comments

Comments
 (0)