From 72480ec913e64f8734bc697cc34caf5d058eff0c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 18 Mar 2022 12:14:25 +0100 Subject: [PATCH 1/2] xfail mobilnet norm layer test --- test/test_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_models.py b/test/test_models.py index 209f27209bf..9da41c64c43 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -400,6 +400,7 @@ def test_mobilenet_v2_residual_setting(): assert out.shape[-1] == 1000 +@pytest.mark.xfail(reason="See https://github.com/pytorch/vision/issues/5642") @pytest.mark.parametrize("model_fn", [models.mobilenet_v2, models.mobilenet_v3_large, models.mobilenet_v3_small]) def test_mobilenet_norm_layer(model_fn): model = model_fn() From b0018fc790e8f0513650519afdad5d29b22cbd37 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 18 Mar 2022 12:20:04 +0100 Subject: [PATCH 2/2] fix test --- test/test_models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 9da41c64c43..d657475bafb 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -400,14 +400,13 @@ def test_mobilenet_v2_residual_setting(): assert out.shape[-1] == 1000 -@pytest.mark.xfail(reason="See https://github.com/pytorch/vision/issues/5642") @pytest.mark.parametrize("model_fn", [models.mobilenet_v2, models.mobilenet_v3_large, models.mobilenet_v3_small]) def test_mobilenet_norm_layer(model_fn): model = model_fn() assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules()) def get_gn(num_channels): - return nn.GroupNorm(32, num_channels) + return nn.GroupNorm(1, num_channels) model = model_fn(norm_layer=get_gn) assert not (any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))