Skip to content

Commit 61763fa

Browse files
szagoruykofmassa
authored andcommitted
fix for loading models with num_batches_tracked in frozen bn (#1728)
1 parent be6dd47 commit 61763fa

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

torchvision/ops/misc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,16 @@ def __init__(self, n):
130130
self.register_buffer("running_mean", torch.zeros(n))
131131
self.register_buffer("running_var", torch.ones(n))
132132

133+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
134+
missing_keys, unexpected_keys, error_msgs):
135+
num_batches_tracked_key = prefix + 'num_batches_tracked'
136+
if num_batches_tracked_key in state_dict:
137+
del state_dict[num_batches_tracked_key]
138+
139+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
140+
state_dict, prefix, local_metadata, strict,
141+
missing_keys, unexpected_keys, error_msgs)
142+
133143
def forward(self, x):
134144
# move reshapes to the beginning
135145
# to make it fuser-friendly

0 commit comments

Comments
 (0)