Skip to content

[Fix] Fix bc breaking from mmdet detr-refactor branch #825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: 1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,51 +27,44 @@
out_channels=256,
kernel_size=1,
act_cfg=None),
encoder=dict(
num_layers=6,
layer_cfg=dict( # DetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True)))),
decoder=dict(
num_layers=6,
layer_cfg=dict( # DetrTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1,
batch_first=True),
cross_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True))),
return_intermediate=False),
num_queries=1,
positional_encoding=dict(num_feats=128, normalize=True),
head=dict(
type='StarkHead',
num_querys=1,
transformer=dict(
type='StarkTransformer',
encoder=dict(
type='mmdet.DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
attn_drop=0.1,
dropout_layer=dict(type='Dropout', drop_prob=0.1))
],
ffn_cfgs=dict(
feedforward_channels=2048,
embed_dims=256,
ffn_drop=0.1),
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='mmdet.DetrTransformerDecoder',
return_intermediate=False,
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
attn_drop=0.1,
dropout_layer=dict(type='Dropout', drop_prob=0.1)),
ffn_cfgs=dict(
feedforward_channels=2048,
embed_dims=256,
ffn_drop=0.1),
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm'))),
),
positional_encoding=dict(
type='mmdet.SinePositionalEncoding', num_feats=128,
normalize=True),
bbox_head=dict(
type='CornerPredictorHead',
inplanes=256,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
checkpoint='torchvision://resnet101')),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmdetection/v2.0/'
'mask2former/mask2former_r101_lsj_8x2_50e_coco/'
'mask2former_r101_lsj_8x2_50e_coco_20220426_100250-c50b6fa6.pth'))
checkpoint='https://download.openmmlab.com/mmdetection/v3.0/'
'mask2former/mask2former_r101_8xb2-lsj-50e_coco-panoptic'
'/mask2former_r101_8xb2-lsj-50e_coco-'
'panoptic_20220329_225104-c74d4d71.pth'))
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
checkpoint='torchvision://resnet101')),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmdetection/v2.0/'
'mask2former/mask2former_r101_lsj_8x2_50e_coco/'
'mask2former_r101_lsj_8x2_50e_coco_20220426_100250-c50b6fa6.pth'))
checkpoint='https://download.openmmlab.com/mmdetection/v3.0/'
'mask2former/mask2former_r101_8xb2-lsj-50e_coco-panoptic'
'/mask2former_r101_8xb2-lsj-50e_'
'coco-panoptic_20220329_225104-c74d4d71.pth'))
60 changes: 22 additions & 38 deletions configs/vis/mask2former/mask2former_r50_8xb2-8e_youtubevis2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,64 +39,47 @@
num_outs=3,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='mmdet.DetrTransformerEncoder',
encoder=dict( # DeformableDetrTransformerEncoder
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=128,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
ffn_cfgs=dict(
type='FFN',
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.0,
act_cfg=dict(type='ReLU', inplace=True)),
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
init_cfg=None),
positional_encoding=dict(
type='mmdet.SinePositionalEncoding',
num_feats=128,
normalize=True),
act_cfg=dict(type='ReLU', inplace=True)))),
positional_encoding=dict(num_feats=128, normalize=True),
init_cfg=None),
enforce_decoder_input_project=False,
positional_encoding=dict(
type='SinePositionalEncoding3D', num_feats=128, normalize=True),
transformer_decoder=dict(
type='mmdet.DetrTransformerDecoder',
return_intermediate=True,
num_layers=9,
transformerlayers=dict(
type='mmdet.DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
dropout=0.0,
batch_first=True),
cross_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.0,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True),
feedforward_channels=2048,
operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
'ffn', 'norm')),
act_cfg=dict(type='ReLU', inplace=True))),
init_cfg=None),
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
Expand Down Expand Up @@ -138,9 +121,10 @@
sampler=dict(type='mmdet.MaskPseudoSampler'))),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmdetection/v2.0/'
'mask2former/mask2former_r50_lsj_8x2_50e_coco/'
'mask2former_r50_lsj_8x2_50e_coco_20220506_191028-8e96e88b.pth'))
checkpoint='https://download.openmmlab.com/mmdetection/v3.0/'
'mask2former/mask2former_r50_8xb2-lsj-50e_coco-panoptic/'
'mask2former_r50_8xb2-lsj-50e_'
'coco-panoptic_20230114_094547-7add5fa8.pth'))

# optimizer
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmdetection/v2.0/mask2former/'
'mask2former_swin-l-p4-w12-384-in21k_lsj_16x1_100e_coco-panoptic/'
'mask2former_swin-l-p4-w12-384-in21k_lsj_16x1_100e_coco-panoptic_'
'20220407_104949-d4919c44.pth'))
'https://download.openmmlab.com/mmdetection/v3.0'
'/mask2former/mask2former_swin-l-p4-w12-384-'
'in21k_16xb1-lsj-100e_coco-panoptic'
'/mask2former_swin-l-p4-w12-384-in21k_16xb1-lsj-100e'
'_coco-panoptic_20220407_104949-82f8d28d.pth'))

# set all layers in backbone to lr_mult=0.1
# set all norm layers, position_embeding,
Expand Down
117 changes: 114 additions & 3 deletions mmtrack/models/sot/stark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

import torch
import torch.nn.functional as F
from mmdet.models.layers import (DetrTransformerDecoder,
DetrTransformerEncoder,
SinePositionalEncoding)
from mmdet.structures.bbox.transforms import bbox_xyxy_to_cxcywh
from torch import Tensor
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.conv import _ConvNd

