From 56fb02b016cdf3dcf1686d5cfaebbc1fb009b58e Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 7 Jun 2023 20:31:08 -0700 Subject: [PATCH] fix: Refactor assertions in E2E tests for Dynamo - Add unittest assertion module to streamline error messaging and reporting --- .../dynamo/test/test_dynamo_backend.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py index 9f2ecf1432..3c2ec01419 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -1,6 +1,7 @@ import torch import timm import pytest +import unittest import torch_tensorrt as torchtrt import torchvision.models as models @@ -12,6 +13,8 @@ cosine_similarity, ) +assertions = unittest.TestCase() + @pytest.mark.unit def test_resnet18(ir): @@ -31,9 +34,9 @@ def test_resnet18(ir): trt_mod = torchtrt.compile(model, **compile_spec) cos_sim = cosine_similarity(model(input), trt_mod(input)) - assert ( + assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Clean up model env @@ -61,9 +64,9 @@ def test_mobilenet_v2(ir): trt_mod = torchtrt.compile(model, **compile_spec) cos_sim = cosine_similarity(model(input), trt_mod(input)) - assert ( + assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Clean up model env @@ -91,9 +94,9 @@ def test_efficientnet_b0(ir): trt_mod = torchtrt.compile(model, **compile_spec) cos_sim = cosine_similarity(model(input), trt_mod(input)) - assert ( + assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Clean up model env @@ -134,9 +137,9 @@ def test_bert_base_uncased(ir): for key in model_outputs.keys(): out, trt_out = model_outputs[key], trt_model_outputs[key] cos_sim = cosine_similarity(out, trt_out) - assert ( + assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Clean up model env @@ -164,9 +167,9 @@ def test_resnet18_half(ir): trt_mod = torchtrt.compile(model, **compile_spec) cos_sim = cosine_similarity(model(input), trt_mod(input)) - assert ( + assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Clean up model env