Skip to content

Commit cf3b2ca

Browse files
committed
Fixing vgg test
1 parent 04a0803 commit cf3b2ca

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchvision/models/vgg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ
9797

9898
def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
9999
if weights is not None:
100-
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
100+
if weights.meta["categories"] is not None:
101+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
101102
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
102103
if weights is not None:
103104
model.load_state_dict(weights.get_state_dict(progress=progress))

0 commit comments

Comments
 (0)