-
Notifications
You must be signed in to change notification settings - Fork 1.2k
change: make v2 migration script add py_version when needed #1598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""A class to ensure that ``py_version`` is defined when constructing framework classes.""" | ||
from __future__ import absolute_import | ||
|
||
import ast | ||
|
||
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier | ||
|
||
PY_VERSION_ARG = "py_version" | ||
PY_VERSION_DEFAULT = "py3" | ||
|
||
FRAMEWORK_MODEL_REQUIRES_PY_VERSION = { | ||
"Chainer": True, | ||
"MXNet": True, | ||
"PyTorch": True, | ||
"SKLearn": False, | ||
"TensorFlow": False, | ||
} | ||
|
||
FRAMEWORK_CLASSES = list(FRAMEWORK_MODEL_REQUIRES_PY_VERSION.keys()) | ||
|
||
MODEL_CLASSES = [ | ||
"{}Model".format(fw) for fw, required in FRAMEWORK_MODEL_REQUIRES_PY_VERSION.items() if required | ||
] | ||
FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORK_CLASSES] | ||
FRAMEWORK_SUBMODULES = ("model", "estimator") | ||
metrizable marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class PyVersionEnforcer(Modifier): | ||
"""A class to ensure that ``py_version`` is defined when | ||
instantiating a framework estimator or model, where appropriate. | ||
""" | ||
|
||
def node_should_be_modified(self, node): | ||
"""Checks if the ast.Call node should be modified to include ``py_version``. | ||
|
||
If the ast.Call node instantiates a framework estimator or model, but doesn't | ||
specify the ``py_version`` parameter when required, then the node should be | ||
modified. However, if ``image_name`` for a framework estimator or ``image`` | ||
for a model is supplied to the call, then ``py_version`` is not required. | ||
|
||
This looks for the following formats: | ||
|
||
- ``PyTorch`` | ||
- ``sagemaker.pytorch.PyTorch`` | ||
|
||
where "PyTorch" can be Chainer, MXNet, PyTorch, SKLearn, or TensorFlow. | ||
|
||
Args: | ||
node (ast.Call): a node that represents a function call. For more, | ||
see https://docs.python.org/3/library/ast.html#abstract-grammar. | ||
|
||
Returns: | ||
bool: If the ``ast.Call`` is instantiating a framework class that | ||
should specify ``py_version``, but doesn't. | ||
""" | ||
if _is_named_constructor(node, FRAMEWORK_CLASSES): | ||
return _version_arg_needed(node, "image_name", PY_VERSION_ARG) | ||
|
||
if _is_named_constructor(node, MODEL_CLASSES): | ||
return _version_arg_needed(node, "image", PY_VERSION_ARG) | ||
|
||
return False | ||
|
||
def modify_node(self, node): | ||
"""Modifies the ``ast.Call`` node's keywords to include ``py_version``. | ||
|
||
Args: | ||
node (ast.Call): a node that represents the constructor of a framework class. | ||
""" | ||
node.keywords.append(ast.keyword(arg=PY_VERSION_ARG, value=ast.Str(s=PY_VERSION_DEFAULT))) | ||
|
||
|
||
def _is_named_constructor(node, names): | ||
"""Checks if the ``ast.Call`` node represents a call to particular named constructors. | ||
|
||
Forms that qualify are either <Framework> or sagemaker.<framework>.<Framework> | ||
where <Framework> belongs to the list of names passed in. | ||
""" | ||
# Check for call from particular names of constructors | ||
if isinstance(node.func, ast.Name): | ||
return node.func.id in names | ||
|
||
# Check for something.that.ends.with.<framework>.<Framework> call for Framework in names | ||
if not (isinstance(node.func, ast.Attribute) and node.func.attr in names): | ||
return False | ||
|
||
# Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call | ||
if isinstance(node.func.value, ast.Attribute) and node.func.value.attr in FRAMEWORK_SUBMODULES: | ||
return _is_in_framework_module(node.func.value) | ||
|
||
# Check for sagemaker.<framework>.<Framework> call | ||
return _is_in_framework_module(node.func) | ||
|
||
|
||
def _is_in_framework_module(node): | ||
"""Checks if node is an ``ast.Attribute`` representing a ``sagemaker.<framework>`` module.""" | ||
return ( | ||
isinstance(node.value, ast.Attribute) | ||
and node.value.attr in FRAMEWORK_MODULES | ||
and isinstance(node.value.value, ast.Name) | ||
and node.value.value.id == "sagemaker" | ||
) | ||
|
||
|
||
def _version_arg_needed(node, image_arg, version_arg): | ||
"""Determines if image_arg or version_arg was supplied""" | ||
return not (_arg_supplied(node, image_arg) or _arg_supplied(node, version_arg)) | ||
|
||
|
||
def _arg_supplied(node, arg): | ||
"""Checks if the ``ast.Call`` node's keywords contain ``arg``.""" | ||
return any(kw.arg == arg and kw.value for kw in node.keywords) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import sys | ||
|
||
import pasta | ||
import pytest | ||
|
||
from sagemaker.cli.compatibility.v2.modifiers import py_version | ||
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def skip_if_py2(): | ||
# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed. | ||
if sys.version_info.major < 3: | ||
pytest.skip("v2 migration script doesn't support Python 2.") | ||
|
||
|
||
@pytest.fixture | ||
def constructor_framework_templates(): | ||
return ( | ||
"TensorFlow({})", | ||
"sagemaker.tensorflow.TensorFlow({})", | ||
"sagemaker.tensorflow.estimator.TensorFlow({})", | ||
"MXNet({})", | ||
"sagemaker.mxnet.MXNet({})", | ||
"sagemaker.mxnet.estimator.MXNet({})", | ||
"Chainer({})", | ||
"sagemaker.chainer.Chainer({})", | ||
"sagemaker.chainer.estimator.Chainer({})", | ||
"PyTorch({})", | ||
"sagemaker.pytorch.PyTorch({})", | ||
"sagemaker.pytorch.estimator.PyTorch({})", | ||
"SKLearn({})", | ||
"sagemaker.sklearn.SKLearn({})", | ||
"sagemaker.sklearn.estimator.SKLearn({})", | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def constructor_model_templates(): | ||
return ( | ||
"MXNetModel({})", | ||
"sagemaker.mxnet.MXNetModel({})", | ||
"sagemaker.mxnet.model.MXNetModel({})", | ||
"ChainerModel({})", | ||
"sagemaker.chainer.ChainerModel({})", | ||
"sagemaker.chainer.model.ChainerModel({})", | ||
"PyTorchModel({})", | ||
"sagemaker.pytorch.PyTorchModel({})", | ||
"sagemaker.pytorch.model.PyTorchModel({})", | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def constructor_templates(constructor_framework_templates, constructor_model_templates): | ||
return tuple(list(constructor_framework_templates) + list(constructor_model_templates)) | ||
|
||
|
||
@pytest.fixture | ||
def constructors_no_version(constructor_templates): | ||
return (ctr.format("") for ctr in constructor_templates) | ||
|
||
|
||
@pytest.fixture | ||
def constructors_with_version(constructor_templates): | ||
return (ctr.format("py_version='py3'") for ctr in constructor_templates) | ||
|
||
|
||
@pytest.fixture | ||
def constructors_with_image_name(constructor_framework_templates): | ||
return (ctr.format("image_name='my:image'") for ctr in constructor_framework_templates) | ||
|
||
|
||
@pytest.fixture | ||
def constructors_with_image(constructor_model_templates): | ||
return (ctr.format("image='my:image'") for ctr in constructor_model_templates) | ||
|
||
|
||
@pytest.fixture | ||
def constructors_version_not_needed(): | ||
return ( | ||
"TensorFlowModel()", | ||
"sagemaker.tensorflow.TensorFlowModel()", | ||
"sagemaker.tensorflow.model.TensorFlowModel()", | ||
"SKLearnModel()", | ||
"sagemaker.sklearn.SKLearnModel()", | ||
"sagemaker.sklearn.model.SKLearnModel()", | ||
) | ||
|
||
|
||
def _test_modified(constructors, should_be): | ||
modifier = py_version.PyVersionEnforcer() | ||
for constructor in constructors: | ||
node = ast_call(constructor) | ||
if should_be: | ||
assert modifier.node_should_be_modified(node) | ||
else: | ||
assert not modifier.node_should_be_modified(node) | ||
|
||
|
||
def test_node_should_be_modified_fw_constructor_no_version(constructors_no_version): | ||
_test_modified(constructors_no_version, should_be=True) | ||
|
||
|
||
def test_node_should_be_modified_fw_constructor_with_version(constructors_with_version): | ||
_test_modified(constructors_with_version, should_be=False) | ||
|
||
|
||
def test_node_should_be_modified_fw_constructor_with_image_name(constructors_with_image_name): | ||
_test_modified(constructors_with_image_name, should_be=False) | ||
|
||
|
||
def test_node_should_be_modified_fw_constructor_with_image(constructors_with_image): | ||
_test_modified(constructors_with_image, should_be=False) | ||
|
||
|
||
def test_node_should_be_modified_fw_constructor_not_needed(constructors_version_not_needed): | ||
Comment on lines
+114
to
+130
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it doesn't seem like any of these fixtures are used in multiple tests, so I'm not sure what the benefit is to defining fixtures for them rather than just creating the iterable in the test directly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i was envisioning using them with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gotcha - that makes sense! |
||
_test_modified(constructors_version_not_needed, should_be=False) | ||
|
||
|
||
def test_node_should_be_modified_random_function_call(): | ||
_test_modified(["sagemaker.session.Session()"], should_be=False) | ||
|
||
|
||
def test_modify_node(constructor_templates): | ||
modifier = py_version.PyVersionEnforcer() | ||
for template in constructor_templates: | ||
no_version, with_version = template.format(""), template.format("py_version='py3'") | ||
node = ast_call(no_version) | ||
modifier.modify_node(node) | ||
|
||
assert with_version == pasta.dump(node) |
Uh oh!
There was an error while loading. Please reload this page.