|
21 | 21 | import paddle.nn.functional as F |
22 | 22 | from ppdet.core.workspace import register |
23 | 23 | import pycocotools.mask as mask_util |
24 | | -from ..initializer import linear_init_ |
| 24 | +from ..initializer import linear_init_, constant_ |
| 25 | +from ..transformers.utils import inverse_sigmoid |
25 | 26 |
|
26 | | -__all__ = ['DETRHead'] |
| 27 | +__all__ = ['DETRHead', 'DeformableDETRHead'] |
27 | 28 |
|
28 | 29 |
|
29 | 30 | class MLP(nn.Layer): |
@@ -275,3 +276,77 @@ def forward(self, out_transformer, body_feats, inputs=None): |
275 | 276 | gt_mask=gt_mask) |
276 | 277 | else: |
277 | 278 | return (outputs_bbox[-1], outputs_logit[-1], outputs_seg) |
| 279 | + |
| 280 | + |
| 281 | +@register |
| 282 | +class DeformableDETRHead(nn.Layer): |
| 283 | + __shared__ = ['num_classes', 'hidden_dim'] |
| 284 | + __inject__ = ['loss'] |
| 285 | + |
| 286 | + def __init__(self, |
| 287 | + num_classes=80, |
| 288 | + hidden_dim=512, |
| 289 | + nhead=8, |
| 290 | + num_mlp_layers=3, |
| 291 | + loss='DETRLoss'): |
| 292 | + super(DeformableDETRHead, self).__init__() |
| 293 | + self.num_classes = num_classes |
| 294 | + self.hidden_dim = hidden_dim |
| 295 | + self.nhead = nhead |
| 296 | + self.loss = loss |
| 297 | + |
| 298 | + self.score_head = nn.Linear(hidden_dim, self.num_classes) |
| 299 | + self.bbox_head = MLP(hidden_dim, |
| 300 | + hidden_dim, |
| 301 | + output_dim=4, |
| 302 | + num_layers=num_mlp_layers) |
| 303 | + |
| 304 | + self._reset_parameters() |
| 305 | + |
| 306 | + def _reset_parameters(self): |
| 307 | + linear_init_(self.score_head) |
| 308 | + constant_(self.score_head.bias, -4.595) |
| 309 | + constant_(self.bbox_head.layers[-1].weight) |
| 310 | + bias = paddle.zeros_like(self.bbox_head.layers[-1].bias) |
| 311 | + bias[2:] = -2.0 |
| 312 | + self.bbox_head.layers[-1].bias.set_value(bias) |
| 313 | + |
| 314 | + @classmethod |
| 315 | + def from_config(cls, cfg, hidden_dim, nhead, input_shape): |
| 316 | + return {'hidden_dim': hidden_dim, 'nhead': nhead} |
| 317 | + |
| 318 | + def forward(self, out_transformer, body_feats, inputs=None): |
| 319 | + r""" |
| 320 | + Args: |
| 321 | + out_transformer (Tuple): (feats: [num_levels, batch_size, |
| 322 | + num_queries, hidden_dim], |
| 323 | + memory: [batch_size, |
| 324 | + \sum_{l=0}^{L-1} H_l \cdot W_l, hidden_dim], |
| 325 | + reference_points: [batch_size, num_queries, 2]) |
| 326 | + body_feats (List(Tensor)): list[[B, C, H, W]] |
| 327 | + inputs (dict): dict(inputs) |
| 328 | + """ |
| 329 | + feats, memory, reference_points = out_transformer |
| 330 | + reference_points = inverse_sigmoid(reference_points.unsqueeze(0)) |
| 331 | + outputs_bbox = self.bbox_head(feats) |
| 332 | + |
| 333 | + # It's equivalent to "outputs_bbox[:, :, :, :2] += reference_points", |
| 334 | + # but the gradient is wrong in paddle. |
| 335 | + outputs_bbox = paddle.concat( |
| 336 | + [ |
| 337 | + outputs_bbox[:, :, :, :2] + reference_points, |
| 338 | + outputs_bbox[:, :, :, 2:] |
| 339 | + ], |
| 340 | + axis=-1) |
| 341 | + |
| 342 | + outputs_bbox = F.sigmoid(outputs_bbox) |
| 343 | + outputs_logit = self.score_head(feats) |
| 344 | + |
| 345 | + if self.training: |
| 346 | + assert inputs is not None |
| 347 | + assert 'gt_bbox' in inputs and 'gt_class' in inputs |
| 348 | + |
| 349 | + return self.loss(outputs_bbox, outputs_logit, inputs['gt_bbox'], |
| 350 | + inputs['gt_class']) |
| 351 | + else: |
| 352 | + return (outputs_bbox[-1], outputs_logit[-1], None) |
0 commit comments