Skip to content

change: make v2 migration script add py_version when needed #1598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/cli/compatibility/v2/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sagemaker.cli.compatibility.v2 import modifiers

FUNCTION_CALL_MODIFIERS = [
modifiers.py_version.PyVersionEnforcer(),
modifiers.framework_version.FrameworkVersionEnforcer(),
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
deprecated_params,
framework_version,
py_version,
tf_legacy_mode,
tfs,
)
124 changes: 124 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/py_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""A class to ensure that ``py_version`` is defined when constructing framework classes."""
from __future__ import absolute_import

import ast

from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier

PY_VERSION_ARG = "py_version"
PY_VERSION_DEFAULT = "py3"

FRAMEWORK_MODEL_REQUIRES_PY_VERSION = {
"Chainer": True,
"MXNet": True,
"PyTorch": True,
"SKLearn": False,
"TensorFlow": False,
}

FRAMEWORK_CLASSES = list(FRAMEWORK_MODEL_REQUIRES_PY_VERSION.keys())

MODEL_CLASSES = [
"{}Model".format(fw) for fw, required in FRAMEWORK_MODEL_REQUIRES_PY_VERSION.items() if required
]
FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORK_CLASSES]
FRAMEWORK_SUBMODULES = ("model", "estimator")


class PyVersionEnforcer(Modifier):
"""A class to ensure that ``py_version`` is defined when
instantiating a framework estimator or model, where appropriate.
"""

def node_should_be_modified(self, node):
"""Checks if the ast.Call node should be modified to include ``py_version``.

If the ast.Call node instantiates a framework estimator or model, but doesn't
specify the ``py_version`` parameter when required, then the node should be
modified. However, if ``image_name`` for a framework estimator or ``image``
for a model is supplied to the call, then ``py_version`` is not required.

This looks for the following formats:

- ``PyTorch``
- ``sagemaker.pytorch.PyTorch``

where "PyTorch" can be Chainer, MXNet, PyTorch, SKLearn, or TensorFlow.

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
bool: If the ``ast.Call`` is instantiating a framework class that
should specify ``py_version``, but doesn't.
"""
if _is_named_constructor(node, FRAMEWORK_CLASSES):
return _version_arg_needed(node, "image_name", PY_VERSION_ARG)

if _is_named_constructor(node, MODEL_CLASSES):
return _version_arg_needed(node, "image", PY_VERSION_ARG)

return False

def modify_node(self, node):
"""Modifies the ``ast.Call`` node's keywords to include ``py_version``.

Args:
node (ast.Call): a node that represents the constructor of a framework class.
"""
node.keywords.append(ast.keyword(arg=PY_VERSION_ARG, value=ast.Str(s=PY_VERSION_DEFAULT)))


def _is_named_constructor(node, names):
"""Checks if the ``ast.Call`` node represents a call to particular named constructors.

Forms that qualify are either <Framework> or sagemaker.<framework>.<Framework>
where <Framework> belongs to the list of names passed in.
"""
# Check for call from particular names of constructors
if isinstance(node.func, ast.Name):
return node.func.id in names

# Check for something.that.ends.with.<framework>.<Framework> call for Framework in names
if not (isinstance(node.func, ast.Attribute) and node.func.attr in names):
return False

# Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call
if isinstance(node.func.value, ast.Attribute) and node.func.value.attr in FRAMEWORK_SUBMODULES:
return _is_in_framework_module(node.func.value)

# Check for sagemaker.<framework>.<Framework> call
return _is_in_framework_module(node.func)


def _is_in_framework_module(node):
"""Checks if node is an ``ast.Attribute`` representing a ``sagemaker.<framework>`` module."""
return (
isinstance(node.value, ast.Attribute)
and node.value.attr in FRAMEWORK_MODULES
and isinstance(node.value.value, ast.Name)
and node.value.value.id == "sagemaker"
)


def _version_arg_needed(node, image_arg, version_arg):
"""Determines if image_arg or version_arg was supplied"""
return not (_arg_supplied(node, image_arg) or _arg_supplied(node, version_arg))


def _arg_supplied(node, arg):
"""Checks if the ``ast.Call`` node's keywords contain ``arg``."""
return any(kw.arg == arg and kw.value for kw in node.keywords)
145 changes: 145 additions & 0 deletions tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_py_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import sys

