Skip to content

Commit baf1c35

Browse files
authored
fix: look for 'sagemaker.<framework>.<estimator/model>' module in v2 migration tool (#1551)
1 parent 2099a8d commit baf1c35

File tree

2 files changed

+57
-40
lines changed

2 files changed

+57
-40
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
# TODO: check for sagemaker.tensorflow.serving.Model
3030
FRAMEWORK_CLASSES = FRAMEWORKS + ["{}Model".format(fw) for fw in FRAMEWORKS]
3131
FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORKS]
32+
FRAMEWORK_SUBMODULES = ("model", "estimator")
3233

3334

3435
class FrameworkVersionEnforcer(Modifier):
@@ -68,19 +69,30 @@ def _is_framework_constructor(self, node):
6869
if isinstance(node.func, ast.Name):
6970
return node.func.id in FRAMEWORK_CLASSES
7071

71-
# Check for sagemaker.<framework>.<Framework> call
72-
ends_with_framework_constructor = (
73-
isinstance(node.func, ast.Attribute) and node.func.attr in FRAMEWORK_CLASSES
74-
)
72+
# Check for something.that.ends.with.<framework>.<Framework> call
73+
if not (isinstance(node.func, ast.Attribute) and node.func.attr in FRAMEWORK_CLASSES):
74+
return False
7575

76-
is_in_framework_module = (
76+
# Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call
77+
if (
7778
isinstance(node.func.value, ast.Attribute)
78-
and node.func.value.attr in FRAMEWORK_MODULES
79-
and isinstance(node.func.value.value, ast.Name)
80-
and node.func.value.value.id == "sagemaker"
81-
)
79+
and node.func.value.attr in FRAMEWORK_SUBMODULES
80+
):
81+
return self._is_in_framework_module(node.func.value)
8282

83-
return ends_with_framework_constructor and is_in_framework_module
83+
# Check for sagemaker.<framework>.<Framework> call
84+
return self._is_in_framework_module(node.func)
85+
86+
def _is_in_framework_module(self, node):
87+
"""Checks if the node is an ``ast.Attribute`` that represents a
88+
``sagemaker.<framework>`` module.
89+
"""
90+
return (
91+
isinstance(node.value, ast.Attribute)
92+
and node.value.attr in FRAMEWORK_MODULES
93+
and isinstance(node.value.value, ast.Name)
94+
and node.value.value.id == "sagemaker"
95+
)
8496

8597
def _fw_version_in_keywords(self, node):
8698
"""Checks if the ``ast.Call`` node's keywords contain ``framework_version``."""

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,34 @@ def test_node_should_be_modified_fw_constructor_no_fw_version():
3232
fw_constructors = (
3333
"TensorFlow()",
3434
"sagemaker.tensorflow.TensorFlow()",
35+
"sagemaker.tensorflow.estimator.TensorFlow()",
3536
"TensorFlowModel()",
3637
"sagemaker.tensorflow.TensorFlowModel()",
38+
"sagemaker.tensorflow.model.TensorFlowModel()",
3739
"MXNet()",
3840
"sagemaker.mxnet.MXNet()",
41+
"sagemaker.mxnet.estimator.MXNet()",
3942
"MXNetModel()",
4043
"sagemaker.mxnet.MXNetModel()",
44+
"sagemaker.mxnet.model.MXNetModel()",
4145
"Chainer()",
4246
"sagemaker.chainer.Chainer()",
47+
"sagemaker.chainer.estimator.Chainer()",
4348
"ChainerModel()",
4449
"sagemaker.chainer.ChainerModel()",
50+
"sagemaker.chainer.model.ChainerModel()",
4551
"PyTorch()",
4652
"sagemaker.pytorch.PyTorch()",
53+
"sagemaker.pytorch.estimator.PyTorch()",
4754
"PyTorchModel()",
4855
"sagemaker.pytorch.PyTorchModel()",
56+
"sagemaker.pytorch.model.PyTorchModel()",
4957
"SKLearn()",
5058
"sagemaker.sklearn.SKLearn()",
59+
"sagemaker.sklearn.estimator.SKLearn()",
5160
"SKLearnModel()",
5261
"sagemaker.sklearn.SKLearnModel()",
62+
"sagemaker.sklearn.model.SKLearnModel()",
5363
)
5464

