diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 128ddbe235..df4a7fd85f 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -18,6 +18,7 @@ from sagemaker.cli.compatibility.v2 import modifiers FUNCTION_CALL_MODIFIERS = [ + modifiers.py_version.PyVersionEnforcer(), modifiers.framework_version.FrameworkVersionEnforcer(), modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(), modifiers.tf_legacy_mode.TensorBoardParameterRemover(), diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py index 3a68a0b735..b729fee535 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py @@ -16,6 +16,7 @@ from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused) deprecated_params, framework_version, + py_version, tf_legacy_mode, tfs, ) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/py_version.py b/src/sagemaker/cli/compatibility/v2/modifiers/py_version.py new file mode 100644 index 0000000000..a66672cb0e --- /dev/null +++ b/src/sagemaker/cli/compatibility/v2/modifiers/py_version.py @@ -0,0 +1,124 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""A class to ensure that ``py_version`` is defined when constructing framework classes.""" +from __future__ import absolute_import + +import ast + +from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier + +PY_VERSION_ARG = "py_version" +PY_VERSION_DEFAULT = "py3" + +FRAMEWORK_MODEL_REQUIRES_PY_VERSION = { + "Chainer": True, + "MXNet": True, + "PyTorch": True, + "SKLearn": False, + "TensorFlow": False, +} + +FRAMEWORK_CLASSES = list(FRAMEWORK_MODEL_REQUIRES_PY_VERSION.keys()) + +MODEL_CLASSES = [ + "{}Model".format(fw) for fw, required in FRAMEWORK_MODEL_REQUIRES_PY_VERSION.items() if required +] +FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORK_CLASSES] +FRAMEWORK_SUBMODULES = ("model", "estimator") + + +class PyVersionEnforcer(Modifier): + """A class to ensure that ``py_version`` is defined when + instantiating a framework estimator or model, where appropriate. + """ + + def node_should_be_modified(self, node): + """Checks if the ast.Call node should be modified to include ``py_version``. + + If the ast.Call node instantiates a framework estimator or model, but doesn't + specify the ``py_version`` parameter when required, then the node should be + modified. However, if ``image_name`` for a framework estimator or ``image`` + for a model is supplied to the call, then ``py_version`` is not required. + + This looks for the following formats: + + - ``PyTorch`` + - ``sagemaker.pytorch.PyTorch`` + + where "PyTorch" can be Chainer, MXNet, PyTorch, SKLearn, or TensorFlow. + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Call`` is instantiating a framework class that + should specify ``py_version``, but doesn't. + """ + if _is_named_constructor(node, FRAMEWORK_CLASSES): + return _version_arg_needed(node, "image_name", PY_VERSION_ARG) + + if _is_named_constructor(node, MODEL_CLASSES): + return _version_arg_needed(node, "image", PY_VERSION_ARG) + + return False + + def modify_node(self, node): + """Modifies the ``ast.Call`` node's keywords to include ``py_version``. + + Args: + node (ast.Call): a node that represents the constructor of a framework class. + """ + node.keywords.append(ast.keyword(arg=PY_VERSION_ARG, value=ast.Str(s=PY_VERSION_DEFAULT))) + + +def _is_named_constructor(node, names): + """Checks if the ``ast.Call`` node represents a call to particular named constructors. + + Forms that qualify are either or sagemaker.. + where belongs to the list of names passed in. + """ + # Check for call from particular names of constructors + if isinstance(node.func, ast.Name): + return node.func.id in names + + # Check for something.that.ends.with.. call for Framework in names + if not (isinstance(node.func, ast.Attribute) and node.func.attr in names): + return False + + # Check for sagemaker... call + if isinstance(node.func.value, ast.Attribute) and node.func.value.attr in FRAMEWORK_SUBMODULES: + return _is_in_framework_module(node.func.value) + + # Check for sagemaker.. call + return _is_in_framework_module(node.func) + + +def _is_in_framework_module(node): + """Checks if node is an ``ast.Attribute`` representing a ``sagemaker.`` module.""" + return ( + isinstance(node.value, ast.Attribute) + and node.value.attr in FRAMEWORK_MODULES + and isinstance(node.value.value, ast.Name) + and node.value.value.id == "sagemaker" + ) + + +def _version_arg_needed(node, image_arg, version_arg): + """Determines if image_arg or version_arg was supplied""" + return not (_arg_supplied(node, image_arg) or _arg_supplied(node, version_arg)) + + +def _arg_supplied(node, arg): + """Checks if the ``ast.Call`` node's keywords contain ``arg``.""" + return any(kw.arg == arg and kw.value for kw in node.keywords) diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_py_version.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_py_version.py new file mode 100644 index 0000000000..f7331f6c34 --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_py_version.py @@ -0,0 +1,145 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import sys + +import pasta +import pytest + +from sagemaker.cli.compatibility.v2.modifiers import py_version +from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call + + +@pytest.fixture(autouse=True) +def skip_if_py2(): + # Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed. + if sys.version_info.major < 3: + pytest.skip("v2 migration script doesn't support Python 2.") + + +@pytest.fixture +def constructor_framework_templates(): + return ( + "TensorFlow({})", + "sagemaker.tensorflow.TensorFlow({})", + "sagemaker.tensorflow.estimator.TensorFlow({})", + "MXNet({})", + "sagemaker.mxnet.MXNet({})", + "sagemaker.mxnet.estimator.MXNet({})", + "Chainer({})", + "sagemaker.chainer.Chainer({})", + "sagemaker.chainer.estimator.Chainer({})", + "PyTorch({})", + "sagemaker.pytorch.PyTorch({})", + "sagemaker.pytorch.estimator.PyTorch({})", + "SKLearn({})", + "sagemaker.sklearn.SKLearn({})", + "sagemaker.sklearn.estimator.SKLearn({})", + ) + + +@pytest.fixture +def constructor_model_templates(): + return ( + "MXNetModel({})", + "sagemaker.mxnet.MXNetModel({})", + "sagemaker.mxnet.model.MXNetModel({})", + "ChainerModel({})", + "sagemaker.chainer.ChainerModel({})", + "sagemaker.chainer.model.ChainerModel({})", + "PyTorchModel({})", + "sagemaker.pytorch.PyTorchModel({})", + "sagemaker.pytorch.model.PyTorchModel({})", + ) + + +@pytest.fixture +def constructor_templates(constructor_framework_templates, constructor_model_templates): + return tuple(list(constructor_framework_templates) + list(constructor_model_templates)) + + +@pytest.fixture +def constructors_no_version(constructor_templates): + return (ctr.format("") for ctr in constructor_templates) + + +@pytest.fixture +def constructors_with_version(constructor_templates): + return (ctr.format("py_version='py3'") for ctr in constructor_templates) + + +@pytest.fixture +def constructors_with_image_name(constructor_framework_templates): + return (ctr.format("image_name='my:image'") for ctr in constructor_framework_templates) + + +@pytest.fixture +def constructors_with_image(constructor_model_templates): + return (ctr.format("image='my:image'") for ctr in constructor_model_templates) + + +@pytest.fixture +def constructors_version_not_needed(): + return ( + "TensorFlowModel()", + "sagemaker.tensorflow.TensorFlowModel()", + "sagemaker.tensorflow.model.TensorFlowModel()", + "SKLearnModel()", + "sagemaker.sklearn.SKLearnModel()", + "sagemaker.sklearn.model.SKLearnModel()", + ) + + +def _test_modified(constructors, should_be): + modifier = py_version.PyVersionEnforcer() + for constructor in constructors: + node = ast_call(constructor) + if should_be: + assert modifier.node_should_be_modified(node) + else: + assert not modifier.node_should_be_modified(node) + + +def test_node_should_be_modified_fw_constructor_no_version(constructors_no_version): + _test_modified(constructors_no_version, should_be=True) + + +def test_node_should_be_modified_fw_constructor_with_version(constructors_with_version): + _test_modified(constructors_with_version, should_be=False) + + +def test_node_should_be_modified_fw_constructor_with_image_name(constructors_with_image_name): + _test_modified(constructors_with_image_name, should_be=False) + + +def test_node_should_be_modified_fw_constructor_with_image(constructors_with_image): + _test_modified(constructors_with_image, should_be=False) + + +def test_node_should_be_modified_fw_constructor_not_needed(constructors_version_not_needed): + _test_modified(constructors_version_not_needed, should_be=False) + + +def test_node_should_be_modified_random_function_call(): + _test_modified(["sagemaker.session.Session()"], should_be=False) + + +def test_modify_node(constructor_templates): + modifier = py_version.PyVersionEnforcer() + for template in constructor_templates: + no_version, with_version = template.format(""), template.format("py_version='py3'") + node = ast_call(no_version) + modifier.modify_node(node) + + assert with_version == pasta.dump(node)