Skip to content

Commit 94ce638

Browse files
change: Address review
1 parent 414ee30 commit 94ce638

File tree

3 files changed

+13
-14
lines changed

3 files changed

+13
-14
lines changed

src/transformers/models/omdet_turbo/configuration_omdet_turbo.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def __init__(
149149
# Init timm backbone with hardcoded values for BC
150150
timm_default_kwargs = {
151151
"out_indices": [1, 2, 3],
152-
"timm_model_kwargs": {"img_size": image_size, "always_partition": True},
152+
"img_size": image_size,
153+
"always_partition": True,
153154
}
154155
backbone_config, kwargs = consolidate_backbone_kwargs_to_config(
155156
backbone_config=backbone_config,
@@ -160,16 +161,11 @@ def __init__(
160161
**kwargs,
161162
)
162163

163-
# BC: hub configs with backbone_kwargs pass img_size/always_partition as direct attrs
164-
if getattr(backbone_config, "model_type", None) == "timm_backbone" and not getattr(
165-
backbone_config, "timm_model_kwargs", None
166-
):
167-
timm_extra = {}
164+
timm_kwargs = {}
165+
if getattr(backbone_config, "model_type", None) == "timm_backbone":
168166
for attr in ("img_size", "always_partition"):
169167
if hasattr(backbone_config, attr):
170-
timm_extra[attr] = getattr(backbone_config, attr)
171-
if timm_extra:
172-
backbone_config.timm_model_kwargs = timm_extra
168+
timm_kwargs[attr] = getattr(backbone_config, attr)
173169

174170
if text_config is None:
175171
logger.info("`text_config` is `None`. Initializing the config with the default `clip_text_model`")
@@ -222,8 +218,14 @@ def __init__(
222218
self.eval_size = eval_size
223219
self.learn_initial_query = learn_initial_query
224220
self.cache_size = cache_size
221+
self.timm_kwargs = timm_kwargs
225222

226223
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
227224

225+
def to_dict(self):
226+
output = super().to_dict()
227+
output.pop("timm_kwargs", None)
228+
return output
229+
228230

229231
__all__ = ["OmDetTurboConfig"]

src/transformers/models/omdet_turbo/modeling_omdet_turbo.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
from ... import initialization as init
2828
from ...activations import ACT2CLS, ACT2FN
29-
from ...backbone_utils import load_backbone
3029
from ...integrations import use_kernel_forward_from_hub
3130
from ...masking_utils import create_bidirectional_mask
3231
from ...modeling_layers import GradientCheckpointingLayer
@@ -39,7 +38,7 @@
3938
logging,
4039
torch_compilable_check,
4140
)
42-
from ..auto import AutoModel
41+
from ..auto import AutoBackbone, AutoModel
4342
from .configuration_omdet_turbo import OmDetTurboConfig
4443

4544

@@ -279,7 +278,7 @@ class OmDetTurboVisionBackbone(nn.Module):
279278
def __init__(self, config: OmDetTurboConfig):
280279
super().__init__()
281280
self.apply_layernorm_after_vision_backbone = config.apply_layernorm_after_vision_backbone
282-
self.vision_backbone = load_backbone(config)
281+
self.vision_backbone = AutoBackbone.from_config(config.backbone_config, **getattr(config, "timm_kwargs", {}))
283282
self.layer_norms = nn.ModuleList(
284283
[nn.LayerNorm(in_channel_dim, eps=config.layer_norm_eps) for in_channel_dim in config.encoder_in_channels]
285284
)

src/transformers/models/timm_backbone/modeling_timm_backbone.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def __init__(self, config, **kwargs):
4949
out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,)
5050
pretrained = kwargs.pop("pretrained", False)
5151
in_chans = kwargs.pop("in_chans", config.num_channels)
52-
timm_model_kwargs = getattr(config, "timm_model_kwargs", {})
5352

5453
backbone = timm.create_model(
5554
config.backbone,
@@ -59,7 +58,6 @@ def __init__(self, config, **kwargs):
5958
in_chans=in_chans,
6059
out_indices=out_indices,
6160
output_stride=config.output_stride,
62-
**timm_model_kwargs,
6361
**kwargs,
6462
)
6563

0 commit comments

Comments
 (0)