@@ -54,6 +54,7 @@ def node_should_be_modified(self, node):
54
54
55
55
- ``TensorFlow``
56
56
- ``sagemaker.tensorflow.TensorFlow``
57
+ - ``sagemaker.tensorflow.estimator.TensorFlow``
57
58
58
59
Legacy mode is enabled if (1) ``script_mode`` is ``False``, ``None``, or not specified,
59
60
and (2) if ``py_version`` is ``py2`` or not specified.
@@ -68,27 +69,35 @@ def node_should_be_modified(self, node):
68
69
return self ._is_tf_constructor (node ) and self ._is_legacy_mode (node )
69
70
70
71
def _is_tf_constructor (self , node ):
71
- """Checks if the ``ast.Call`` node represents a call of the form
72
- ``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``.
72
+ """Checks if the ``ast.Call`` node represents a call of the form ``TensorFlow``,
73
+ ``sagemaker.tensorflow. TensorFlow``, or ``sagemaker.tensorflow.estimator .TensorFlow``.
73
74
"""
74
75
# Check for TensorFlow()
75
76
if isinstance (node .func , ast .Name ):
76
77
return node .func .id == "TensorFlow"
77
78
79
+ # Check for something.that.ends.with.TensorFlow()
80
+ if not (isinstance (node .func , ast .Attribute ) and node .func .attr == "TensorFlow" ):
81
+ return False
82
+
83
+ # Check for sagemaker.tensorflow.estimator.TensorFlow()
84
+ if isinstance (node .func .value , ast .Attribute ) and node .func .value .attr == "estimator" :
85
+ return self ._is_in_tensorflow_module (node .func .value )
86
+
78
87
# Check for sagemaker.tensorflow.TensorFlow()
79
- ends_with_tensorflow_constructor = (
80
- isinstance (node .func , ast .Attribute ) and node .func .attr == "TensorFlow"
81
- )
88
+ return self ._is_in_tensorflow_module (node .func )
82
89
83
- is_in_tensorflow_module = (
84
- isinstance (node .func .value , ast .Attribute )
85
- and node .func .value .attr == "tensorflow"
86
- and isinstance (node .func .value .value , ast .Name )
87
- and node .func .value .value .id == "sagemaker"
90
+ def _is_in_tensorflow_module (self , node ):
91
+ """Checks if the node is an ``ast.Attribute`` that represents the
92
+ ``sagemaker.tensorflow`` module.
93
+ """
94
+ return (
95
+ isinstance (node .value , ast .Attribute )
96
+ and node .value .attr == "tensorflow"
97
+ and isinstance (node .value .value , ast .Name )
98
+ and node .value .value .id == "sagemaker"
88
99
)
89
100
90
- return ends_with_tensorflow_constructor and is_in_tensorflow_module
91
-
92
101
def _is_legacy_mode (self , node ):
93
102
"""Checks if the ``ast.Call`` node's keywords signal using legacy mode."""
94
103
script_mode = False
0 commit comments