Skip to content

Commit 64c66f2

Browse files
author
owahab
committed
feature: Add tests for ast_transformer
1 parent 872ba39 commit 64c66f2

File tree

4 files changed

+35
-46
lines changed

4 files changed

+35
-46
lines changed

tests/unit/v2/samples/simple.txt

Lines changed: 0 additions & 4 deletions
This file was deleted.

tests/unit/v2/test_framework_version.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

tests/unit/v2/test_transformer.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,47 @@
11
import ast
2-
import unittest
3-
4-
from tests.unit.v2.utils import get_sample_file
5-
from tools.compatibility.v2.ast_transformer import ASTTransformer
2+
from sagemaker.tools.compatibility.v2.ast_transformer import ASTTransformer
63
import pasta
74

85

9-
class TransformerTest(unittest.TestCase):
10-
def setUp(self) -> None:
11-
self.transformer_class = ASTTransformer()
12-
13-
def test_simple_transform(self):
14-
sample = get_sample_file('simple.txt')
15-
rewrite = self.transformer_class.visit(
16-
ast.parse(
17-
sample
18-
)
6+
def test_code_needs_transform():
7+
simple = """
8+
TensorFlow(entry_point="foo.py")
9+
sagemaker.tensorflow.TensorFlow()
10+
m = MXNet()
11+
sagemaker.mxnet.MXNet()
12+
"""
13+
transformer_class = ASTTransformer()
14+
rewrite = transformer_class.visit(
15+
ast.parse(
16+
simple
1917
)
20-
21-
expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0')
18+
)
19+
expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0')
2220
sagemaker.tensorflow.TensorFlow(framework_version='1.11.0')
2321
m = MXNet(framework_version='1.2.0')
2422
sagemaker.mxnet.MXNet(framework_version='1.2.0')\n"""
2523

26-
self.assertEqual(pasta.dump(rewrite), expected)
24+
assert pasta.dump(
25+
rewrite
26+
) == expected
2727

2828

29-
if __name__ == '__main__':
30-
unittest.main()
31-
29+
def test_code_does_not_need_transform():
30+
simple = """TensorFlow(entry_point='foo.py', framework_version='1.11.0')
31+
sagemaker.tensorflow.TensorFlow(framework_version='1.11.0')
32+
m = MXNet(framework_version='1.2.0')
33+
sagemaker.mxnet.MXNet(framework_version='1.2.0')\n"""
34+
transformer_class = ASTTransformer()
35+
rewrite = transformer_class.visit(
36+
ast.parse(
37+
simple
38+
)
39+
)
40+
expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0')
41+
sagemaker.tensorflow.TensorFlow(framework_version='1.11.0')
42+
m = MXNet(framework_version='1.2.0')
43+
sagemaker.mxnet.MXNet(framework_version='1.2.0')\n"""
3244

45+
assert pasta.dump(
46+
rewrite
47+
) == expected

tests/unit/v2/utils.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)