Skip to content

Commit e8aeb80

Browse files
authored
[transformer] add Deformable DETR base code (PaddlePaddle#3718)
1 parent 283f5ac commit e8aeb80

File tree

6 files changed

+667
-12
lines changed

6 files changed

+667
-12
lines changed

ppdet/modeling/heads/detr_head.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
import paddle.nn.functional as F
2222
from ppdet.core.workspace import register
2323
import pycocotools.mask as mask_util
24-
from ..initializer import linear_init_
24+
from ..initializer import linear_init_, constant_
25+
from ..transformers.utils import inverse_sigmoid
2526

26-
__all__ = ['DETRHead']
27+
__all__ = ['DETRHead', 'DeformableDETRHead']
2728

2829

2930
class MLP(nn.Layer):
@@ -275,3 +276,77 @@ def forward(self, out_transformer, body_feats, inputs=None):
275276
gt_mask=gt_mask)
276277
else:
277278
return (outputs_bbox[-1], outputs_logit[-1], outputs_seg)
279+
280+
281+
@register
282+
class DeformableDETRHead(nn.Layer):
283+
__shared__ = ['num_classes', 'hidden_dim']
284+
__inject__ = ['loss']
285+
286+
def __init__(self,
287+
num_classes=80,
288+
hidden_dim=512,
289+
nhead=8,
290+
num_mlp_layers=3,
291+
loss='DETRLoss'):
292+
super(DeformableDETRHead, self).__init__()
293+
self.num_classes = num_classes
294+
self.hidden_dim = hidden_dim
295+
self.nhead = nhead
296+
self.loss = loss
297+
298+
self.score_head = nn.Linear(hidden_dim, self.num_classes)
299+
self.bbox_head = MLP(hidden_dim,
300+
hidden_dim,
301+
output_dim=4,
302+
num_layers=num_mlp_layers)
303+
304+
self._reset_parameters()
305+
306+
def _reset_parameters(self):
307+
linear_init_(self.score_head)
308+
constant_(self.score_head.bias, -4.595)
309+
constant_(self.bbox_head.layers[-1].weight)
310+
bias = paddle.zeros_like(self.bbox_head.layers[-1].bias)
311+
bias[2:] = -2.0
312+
self.bbox_head.layers[-1].bias.set_value(bias)
313+
314+
@classmethod
315+
def from_config(cls, cfg, hidden_dim, nhead, input_shape):
316+
return {'hidden_dim': hidden_dim, 'nhead': nhead}
317+
318+
def forward(self, out_transformer, body_feats, inputs=None):
319+
r"""
320+
Args:
321+
out_transformer (Tuple): (feats: [num_levels, batch_size,
322+
num_queries, hidden_dim],
323+
memory: [batch_size,
324+
\sum_{l=0}^{L-1} H_l \cdot W_l, hidden_dim],
325+
reference_points: [batch_size, num_queries, 2])
326+
body_feats (List(Tensor)): list[[B, C, H, W]]
327+
inputs (dict): dict(inputs)
328+
"""
329+
feats, memory, reference_points = out_transformer
330+
reference_points = inverse_sigmoid(reference_points.unsqueeze(0))
331+
outputs_bbox = self.bbox_head(feats)
332+
333+
# It's equivalent to "outputs_bbox[:, :, :, :2] += reference_points",
334+
# but the gradient is wrong in paddle.
335+
outputs_bbox = paddle.concat(
336+
[
337+
outputs_bbox[:, :, :, :2] + reference_points,
338+
outputs_bbox[:, :, :, 2:]
339+
],
340+
axis=-1)
341+
342+
outputs_bbox = F.sigmoid(outputs_bbox)
343+
outputs_logit = self.score_head(feats)
344+
345+
if self.training:
346+
assert inputs is not None
347+
assert 'gt_bbox' in inputs and 'gt_class' in inputs
348+
349+
return self.loss(outputs_bbox, outputs_logit, inputs['gt_bbox'],
350+
inputs['gt_class'])
351+
else:
352+
return (outputs_bbox[-1], outputs_logit[-1], None)

ppdet/modeling/post_process.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -532,12 +532,23 @@ def __call__(self, head_out, im_shape, scale_factor):
532532

533533
scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax(
534534
logits)[:, :, :-1]
535-
scores, labels = scores.max(-1), scores.argmax(-1)
536535

537-
if scores.shape[1] > self.num_top_queries:
538-
scores, index = paddle.topk(scores, self.num_top_queries, axis=-1)
539-
labels = paddle.stack(
540-
[paddle.gather(l, i) for l, i in zip(labels, index)])
536+
if not self.use_focal_loss:
537+
scores, labels = scores.max(-1), scores.argmax(-1)
538+
if scores.shape[1] > self.num_top_queries:
539+
scores, index = paddle.topk(
540+
scores, self.num_top_queries, axis=-1)
541+
labels = paddle.stack(
542+
[paddle.gather(l, i) for l, i in zip(labels, index)])
543+
bbox_pred = paddle.stack(
544+
[paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
545+
else:
546+
scores, index = paddle.topk(
547+
scores.reshape([logits.shape[0], -1]),
548+
self.num_top_queries,
549+
axis=-1)
550+
labels = index % logits.shape[2]
551+
index = index // logits.shape[2]
541552
bbox_pred = paddle.stack(
542553
[paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
543554

ppdet/modeling/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
from . import utils
1717
from . import matchers
1818
from . import position_encoding
19+
from . import deformable_transformer
1920

2021
from .detr_transformer import *
2122
from .utils import *
2223
from .matchers import *
2324
from .position_encoding import *
25+
from .deformable_transformer import *

0 commit comments

Comments
 (0)