Skip to content

Use real weight and image for classification model test and relaxing precision requirement for general model tests #7130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
dcdc8db
Relaxing test_models precision, revert #6380
YosuaMichael Jan 25, 2023
4e1cb65
Check test using cuda 11.7 instead
YosuaMichael Jan 25, 2023
edcb727
Use real weight and image for classification model and adjust precision
YosuaMichael Jan 26, 2023
ff950d7
Switch back to cuda 11.6
YosuaMichael Jan 26, 2023
c32b4ae
Relaxing fx test tolerance to 5e-5
YosuaMichael Jan 26, 2023
0e9fc37
Relaxing detection test and use float64 for flaky detection models
YosuaMichael Jan 26, 2023
ef6e11c
Merge branch 'main' into test/relaxing-precision
YosuaMichael Jan 26, 2023
e384ca0
Fix linter issue
YosuaMichael Jan 26, 2023
8d33c56
Fix to use real image for classification model
YosuaMichael Jan 27, 2023
a2ec9c1
Merge branch 'test/relaxing-precision' of github.com:YosuaMichael/vis…
YosuaMichael Jan 27, 2023
d991bc4
Mark maskrcnn_resnet50_fpn_v2 as flaky detection model
YosuaMichael Jan 27, 2023
7873c09
Fix vitc test and try not using pretrained weight for vit_h and reduc…
YosuaMichael Jan 27, 2023
ad83ef6
Merge branch 'main' into test/relaxing-precision
YosuaMichael Jan 27, 2023
ad99e28
Remove _get_image comment that list the model usage and change num_ex…
YosuaMichael Jan 27, 2023
5047b49
Merge branch 'test/relaxing-precision' of github.com:YosuaMichael/vis…
YosuaMichael Jan 27, 2023
7661ab4
Merge branch 'main' into test/relaxing-precision
YosuaMichael Jan 27, 2023
b6e83c1
Simplify slow model input_shape assignment
YosuaMichael Jan 27, 2023
60675be
Fix with ufmt format
YosuaMichael Jan 27, 2023
ab81f0e
Merge branch 'main' into test/relaxing-precision
YosuaMichael Jan 27, 2023
0b5bcec
Merge branch 'main' into test/relaxing-precision
NicolasHug Feb 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified test/expect/ModelTester.test_alexnet_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_convnext_base_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_convnext_large_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_convnext_small_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_convnext_tiny_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_densenet121_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_densenet161_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_densenet169_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_densenet201_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_b0_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_b1_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_b2_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_b3_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_b4_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_b5_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_b6_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_b7_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_v2_l_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_v2_m_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_efficientnet_v2_s_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_googlenet_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_inception_v3_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_maxvit_t_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_mnasnet0_5_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_mnasnet0_75_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_mnasnet1_0_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_mnasnet1_3_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_mobilenet_v2_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_mobilenet_v3_large_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_mobilenet_v3_small_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_x_16gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_x_1_6gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_x_32gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_x_3_2gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_x_400mf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_x_800mf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_x_8gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_y_128gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_y_16gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_y_1_6gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_y_32gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_y_3_2gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_y_400mf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_y_800mf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_regnet_y_8gf_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_resnet101_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_resnet152_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_resnet18_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_resnet34_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_resnet50_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_resnext101_32x8d_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_resnext101_64x4d_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_resnext50_32x4d_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_shufflenet_v2_x0_5_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_shufflenet_v2_x1_0_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_shufflenet_v2_x1_5_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_shufflenet_v2_x2_0_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_squeezenet1_0_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_squeezenet1_1_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_swin_b_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_swin_s_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_swin_t_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_swin_v2_b_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_swin_v2_s_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_swin_v2_t_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vgg11_bn_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vgg11_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vgg13_bn_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vgg13_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vgg16_bn_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vgg16_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vgg19_bn_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vgg19_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vit_b_16_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vit_b_32_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vit_h_14_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vit_l_16_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_vit_l_32_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_wide_resnet101_2_expect.pkl
Binary file not shown.
Binary file modified test/expect/ModelTester.test_wide_resnet50_2_expect.pkl
Binary file not shown.
99 changes: 59 additions & 40 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from common_utils import cpu_and_gpu, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
from PIL import Image
from torchvision import models, transforms
from torchvision.models import get_model_builder, list_models
from torchvision.models import get_model_builder, get_model_weights, get_weight, list_models


ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
Expand All @@ -29,20 +29,8 @@ def list_model_fns(module):
return [get_model_builder(name) for name in list_models(module)]


