|
1 | 1 | import ast
|
2 |
| -import unittest |
3 |
| - |
4 |
| -from tests.unit.v2.utils import get_sample_file |
5 |
| -from tools.compatibility.v2.ast_transformer import ASTTransformer |
| 2 | +from sagemaker.tools.compatibility.v2.ast_transformer import ASTTransformer |
6 | 3 | import pasta
|
7 | 4 |
|
8 | 5 |
|
9 |
| -class TransformerTest(unittest.TestCase): |
10 |
| - def setUp(self) -> None: |
11 |
| - self.transformer_class = ASTTransformer() |
12 |
| - |
13 |
| - def test_simple_transform(self): |
14 |
| - sample = get_sample_file('simple.txt') |
15 |
| - rewrite = self.transformer_class.visit( |
16 |
| - ast.parse( |
17 |
| - sample |
18 |
| - ) |
| 6 | +def test_code_needs_transform(): |
| 7 | + simple = """ |
| 8 | +TensorFlow(entry_point="foo.py") |
| 9 | +sagemaker.tensorflow.TensorFlow() |
| 10 | +m = MXNet() |
| 11 | +sagemaker.mxnet.MXNet() |
| 12 | +""" |
| 13 | + transformer_class = ASTTransformer() |
| 14 | + rewrite = transformer_class.visit( |
| 15 | + ast.parse( |
| 16 | + simple |
19 | 17 | )
|
20 |
| - |
21 |
| - expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') |
| 18 | + ) |
| 19 | + expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') |
22 | 20 | sagemaker.tensorflow.TensorFlow(framework_version='1.11.0')
|
23 | 21 | m = MXNet(framework_version='1.2.0')
|
24 | 22 | sagemaker.mxnet.MXNet(framework_version='1.2.0')\n"""
|
25 | 23 |
|
26 |
| - self.assertEqual(pasta.dump(rewrite), expected) |
| 24 | + assert pasta.dump( |
| 25 | + rewrite |
| 26 | + ) == expected |
27 | 27 |
|
28 | 28 |
|
29 |
| -if __name__ == '__main__': |
30 |
| - unittest.main() |
31 |
| - |
| 29 | +def test_code_does_not_need_transform(): |
| 30 | + simple = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') |
| 31 | +sagemaker.tensorflow.TensorFlow(framework_version='1.11.0') |
| 32 | +m = MXNet(framework_version='1.2.0') |
| 33 | +sagemaker.mxnet.MXNet(framework_version='1.2.0')\n""" |
| 34 | + transformer_class = ASTTransformer() |
| 35 | + rewrite = transformer_class.visit( |
| 36 | + ast.parse( |
| 37 | + simple |
| 38 | + ) |
| 39 | + ) |
| 40 | + expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') |
| 41 | +sagemaker.tensorflow.TensorFlow(framework_version='1.11.0') |
| 42 | +m = MXNet(framework_version='1.2.0') |
| 43 | +sagemaker.mxnet.MXNet(framework_version='1.2.0')\n""" |
32 | 44 |
|
| 45 | + assert pasta.dump( |
| 46 | + rewrite |
| 47 | + ) == expected |
0 commit comments