Skip to content

Commit 357a5cf

Browse files
fix: Add and forward timm_model_kwargs to timm.create_model
1 parent 852f785 commit 357a5cf

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

src/transformers/models/omdet_turbo/configuration_omdet_turbo.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,7 @@ def __init__(
149149
# Init timm backbone with hardcoded values for BC
150150
timm_default_kwargs = {
151151
"out_indices": [1, 2, 3],
152-
"img_size": image_size,
153-
"always_partition": True,
152+
"timm_model_kwargs": {"img_size": image_size, "always_partition": True},
154153
}
155154
backbone_config, kwargs = consolidate_backbone_kwargs_to_config(
156155
backbone_config=backbone_config,
@@ -161,6 +160,16 @@ def __init__(
161160
**kwargs,
162161
)
163162

163+
if getattr(backbone_config, "model_type", None) == "timm_backbone" and not getattr(
164+
backbone_config, "timm_model_kwargs", None
165+
):
166+
timm_extra = {}
167+
for attr in ("img_size", "always_partition"):
168+
if hasattr(backbone_config, attr):
169+
timm_extra[attr] = getattr(backbone_config, attr)
170+
if timm_extra:
171+
backbone_config.timm_model_kwargs = timm_extra
172+
164173
if text_config is None:
165174
logger.info("`text_config` is `None`. Initializing the config with the default `clip_text_model`")
166175
text_config = CONFIG_MAPPING["clip_text_model"]()

src/transformers/models/timm_backbone/configuration_timm_backbone.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class TimmBackboneConfig(BackboneConfigMixin, PreTrainedConfig):
3131
Whether to output only the features or also the logits.
3232
freeze_batch_norm_2d (`bool`, *optional*, defaults to `False`):
3333
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`.
34+
timm_model_kwargs (`dict`, *optional*):
35+
Additional keyword arguments to pass to `timm.create_model` (e.g. `{"img_size": 640}`).
3436
3537
Example:
3638
```python
@@ -57,6 +59,7 @@ def __init__(
5759
out_indices=None,
5860
freeze_batch_norm_2d=False,
5961
output_stride=None,
62+
timm_model_kwargs=None,
6063
**kwargs,
6164
):
6265
self.backbone = backbone
@@ -65,6 +68,7 @@ def __init__(
6568
self.out_indices = out_indices if out_indices is not None else [-1]
6669
self.output_stride = output_stride
6770
self.freeze_batch_norm_2d = freeze_batch_norm_2d
71+
self.timm_model_kwargs = timm_model_kwargs if timm_model_kwargs is not None else {}
6872

6973
super().__init__(**kwargs)
7074

src/transformers/models/timm_backbone/modeling_timm_backbone.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, config, **kwargs):
4949
out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,)
5050
pretrained = kwargs.pop("pretrained", False)
5151
in_chans = kwargs.pop("in_chans", config.num_channels)
52+
timm_model_kwargs = getattr(config, "timm_model_kwargs", {})
5253

5354
backbone = timm.create_model(
5455
config.backbone,
@@ -58,6 +59,7 @@ def __init__(self, config, **kwargs):
5859
in_chans=in_chans,
5960
out_indices=out_indices,
6061
output_stride=config.output_stride,
62+
**timm_model_kwargs,
6163
**kwargs,
6264
)
6365

0 commit comments

Comments
 (0)