import pasta
import pytest

from sagemaker.cli.compatibility.v2.modifiers import py_version
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call


@pytest.fixture(autouse=True)
def skip_if_py2():
# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed.
if sys.version_info.major < 3:
pytest.skip("v2 migration script doesn't support Python 2.")


@pytest.fixture
def constructor_framework_templates():
return (
"TensorFlow({})",
"sagemaker.tensorflow.TensorFlow({})",
"sagemaker.tensorflow.estimator.TensorFlow({})",
"MXNet({})",
"sagemaker.mxnet.MXNet({})",
"sagemaker.mxnet.estimator.MXNet({})",
"Chainer({})",
"sagemaker.chainer.Chainer({})",
"sagemaker.chainer.estimator.Chainer({})",
"PyTorch({})",
"sagemaker.pytorch.PyTorch({})",
"sagemaker.pytorch.estimator.PyTorch({})",
"SKLearn({})",
"sagemaker.sklearn.SKLearn({})",
"sagemaker.sklearn.estimator.SKLearn({})",
)


@pytest.fixture
def constructor_model_templates():
return (
"MXNetModel({})",
"sagemaker.mxnet.MXNetModel({})",
"sagemaker.mxnet.model.MXNetModel({})",
"ChainerModel({})",
"sagemaker.chainer.ChainerModel({})",
"sagemaker.chainer.model.ChainerModel({})",
"PyTorchModel({})",
"sagemaker.pytorch.PyTorchModel({})",
"sagemaker.pytorch.model.PyTorchModel({})",
)


@pytest.fixture
def constructor_templates(constructor_framework_templates, constructor_model_templates):
return tuple(list(constructor_framework_templates) + list(constructor_model_templates))


@pytest.fixture
def constructors_no_version(constructor_templates):
return (ctr.format("") for ctr in constructor_templates)


@pytest.fixture
def constructors_with_version(constructor_templates):
return (ctr.format("py_version='py3'") for ctr in constructor_templates)


@pytest.fixture
def constructors_with_image_name(constructor_framework_templates):
return (ctr.format("image_name='my:image'") for ctr in constructor_framework_templates)


@pytest.fixture
def constructors_with_image(constructor_model_templates):
return (ctr.format("image='my:image'") for ctr in constructor_model_templates)


@pytest.fixture
def constructors_version_not_needed():
return (
"TensorFlowModel()",
"sagemaker.tensorflow.TensorFlowModel()",
"sagemaker.tensorflow.model.TensorFlowModel()",
"SKLearnModel()",
"sagemaker.sklearn.SKLearnModel()",
"sagemaker.sklearn.model.SKLearnModel()",
)


def _test_modified(constructors, should_be):
modifier = py_version.PyVersionEnforcer()
for constructor in constructors:
node = ast_call(constructor)
if should_be:
assert modifier.node_should_be_modified(node)
else:
assert not modifier.node_should_be_modified(node)


def test_node_should_be_modified_fw_constructor_no_version(constructors_no_version):
_test_modified(constructors_no_version, should_be=True)


def test_node_should_be_modified_fw_constructor_with_version(constructors_with_version):
_test_modified(constructors_with_version, should_be=False)


def test_node_should_be_modified_fw_constructor_with_image_name(constructors_with_image_name):
_test_modified(constructors_with_image_name, should_be=False)


def test_node_should_be_modified_fw_constructor_with_image(constructors_with_image):
_test_modified(constructors_with_image, should_be=False)


def test_node_should_be_modified_fw_constructor_not_needed(constructors_version_not_needed):
Comment on lines +114 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't seem like any of these fixtures are used in multiple tests, so I'm not sure what the benefit is to defining fixtures for them rather than just creating the iterable in the test directly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was envisioning using them with framework_version tests in a separate pr.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha - that makes sense!

_test_modified(constructors_version_not_needed, should_be=False)


def test_node_should_be_modified_random_function_call():
_test_modified(["sagemaker.session.Session()"], should_be=False)


def test_modify_node(constructor_templates):
modifier = py_version.PyVersionEnforcer()
for template in constructor_templates:
no_version, with_version = template.format(""), template.format("py_version='py3'")
node = ast_call(no_version)
modifier.modify_node(node)

assert with_version == pasta.dump(node)