5565
modifier = framework_version.FrameworkVersionEnforcer()
@@ -63,24 +73,34 @@ def test_node_should_be_modified_fw_constructor_with_fw_version():
6373
fw_constructors = (
6474
"TensorFlow(framework_version='2.2')",
6575
"sagemaker.tensorflow.TensorFlow(framework_version='2.2')",
76+
"sagemaker.tensorflow.estimator.TensorFlow(framework_version='2.2')",
6677
"TensorFlowModel(framework_version='1.10')",
6778
"sagemaker.tensorflow.TensorFlowModel(framework_version='1.10')",
79+
"sagemaker.tensorflow.model.TensorFlowModel(framework_version='1.10')",
6880
"MXNet(framework_version='1.6')",
6981
"sagemaker.mxnet.MXNet(framework_version='1.6')",
82+
"sagemaker.mxnet.estimator.MXNet(framework_version='1.6')",
7083
"MXNetModel(framework_version='1.6')",
7184
"sagemaker.mxnet.MXNetModel(framework_version='1.6')",
85+
"sagemaker.mxnet.model.MXNetModel(framework_version='1.6')",
7286
"PyTorch(framework_version='1.4')",
7387
"sagemaker.pytorch.PyTorch(framework_version='1.4')",
88+
"sagemaker.pytorch.estimator.PyTorch(framework_version='1.4')",
7489
"PyTorchModel(framework_version='1.4')",
7590
"sagemaker.pytorch.PyTorchModel(framework_version='1.4')",
91+
"sagemaker.pytorch.model.PyTorchModel(framework_version='1.4')",
7692
"Chainer(framework_version='5.0')",
7793
"sagemaker.chainer.Chainer(framework_version='5.0')",
94+
"sagemaker.chainer.estimator.Chainer(framework_version='5.0')",
7895
"ChainerModel(framework_version='5.0')",
7996
"sagemaker.chainer.ChainerModel(framework_version='5.0')",
97+
"sagemaker.chainer.model.ChainerModel(framework_version='5.0')",
8098
"SKLearn(framework_version='0.20.0')",
8199
"sagemaker.sklearn.SKLearn(framework_version='0.20.0')",
100+
"sagemaker.sklearn.estimator.SKLearn(framework_version='0.20.0')",
82101
"SKLearnModel(framework_version='0.20.0')",
83102
"sagemaker.sklearn.SKLearnModel(framework_version='0.20.0')",
103+
"sagemaker.sklearn.model.SKLearnModel(framework_version='0.20.0')",
84104
)
85105

86106
modifier = framework_version.FrameworkVersionEnforcer()
@@ -97,51 +117,36 @@ def test_node_should_be_modified_random_function_call():
97117

98118

99119
def test_modify_node_tf():
100-
classes = (
101-
"TensorFlow" "sagemaker.tensorflow.TensorFlow",
102-
"TensorFlowModel",
103-
"sagemaker.tensorflow.TensorFlowModel",
104-
)
105-
_test_modify_node(classes, "1.11.0")
120+
_test_modify_node("TensorFlow", "1.11.0")
106121

107122

108123
def test_modify_node_mx():
109-
classes = ("MXNet", "sagemaker.mxnet.MXNet", "MXNetModel", "sagemaker.mxnet.MXNetModel")
110-
_test_modify_node(classes, "1.2.0")
124+
_test_modify_node("MXNet", "1.2.0")
111125

112126

113127
def test_modify_node_chainer():
114-
classes = (
115-
"Chainer",
116-
"sagemaker.chainer.Chainer",
117-
"ChainerModel",
118-
"sagemaker.chainer.ChainerModel",
119-
)
120-
_test_modify_node(classes, "4.1.0")
128+
_test_modify_node("Chainer", "4.1.0")
121129

122130

123131
def test_modify_node_pt():
124-
classes = (
125-
"PyTorch",
126-
"sagemaker.pytorch.PyTorch",
127-
"PyTorchModel",
128-
"sagemaker.pytorch.PyTorchModel",
129-
)
130-
_test_modify_node(classes, "0.4.0")
132+
_test_modify_node("PyTorch", "0.4.0")
131133

132134

133135
def test_modify_node_sklearn():
134-
classes = (
135-
"SKLearn",
136-
"sagemaker.sklearn.SKLearn",
137-
"SKLearnModel",
138-
"sagemaker.sklearn.SKLearnModel",
139-
)
140-
_test_modify_node(classes, "0.20.0")
136+
_test_modify_node("SKLearn", "0.20.0")
141137

142138

143-
def _test_modify_node(classes, default_version):
139+
def _test_modify_node(framework, default_version):
144140
modifier = framework_version.FrameworkVersionEnforcer()
141+
142+
classes = (
143+
"{}".format(framework),
144+
"sagemaker.{}.{}".format(framework.lower(), framework),
145+
"sagemaker.{}.estimator.{}".format(framework.lower(), framework),
146+
"{}Model".format(framework),
147+
"sagemaker.{}.{}Model".format(framework.lower(), framework),
148+
"sagemaker.{}.model.{}Model".format(framework.lower(), framework),
149+
)
145150
for cls in classes:
146151
node = ast_call("{}()".format(cls))
147152
modifier.modify_node(node)

0 commit comments

Comments
 (0)