Skip to content

Commit eedff54

Browse files
only copare top 100 classes in image classification
1 parent 5ede25a commit eedff54

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

tests/onnxruntime/test_modeling.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2867,9 +2867,6 @@ class ORTModelForImageClassificationIntegrationTest(ORTModelTestMixin):
28672867
ORTMODEL_CLASS = ORTModelForImageClassification
28682868
TASK = "image-classification"
28692869

2870-
ATOL = 2e-3 # 0.02 difference in logits
2871-
RTOL = 1e-2 # 1% difference in logits
2872-
28732870
def _get_model_ids(self, model_arch):
28742871
model_ids = MODEL_NAMES[model_arch]
28752872
if isinstance(model_ids, dict):
@@ -3040,8 +3037,16 @@ def test_compare_to_io_binding(self, model_arch):
30403037
onnx_outputs = onnx_model(**inputs)
30413038
io_outputs = io_model(**inputs)
30423039

3040+
print("shape of logits", io_outputs.logits.shape)
3041+
30433042
self.assertTrue("logits" in io_outputs)
30443043
self.assertIsInstance(io_outputs.logits, torch.Tensor)
3044+
self.assertEqual(io_outputs.logits.shape, onnx_outputs.logits.shape)
3045+
3046+
if io_outputs.logits.shape[1] > 100:
3047+
# we compare only the top 100 classes (biggest 100 values in order)
3048+
io_outputs.logits = torch.topk(io_outputs.logits, 100, dim=1).values
3049+
onnx_outputs.logits = torch.topk(onnx_outputs.logits, 100, dim=1).values
30453050

30463051
# compare tensor outputs
30473052
torch.testing.assert_close(onnx_outputs.logits, io_outputs.logits, atol=self.ATOL, rtol=self.RTOL)

0 commit comments

Comments
 (0)