Expand Down Expand Up @@ -43,7 +46,11 @@ class Stark(BaseSingleObjectTracker):
def __init__(self,
backbone: dict,
neck: Optional[dict] = None,
head: Optional[dict] = None,
encoder: OptConfigType = None,
decoder: OptConfigType = None,
head: OptConfigType = None,
positional_encoding: OptConfigType = None,
num_queries: int = 100,
pretrains: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
Expand Down Expand Up @@ -71,6 +78,29 @@ def __init__(self,
if frozen_modules is not None:
self.freeze_module(frozen_modules)

self.encoder = encoder
self.decoder = decoder
self.positional_encoding = positional_encoding
self.num_queries = num_queries
self._init_layers()

def _init_layers(self) -> None:
"""Initialize layers except for backbone, neck and bbox_head."""
self.positional_encoding = SinePositionalEncoding(
**self.positional_encoding)
self.encoder = DetrTransformerEncoder(**self.encoder)
self.decoder = DetrTransformerDecoder(**self.decoder)
self.embed_dims = self.encoder.embed_dims
# NOTE The embed_dims is typically passed from the inside out.
# For example in DETR, The embed_dims is passed as
# self_attn -> the first encoder layer -> encoder -> detector.
self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)

num_feats = self.positional_encoding.num_feats
assert num_feats * 2 == self.embed_dims, \
'embed_dims should be exactly 2 times of num_feats. ' \
f'Found {self.embed_dims} and {num_feats}.'

def init_weights(self):
"""Initialize the weights of modules in single object tracker."""
# We don't use the `init_weights()` function in BaseModule, since it
Expand All @@ -87,6 +117,11 @@ def init_weights(self):
if self.with_head:
self.head.init_weights()

for coder in self.encoder, self.decoder:
for p in coder.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)

def extract_feat(self, img: Tensor) -> Tensor:
"""Extract the features of the input image.

Expand Down Expand Up @@ -300,6 +335,82 @@ def loss(self, inputs: dict, data_samples: List[TrackDataSample],
x_dict = dict(feat=x_feat, mask=search_padding_mask[:, 0])
head_inputs.append(x_dict)

losses = self.head.loss(head_inputs, data_samples)
outs_dec, enc_mem = self.forward_transformer(head_inputs)
losses = self.head.loss(head_inputs, outs_dec, enc_mem, data_samples)

return losses

def forward_transformer(self, inputs):
# 1. preprocess inputs for transformer
all_inputs = []
for input in inputs:
feat = input['feat'][0]
feat_size = feat.shape[-2:]
mask = F.interpolate(
input['mask'][None].float(), size=feat_size).to(torch.bool)[0]
pos_embed = self.positional_encoding(mask)
all_inputs.append(dict(feat=feat, mask=mask, pos_embed=pos_embed))
all_inputs = self.head._merge_template_search(all_inputs)

# 2. forward transformer head
# outs_dec is in (1, bs, num_query, c) shape
# enc_mem is in (feats_flatten_len, bs, c) shape
outs_dec, enc_mem = self.transformer(
all_inputs['feat'].permute(1, 0, 2), all_inputs['mask'],
self.query_embedding.weight,
all_inputs['pos_embed'].permute(1, 0, 2))
return outs_dec, enc_mem

def transformer(self, x: Tensor, mask: Tensor, query_embed: Tensor,
pos_embed: Tensor) -> Tuple[Tensor, Tensor]:
"""Forward function for `StarkTransformer`.

The difference with transofrmer module in `MMCV` is the input shape.
The sizes of template feature maps and search feature maps are
different. Thus, we must flatten and concatenate them outside this
module. The `MMCV` flatten the input features inside tranformer module.

Args:
x (Tensor): Input query with shape (feats_flatten_len, bs, c)
where c = embed_dims.
mask (Tensor): The key_padding_mask used for encoder and decoder,
with shape (bs, feats_flatten_len).
query_embed (Tensor): The query embedding for decoder, with shape
(num_query, c).
pos_embed (Tensor): The positional encoding for encoder and
decoder, with shape (feats_flatten_len, bs, c).

Here, 'feats_flatten_len' = z_feat_h*z_feat_w*2 + \
x_feat_h*x_feat_w.
'z_feat_h' and 'z_feat_w' denote the height and width of the
template features respectively.
'x_feat_h' and 'x_feat_w' denote the height and width of search
features respectively.
Returns:
tuple[Tensor, Tensor]: results of decoder containing the following
tensor.
- out_dec: Output from decoder. If return_intermediate_dec \
is True, output has shape [num_dec_layers, bs,
num_query, embed_dims], else has shape [1, bs, \
num_query, embed_dims].
Here, return_intermediate_dec=False
- enc_mem: Output results from encoder, with shape \
(feats_flatten_len, bs, embed_dims).
"""
bs, _, _ = x.shape
query_embed = query_embed.unsqueeze(1).repeat(
bs, 1, 1) # [num_query, embed_dims] -> [num_query, bs, embed_dims]

enc_mem = self.encoder(
query=x, query_pos=pos_embed, key_padding_mask=mask)
target = torch.zeros_like(query_embed)
# out_dec: [num_dec_layers, num_query, bs, embed_dims]
out_dec = self.decoder(
query=target,
key=enc_mem,
value=enc_mem,
key_pos=pos_embed,
query_pos=query_embed,
key_padding_mask=mask)
enc_mem = enc_mem.permute(1, 0, 2)
return out_dec, enc_mem
Loading