@@ -2867,9 +2867,6 @@ class ORTModelForImageClassificationIntegrationTest(ORTModelTestMixin):
28672867 ORTMODEL_CLASS = ORTModelForImageClassification
28682868 TASK = "image-classification"
28692869
2870- ATOL = 2e-3 # 0.02 difference in logits
2871- RTOL = 1e-2 # 1% difference in logits
2872-
28732870 def _get_model_ids (self , model_arch ):
28742871 model_ids = MODEL_NAMES [model_arch ]
28752872 if isinstance (model_ids , dict ):
@@ -3040,8 +3037,16 @@ def test_compare_to_io_binding(self, model_arch):
30403037 onnx_outputs = onnx_model (** inputs )
30413038 io_outputs = io_model (** inputs )
30423039
3040+ print ("shape of logits" , io_outputs .logits .shape )
3041+
30433042 self .assertTrue ("logits" in io_outputs )
30443043 self .assertIsInstance (io_outputs .logits , torch .Tensor )
3044+ self .assertEqual (io_outputs .logits .shape , onnx_outputs .logits .shape )
3045+
3046+ if io_outputs .logits .shape [1 ] > 100 :
3047+ # we compare only the top 100 classes (biggest 100 values in order)
3048+ io_outputs .logits = torch .topk (io_outputs .logits , 100 , dim = 1 ).values
3049+ onnx_outputs .logits = torch .topk (onnx_outputs .logits , 100 , dim = 1 ).values
30453050
30463051 # compare tensor outputs
30473052 torch .testing .assert_close (onnx_outputs .logits , io_outputs .logits , atol = self .ATOL , rtol = self .RTOL )
0 commit comments