def _get_image(input_shape, real_image, device):
"""This routine loads a real or random image based on `real_image` argument.
Currently, the real image is utilized for the following list of models:
- `retinanet_resnet50_fpn`,
- `retinanet_resnet50_fpn_v2`,
- `keypointrcnn_resnet50_fpn`,
- `fasterrcnn_resnet50_fpn`,
- `fasterrcnn_resnet50_fpn_v2`,
- `fcos_resnet50_fpn`,
- `maskrcnn_resnet50_fpn`,
- `maskrcnn_resnet50_fpn_v2`,
in `test_classification_model` and `test_detection_model`.
To do so, a keyword argument `real_image` was added to the abovelisted models in `_model_params`
"""
def _get_image(input_shape, real_image, device, weights=None, dtype=None):
"""This routine loads a real or random image based on `real_image` argument."""
if real_image:
# TODO: Maybe unify file discovery logic with test_image.py
GRACE_HOPPER = os.path.join(
Expand All @@ -51,19 +39,26 @@ def _get_image(input_shape, real_image, device):

img = Image.open(GRACE_HOPPER)

original_width, original_height = img.size

# make the image square
img = img.crop((0, 0, original_width, original_width))
img = img.resize(input_shape[1:3])
if weights is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we just pass the weights all the time? What's the reason for having them in only some cases but not all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some cases the weight are really restrictive, for instance if we use vit_h_14, it will only accept the image_size of the size of the min_size of the weight: https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py#L321 and in this case we can't do the test with lower resolution with the weight.

Also as of now, we dont use real weight for detection model test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

, it will only accept the image_size of the size of the min_size of the weight: https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py#L321 and in this case we can't do the test with lower resolution with the weight

But isn't that a good thing? i.e. if we go below the min_size limit, wouldn't we expect the model to output garbage? And if not, why is the limit not lower?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For test purpose, we might want to use smaller image even if the output is garbage but we can still check for the consistentcy (what we did so far with random image and random weight). And in this case if we set weight=None then it will basically behave like before, the get_image will assume that the test dont use real weight but rather initialized with random weight.

original_width, original_height = img.size
# make the image square
img = img.crop((0, 0, original_width, original_width))
img = img.resize(input_shape[-2:])

convert_tensor = transforms.ToTensor()
image = convert_tensor(img)
convert_tensor = transforms.ToTensor()
image = convert_tensor(img)
else:
H, W = input_shape[-2:]
min_side = min(H, W)
preprocess = weights.transforms(resize_size=min_side, crop_size=min_side)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to pass parameters to the weights.transforms() , they will handle the size properly.

Copy link
Contributor Author

@YosuaMichael YosuaMichael Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this if we want to control the size when the test happened, otherwise we will rely on the default size on the weight transforms (In some big model, we would like to use smaller image size for the test to speed up runtime).

Note: For test purpose, I think it is okay not to use the preferred image size that will yield the best accuracy for the model.

image = preprocess(img)
if len(input_shape) > len(image.size()):
image = image.unsqueeze(0)
assert tuple(image.size()) == input_shape
return image.to(device=device)
return image.to(device=device, dtype=dtype)

# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
return torch.rand(input_shape).to(device=device)
return torch.rand(input_shape).to(device=device, dtype=dtype)


@pytest.fixture
Expand Down Expand Up @@ -195,7 +190,7 @@ def _check_fx_compatible(model, inputs, eager_out=None):
eager_out = model(inputs)
with torch.no_grad(), freeze_rng_state():
fx_out = model_fx(inputs)
torch.testing.assert_close(eager_out, fx_out)
torch.testing.assert_close(eager_out, fx_out, atol=5e-5, rtol=5e-5)


def _check_input_backprop(model, inputs):
Expand Down Expand Up @@ -278,11 +273,15 @@ def _check_input_backprop(model, inputs):
# tests under test_quantized_classification_model will be skipped for the following models.
quantized_flaky_models = ("inception_v3", "resnet50")

# The tests for the following detection models are flaky due to precision of float32
# we will do the test in float64 for these models
detection_flaky_models = ("keypointrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn_v2")


# The following contains configuration parameters for all models which are used by
# the _test_*_model methods.
_model_params = {
"inception_v3": {"input_shape": (1, 3, 299, 299), "init_weights": True},
"inception_v3": {"input_shape": (1, 3, 299, 299)},
"retinanet_resnet50_fpn": {
"num_classes": 20,
"score_thresh": 0.01,
Expand Down Expand Up @@ -354,6 +353,7 @@ def _check_input_backprop(model, inputs):
"vit_h_14": {
"image_size": 56,
"input_shape": (1, 3, 56, 56),
"weight_name": None,
},
"mvit_v1_b": {
"input_shape": (1, 3, 16, 224, 224),
Expand All @@ -364,7 +364,8 @@ def _check_input_backprop(model, inputs):
"s3d": {
"input_shape": (1, 3, 16, 224, 224),
},
"googlenet": {"init_weights": True},
"regnet_y_128gf": {"weight_name": "IMAGENET1K_SWAG_LINEAR_V1"},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just get the actual weights from the model name, using the helpers from https://pytorch.org/vision/main/models.html#model-registration-mechanism ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, I actually use the helper to get the actual weight in here.

I think I prefer this design where we dont need to specify the weight_enum for the weight_name (since it can be retrieve from the model_name). Also, it is easier to say that the default value that we use is IMAGENET1K_V1 for the test.

"vitc_b_16": {"weight_name": None},
}
# speeding up slow models:
slow_models = [
Expand All @@ -390,7 +391,7 @@ def _check_input_backprop(model, inputs):
"swin_v2_b",
]
for m in slow_models:
_model_params[m] = {"input_shape": (1, 3, 64, 64)}
_model_params[m] = dict(_model_params.get(m, dict()), **{"input_shape": (1, 3, 64, 64)})


# skip big models to reduce memory usage on CI test. We can exclude combinations of (platform-system, device).
Expand Down Expand Up @@ -648,6 +649,7 @@ def test_generalizedrcnn_transform_repr():


def vitc_b_16(**kwargs: Any):
kwargs.pop("weights", None)
return models.VisionTransformer(
image_size=224,
patch_size=16,
Expand All @@ -671,33 +673,46 @@ def test_vitc_models(model_fn, dev):
def test_classification_model(model_fn, dev):
set_rng_seed(0)
defaults = {
"num_classes": 50,
"num_classes": 1000,
"input_shape": (1, 3, 224, 224),
"num_classes_to_check": 50,
"real_image": True,
}
model_name = model_fn.__name__
if SKIP_BIG_MODEL and is_skippable(model_name, dev):
pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
kwargs = {**defaults, **_model_params.get(model_name, {})}
num_classes = kwargs.get("num_classes")
num_classes_to_check = kwargs.pop("num_classes_to_check")
input_shape = kwargs.pop("input_shape")
real_image = kwargs.pop("real_image", False)
weight_name = kwargs.pop("weight_name", "IMAGENET1K_V1")
weight = None
if weight_name is not None:
weight_enum = get_model_weights(model_name)
weight = get_weight(f"{weight_enum.__name__}.{weight_name}")

model = model_fn(**kwargs)
model = model_fn(weights=weight, **kwargs)
model.eval().to(device=dev)
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev)
out = model(x)
_assert_expected(out.cpu(), model_name, prec=1e-3)
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev, weights=weight)
with torch.no_grad(), freeze_rng_state():
out = model(x)
expect_out = out[:, :num_classes_to_check]
_assert_expected(expect_out.cpu(), model_name, prec=3e-2)
assert out.shape[-1] == num_classes
assert expect_out.shape[-1] == num_classes_to_check
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out)

if dev == "cuda":
with torch.cuda.amp.autocast():
with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state():
model.to(x.device)
out = model(x)
expect_out = out[:, :num_classes_to_check]
# See autocast_flaky_numerics comment at top of file.
if model_name not in autocast_flaky_numerics:
_assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == 50
_assert_expected(expect_out.cpu(), model_name, prec=0.1)
assert expect_out.shape[-1] == num_classes_to_check

_check_input_backprop(model, x)

Expand Down Expand Up @@ -777,13 +792,17 @@ def test_detection_model(model_fn, dev):
"input_shape": (3, 300, 300),
}
model_name = model_fn.__name__
if model_name in detection_flaky_models:
dtype = torch.float64
else:
dtype = torch.get_default_dtype()
kwargs = {**defaults, **_model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
real_image = kwargs.pop("real_image", False)

model = model_fn(**kwargs)
model.eval().to(device=dev)
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev)
model.eval().to(device=dev, dtype=dtype)
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev, dtype=dtype)
model_input = [x]
with torch.no_grad(), freeze_rng_state():
out = model(model_input)
Expand Down Expand Up @@ -818,7 +837,7 @@ def compute_mean_std(tensor):
return {"mean": mean, "std": std}

output = map_nested_tensor_object(out, tensor_map_fn=compact)
prec = 0.01
prec = 3e-2
try:
# We first try to assert the entire output if possible. This is not
# only the best way to assert results but also handles the cases
Expand Down Expand Up @@ -917,7 +936,7 @@ def test_video_model(model_fn, dev):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
_assert_expected(out.cpu(), model_name, prec=1e-5)
_assert_expected(out.cpu(), model_name, prec=3e-3)
assert out.shape[-1] == num_classes
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out)
Expand Down