@@ -822,24 +822,36 @@ def test_detection_model_validation(model_fn):
822
822
@pytest .mark .parametrize ("model_fn" , get_models_from_module (models .video ))
823
823
@pytest .mark .parametrize ("dev" , cpu_and_gpu ())
824
824
def test_video_model (model_fn , dev ):
825
+ set_rng_seed (0 )
825
826
# the default input shape is
826
827
# bs * num_channels * clip_len * h *w
827
- input_shape = (1 , 3 , 4 , 112 , 112 )
828
+ defaults = {
829
+ "input_shape" : (1 , 3 , 4 , 112 , 112 ),
830
+ "num_classes" : 50 ,
831
+ }
828
832
model_name = model_fn .__name__
833
+ kwargs = {** defaults , ** _model_params .get (model_name , {})}
834
+ num_classes = kwargs .get ("num_classes" )
835
+ input_shape = kwargs .pop ("input_shape" )
829
836
# test both basicblock and Bottleneck
830
- model = model_fn (num_classes = 50 )
837
+ model = model_fn (** kwargs )
831
838
model .eval ().to (device = dev )
832
839
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
833
840
x = torch .rand (input_shape ).to (device = dev )
834
841
out = model (x )
842
+ _assert_expected (out .cpu (), model_name , prec = 0.1 )
843
+ assert out .shape [- 1 ] == num_classes
835
844
_check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (model_name , None ), eager_out = out )
836
845
_check_fx_compatible (model , x , eager_out = out )
837
- assert out .shape [- 1 ] == 50
846
+ assert out .shape [- 1 ] == num_classes
838
847
839
848
if dev == "cuda" :
840
849
with torch .cuda .amp .autocast ():
841
850
out = model (x )
842
- assert out .shape [- 1 ] == 50
851
+ # See autocast_flaky_numerics comment at top of file.
852
+ if model_name not in autocast_flaky_numerics :
853
+ _assert_expected (out .cpu (), model_name , prec = 0.1 )
854
+ assert out .shape [- 1 ] == num_classes
843
855
844
856
_check_input_backprop (model , x )
845
857
0 commit comments