Skip to content

Test some flaky detection models on float64 instead of float32 #7204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def list_model_fns(module):
return [get_model_builder(name) for name in list_models(module)]


def _get_image(input_shape, real_image, device):
def _get_image(input_shape, real_image, device, dtype=None):
"""This routine loads a real or random image based on `real_image` argument.
Currently, the real image is utilized for the following list of models:
- `retinanet_resnet50_fpn`,
Expand Down Expand Up @@ -60,10 +60,10 @@ def _get_image(input_shape, real_image, device):
convert_tensor = transforms.ToTensor()
image = convert_tensor(img)
assert tuple(image.size()) == input_shape
return image.to(device=device)
return image.to(device=device, dtype=dtype)

# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
return torch.rand(input_shape).to(device=device)
return torch.rand(input_shape).to(device=device, dtype=dtype)


@pytest.fixture
Expand Down Expand Up @@ -278,6 +278,11 @@ def _check_input_backprop(model, inputs):
# tests under test_quantized_classification_model will be skipped for the following models.
quantized_flaky_models = ("inception_v3", "resnet50")

# The tests for the following detection models are flaky.
# We run those tests on float64 to avoid floating point errors.
# FIXME: we shouldn't have to do that :'/
detection_flaky_models = ("keypointrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn_v2")


# The following contains configuration parameters for all models which are used by
# the _test_*_model methods.
Expand Down Expand Up @@ -777,13 +782,17 @@ def test_detection_model(model_fn, dev):
"input_shape": (3, 300, 300),
}
model_name = model_fn.__name__
if model_name in detection_flaky_models:
dtype = torch.float64
else:
dtype = torch.get_default_dtype()
kwargs = {**defaults, **_model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
real_image = kwargs.pop("real_image", False)

model = model_fn(**kwargs)
model.eval().to(device=dev)
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev)
model.eval().to(device=dev, dtype=dtype)
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev, dtype=dtype)
model_input = [x]
with torch.no_grad(), freeze_rng_state():
out = model(model_input)
Expand Down