Skip to content

Commit 2099a8d

Browse files
authored
fix: look for 'sagemaker.tensorflow.estimator' module in v2 migration tool (#1550)
1 parent 3014421 commit 2099a8d

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def node_should_be_modified(self, node):
5454
5555
- ``TensorFlow``
5656
- ``sagemaker.tensorflow.TensorFlow``
57+
- ``sagemaker.tensorflow.estimator.TensorFlow``
5758
5859
Legacy mode is enabled if (1) ``script_mode`` is ``False``, ``None``, or not specified,
5960
and (2) if ``py_version`` is ``py2`` or not specified.
@@ -68,27 +69,35 @@ def node_should_be_modified(self, node):
6869
return self._is_tf_constructor(node) and self._is_legacy_mode(node)
6970

7071
def _is_tf_constructor(self, node):
71-
"""Checks if the ``ast.Call`` node represents a call of the form
72-
``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``.
72+
"""Checks if the ``ast.Call`` node represents a call of the form ``TensorFlow``,
73+
``sagemaker.tensorflow.TensorFlow``, or ``sagemaker.tensorflow.estimator.TensorFlow``.
7374
"""
7475
# Check for TensorFlow()
7576
if isinstance(node.func, ast.Name):
7677
return node.func.id == "TensorFlow"
7778

79+
# Check for something.that.ends.with.TensorFlow()
80+
if not (isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"):
81+
return False
82+
83+
# Check for sagemaker.tensorflow.estimator.TensorFlow()
84+
if isinstance(node.func.value, ast.Attribute) and node.func.value.attr == "estimator":
85+
return self._is_in_tensorflow_module(node.func.value)
86+
7887
# Check for sagemaker.tensorflow.TensorFlow()
79-
ends_with_tensorflow_constructor = (
80-
isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"
81-
)
88+
return self._is_in_tensorflow_module(node.func)
8289

83-
is_in_tensorflow_module = (
84-
isinstance(node.func.value, ast.Attribute)
85-
and node.func.value.attr == "tensorflow"
86-
and isinstance(node.func.value.value, ast.Name)
87-
and node.func.value.value.id == "sagemaker"
90+
def _is_in_tensorflow_module(self, node):
91+
"""Checks if the node is an ``ast.Attribute`` that represents the
92+
``sagemaker.tensorflow`` module.
93+
"""
94+
return (
95+
isinstance(node.value, ast.Attribute)
96+
and node.value.attr == "tensorflow"
97+
and isinstance(node.value.value, ast.Name)
98+
and node.value.value.id == "sagemaker"
8899
)
89100

90-
return ends_with_tensorflow_constructor and is_in_tensorflow_module
91-
92101
def _is_legacy_mode(self, node):
93102
"""Checks if the ``ast.Call`` node's keywords signal using legacy mode."""
94103
script_mode = False

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def test_node_should_be_modified_tf_constructor_legacy_mode():
4242
"sagemaker.tensorflow.TensorFlow(script_mode=None)",
4343
"sagemaker.tensorflow.TensorFlow(py_version='py2')",
4444
"sagemaker.tensorflow.TensorFlow()",
45+
"sagemaker.tensorflow.estimator.TensorFlow(script_mode=False)",
46+
"sagemaker.tensorflow.estimator.TensorFlow(script_mode=None)",
47+
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py2')",
48+
"sagemaker.tensorflow.estimator.TensorFlow()",
4549
)
4650

4751
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
@@ -61,6 +65,10 @@ def test_node_should_be_modified_tf_constructor_script_mode():
6165
"sagemaker.tensorflow.TensorFlow(py_version='py3')",
6266
"sagemaker.tensorflow.TensorFlow(py_version='py37')",
6367
"sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=False)",
68+
"sagemaker.tensorflow.estimator.TensorFlow(script_mode=True)",
69+
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py3')",
70+
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py37')",
71+
"sagemaker.tensorflow.estimator.TensorFlow(py_version='py3', script_mode=False)",
6472
)
6573

6674
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()

0 commit comments

Comments
 (0)