diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index d662de8078a..3178a81b52c 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -151,7 +151,7 @@ def mobilenet_backbone( # find the index of the layer from which we wont freeze assert 0 <= trainable_layers <= num_stages - freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] + freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] for b in backbone[:freeze_before]: for parameter in b.parameters(): diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 6695167d462..b6e620dba4a 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -510,7 +510,7 @@ def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained # find the index of the layer from which we wont freeze assert 0 <= trainable_layers <= num_stages - freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] + freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] for b in backbone[:freeze_before]: for parameter in b.parameters(): diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 30e4383ed5d..26378d7038d 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -147,7 +147,7 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t # find the index of the layer from which we wont freeze assert 0 <= trainable_layers <= num_stages - freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] + freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] for b in backbone[:freeze_before]: for parameter in b.parameters():