Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mmdeploy/codebase/mmcls/models/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def base_classifier__forward(
if self.head is not None:
output = self.head(output)

from mmcls.models.heads import MultiLabelClsHead
from mmcls.models.heads import ConformerHead, MultiLabelClsHead
if isinstance(self.head, MultiLabelClsHead):
output = torch.sigmoid(output)
elif isinstance(self.head, ConformerHead):
output = F.softmax(torch.add(output[0], output[1]), dim=1)
else:
output = F.softmax(output, dim=1)
return output