Skip to content

Commit b943bc1

Browse files
committed
Add utils test
1 parent 64a9210 commit b943bc1

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

test/test_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import unittest
2+
from unittest.mock import patch
3+
from torchao.utils import torch_version_at_least
4+
5+
class TestTorchVersionAtLeast(unittest.TestCase):
6+
def test_torch_version_at_least(self):
7+
test_cases = [
8+
("2.5.0a0+git9f17037", "2.5.0", True),
9+
("2.5.0a0+git9f17037", "2.4.0", True),
10+
("2.5.0.dev20240708+cu121", "2.5.0", True),
11+
("2.5.0.dev20240708+cu121", "2.4.0", True),
12+
("2.5.0", "2.4.0", True),
13+
("2.5.0", "2.5.0", True),
14+
("2.4.0", "2.4.0", True),
15+
("2.4.0", "2.5.0", False),
16+
]
17+
18+
for torch_version, compare_version, expected_result in test_cases:
19+
with patch('torch.__version__', torch_version):
20+
result = torch_version_at_least(compare_version)
21+
22+
self.assertEqual(result, expected_result, f"Failed for torch.__version__={torch_version}, comparing with {compare_version}")
23+
print(f"{torch_version}: {result}")
24+
25+
if __name__ == '__main__':
26+
unittest.main()

0 commit comments

Comments
 (0)