Skip to content

Commit 69ce452

Browse files
authored
Validate against expected files on videos (#6077)
* Validate against expected files on videos * Plus tests for autocast
1 parent 3a2631b commit 69ce452

4 files changed

+16
-4
lines changed
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.

test/test_models.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -822,24 +822,36 @@ def test_detection_model_validation(model_fn):
822822
@pytest.mark.parametrize("model_fn", get_models_from_module(models.video))
823823
@pytest.mark.parametrize("dev", cpu_and_gpu())
824824
def test_video_model(model_fn, dev):
825+
set_rng_seed(0)
825826
# the default input shape is
826827
# bs * num_channels * clip_len * h *w
827-
input_shape = (1, 3, 4, 112, 112)
828+
defaults = {
829+
"input_shape": (1, 3, 4, 112, 112),
830+
"num_classes": 50,
831+
}
828832
model_name = model_fn.__name__
833+
kwargs = {**defaults, **_model_params.get(model_name, {})}
834+
num_classes = kwargs.get("num_classes")
835+
input_shape = kwargs.pop("input_shape")
829836
# test both basicblock and Bottleneck
830-
model = model_fn(num_classes=50)
837+
model = model_fn(**kwargs)
831838
model.eval().to(device=dev)
832839
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
833840
x = torch.rand(input_shape).to(device=dev)
834841
out = model(x)
842+
_assert_expected(out.cpu(), model_name, prec=0.1)
843+
assert out.shape[-1] == num_classes
835844
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
836845
_check_fx_compatible(model, x, eager_out=out)
837-
assert out.shape[-1] == 50
846+
assert out.shape[-1] == num_classes
838847

839848
if dev == "cuda":
840849
with torch.cuda.amp.autocast():
841850
out = model(x)
842-
assert out.shape[-1] == 50
851+
# See autocast_flaky_numerics comment at top of file.
852+
if model_name not in autocast_flaky_numerics:
853+
_assert_expected(out.cpu(), model_name, prec=0.1)
854+
assert out.shape[-1] == num_classes
843855

844856
_check_input_backprop(model, x)
845857

0 commit comments

Comments
 (0)