1
+ from __future__ import absolute_import
2
+
1
3
import ast
2
4
from sagemaker .tools .compatibility .v2 .ast_transformer import ASTTransformer
3
5
import pasta
@@ -10,20 +12,15 @@ def test_code_needs_transform():
10
12
m = MXNet()
11
13
sagemaker.mxnet.MXNet()
12
14
"""
15
+
13
16
transformer_class = ASTTransformer ()
14
- rewrite = transformer_class .visit (
15
- ast .parse (
16
- simple
17
- )
18
- )
17
+ rewrite = transformer_class .visit (ast .parse (simple ))
19
18
expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0')
20
19
sagemaker.tensorflow.TensorFlow(framework_version='1.11.0')
21
20
m = MXNet(framework_version='1.2.0')
22
21
sagemaker.mxnet.MXNet(framework_version='1.2.0')\n """
23
22
24
- assert pasta .dump (
25
- rewrite
26
- ) == expected
23
+ assert pasta .dump (rewrite ) == expected
27
24
28
25
29
26
def test_code_does_not_need_transform ():
@@ -32,16 +29,10 @@ def test_code_does_not_need_transform():
32
29
m = MXNet(framework_version='1.2.0')
33
30
sagemaker.mxnet.MXNet(framework_version='1.2.0')\n """
34
31
transformer_class = ASTTransformer ()
35
- rewrite = transformer_class .visit (
36
- ast .parse (
37
- simple
38
- )
39
- )
32
+ rewrite = transformer_class .visit (ast .parse (simple ))
40
33
expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0')
41
34
sagemaker.tensorflow.TensorFlow(framework_version='1.11.0')
42
35
m = MXNet(framework_version='1.2.0')
43
36
sagemaker.mxnet.MXNet(framework_version='1.2.0')\n """
44
37
45
- assert pasta .dump (
46
- rewrite
47
- ) == expected
38
+ assert pasta .dump (rewrite ) == expected
0 commit comments