Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion monai/networks/blocks/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def __init__(
self.upsample_mode = upsample_mode
self.conv2d_type = conv2d_type
self.out_channels = out_channels
resnet = models.resnet50(pretrained=pretrained, progress=progress)
resnet = models.resnet50(
progress=progress, weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None
)

self.conv1 = resnet.conv1
self.bn0 = resnet.bn1
Expand Down
7 changes: 3 additions & 4 deletions monai/networks/nets/milmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
import torch.nn as nn

from monai.utils.module import optional_import
from monai.utils import optional_import

models, _ = optional_import("torchvision.models")

Expand Down Expand Up @@ -48,7 +48,6 @@ class MILModel(nn.Module):
Defaults to ``None`` (necessary only when using a custom backbone)
trans_blocks: number of the blocks in `TransformEncoder` layer.
trans_dropout: dropout rate in `TransformEncoder` layer.

"""

def __init__(
Expand All @@ -74,7 +73,7 @@ def __init__(
self.transformer: nn.Module | None = None

if backbone is None:
net = models.resnet50(pretrained=pretrained)
net = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)
nfc = net.fc.in_features # save the number of final features
net.fc = torch.nn.Identity() # remove final linear layer

Expand All @@ -99,7 +98,7 @@ def hook(module, input, output):
torch_model = getattr(models, backbone, None)
if torch_model is None:
raise ValueError("Unknown torch vision model" + str(backbone))
net = torch_model(pretrained=pretrained)
net = torch_model(weights="DEFAULT" if pretrained else None)

if getattr(net, "fc", None) is not None:
nfc = net.fc.in_features # save the number of final features
Expand Down
11 changes: 5 additions & 6 deletions monai/networks/nets/torchvision_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,11 @@ def __init__(
weights=None,
**kwargs,
):
if weights is not None:
model = getattr(models, model_name)(weights=weights, **kwargs)
elif pretrained:
model = getattr(models, model_name)(weights="DEFAULT", **kwargs)
else:
model = getattr(models, model_name)(weights=None, **kwargs)
# if pretrained is False, weights is a weight tensor or None for no pretrained loading
if pretrained and weights is None:
weights = "DEFAULT"

model = getattr(models, model_name)(weights=weights, **kwargs)

super().__init__(
model=model,
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pep8-naming
pycodestyle
pyflakes
black>=25.1.0
isort>=5.1, <6.0
isort>=5.1, !=6.0.0
ruff
pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows"
types-setuptools
Expand Down
2 changes: 1 addition & 1 deletion tests/networks/nets/test_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_pretrain_consistency(self, model, input_param, input_shape):
net = model(**input_param).to(device)
with eval_mode(net):
result = net.features.forward(example)
torchvision_net = torchvision.models.densenet121(pretrained=True).to(device)
torchvision_net = torchvision.models.densenet121(weights="DEFAULT").to(device)
with eval_mode(torchvision_net):
expected_result = torchvision_net.features.forward(example)
self.assertTrue(torch.all(result == expected_result))
Expand Down
10 changes: 5 additions & 5 deletions tests/networks/nets/test_milmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@
TEST_CASE_MILMODEL.append(test_case)

# torchvision backbone
TEST_CASE_MILMODEL.append(
[{"num_classes": 5, "backbone": "resnet18", "pretrained": False}, (2, 2, 3, 512, 512), (2, 5)]
)
TEST_CASE_MILMODEL.append([{"num_classes": 5, "backbone": "resnet18", "pretrained": True}, (2, 2, 3, 512, 512), (2, 5)])
for pretrained in [True, False]:
TEST_CASE_MILMODEL.append(
[{"num_classes": 5, "backbone": "resnet18", "pretrained": pretrained}, (2, 2, 3, 512, 512), (2, 5)]
)

# custom backbone
backbone = models.densenet121(pretrained=False)
backbone = models.densenet121()
backbone_nfeatures = backbone.classifier.in_features
backbone.classifier = torch.nn.Identity()
TEST_CASE_MILMODEL.append(
Expand Down
Loading