Skip to content

change: update v2 migration tool to rename TFS classes/imports #1552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/sagemaker/cli/compatibility/v2/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
modifiers.deprecated_params.TensorFlowScriptModeParameterRemover(),
modifiers.tfs.TensorFlowServingConstructorRenamer(),
]

IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]

IMPORT_FROM_MODIFIERS = [modifiers.tfs.TensorFlowServingImportFromRenamer()]


class ASTTransformer(ast.NodeTransformer):
"""An ``ast.NodeTransformer`` subclass that walks the abstract syntax tree and
Expand All @@ -46,3 +51,37 @@ def visit_Call(self, node):

ast.fix_missing_locations(node)
return node

def visit_Import(self, node):
"""Visits an ``ast.Import`` node and returns a modified node, if needed.
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.

Args:
node (ast.Import): a node that represents an import statement.

Returns:
ast.Import: a node that represents an import statement, which has
potentially been modified from the original input.
"""
for import_checker in IMPORT_MODIFIERS:
import_checker.check_and_modify_node(node)

ast.fix_missing_locations(node)
return node

def visit_ImportFrom(self, node):
"""Visits an ``ast.ImportFrom`` node and returns a modified node, if needed.
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.

Args:
node (ast.ImportFrom): a node that represents an import statement.

Returns:
ast.ImportFrom: a node that represents an import statement, which has
potentially been modified from the original input.
"""
for import_checker in IMPORT_FROM_MODIFIERS:
import_checker.check_and_modify_node(node)

ast.fix_missing_locations(node)
return node
1 change: 1 addition & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
deprecated_params,
framework_version,
tf_legacy_mode,
tfs,
)
121 changes: 121 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/tfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Classes to modify TensorFlow Serving code to be compatible with SageMaker Python SDK v2."""
from __future__ import absolute_import

import ast

from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier


class TensorFlowServingConstructorRenamer(Modifier):
"""A class to rename TensorFlow Serving classes."""

def node_should_be_modified(self, node):
"""Checks if the ``ast.Call`` node instantiates a TensorFlow Serving class.

This looks for the following calls:

- ``sagemaker.tensorflow.serving.Model``
- ``sagemaker.tensorflow.serving.Predictor``
- ``Predictor``

Because ``Model`` can refer to either ``sagemaker.tensorflow.serving.Model``
or :class:`~sagemaker.model.Model`, ``Model`` on its own is not sufficient
for indicating a TFS Model object.

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
bool: If the ``ast.Call`` instantiates a TensorFlow Serving class.
"""
if isinstance(node.func, ast.Name):
return node.func.id == "Predictor"

if not (isinstance(node.func, ast.Attribute) and node.func.attr in ("Model", "Predictor")):
return False

return (
isinstance(node.func.value, ast.Attribute)
and node.func.value.attr == "serving"
and isinstance(node.func.value.value, ast.Attribute)
and node.func.value.value.attr == "tensorflow"
and isinstance(node.func.value.value.value, ast.Name)
and node.func.value.value.value.id == "sagemaker"
)

def modify_node(self, node):
"""Modifies the ``ast.Call`` node to use the v2 classes for TensorFlow Serving:

- ``sagemaker.tensorflow.TensorFlowModel``
- ``sagemaker.tensorflow.TensorFlowPredictor``

Args:
node (ast.Call): a node that represents a TensorFlow Serving constructor.
"""
if isinstance(node.func, ast.Name):
node.func.id = self._new_cls_name(node.func.id)
else:
node.func.attr = self._new_cls_name(node.func.attr)
node.func.value = node.func.value.value

def _new_cls_name(self, cls_name):
"""Returns the v2 class name."""
return "TensorFlow{}".format(cls_name)


class TensorFlowServingImportFromRenamer(Modifier):
"""A class to update import statements starting with ``from sagemaker.tensorflow.serving``."""

def node_should_be_modified(self, node):
"""Checks if the import statement imports from the ``sagemaker.tensorflow.serving`` module.

Args:
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
bool: If the ``ast.ImportFrom`` uses the ``sagemaker.tensorflow.serving`` module.
"""
return node.module == "sagemaker.tensorflow.serving"

def modify_node(self, node):
"""Changes the ``ast.ImportFrom`` node's module to ``sagemaker.tensorflow`` and updates the
imported class names to ``TensorFlowModel`` and ``TensorFlowPredictor``, as applicable.

Args:
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
"""
node.module = "sagemaker.tensorflow"

for cls in node.names:
cls.name = "TensorFlow{}".format(cls.name)


class TensorFlowServingImportRenamer(Modifier):
"""A class to update ``import sagemaker.tensorflow.serving``."""

def check_and_modify_node(self, node):
"""Checks if the ``ast.Import`` node imports the ``sagemaker.tensorflow.serving`` module
and, if so, changes it to ``sagemaker.tensorflow``.

Args:
node (ast.Import): a node that represents an import statement. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.
"""
for module in node.names:
if module.name == "sagemaker.tensorflow.serving":
module.name = "sagemaker.tensorflow"
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@

def ast_call(code):
return pasta.parse(code).body[0].value


def ast_import(code):
return pasta.parse(code).body[0]
107 changes: 107 additions & 0 deletions tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pasta

from sagemaker.cli.compatibility.v2.modifiers import tfs
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import


def test_constructor_node_should_be_modified_tfs_constructor():
tfs_constructors = (
"sagemaker.tensorflow.serving.Model()",
"sagemaker.tensorflow.serving.Predictor()",
"Predictor()",
)

modifier = tfs.TensorFlowServingConstructorRenamer()

for constructor in tfs_constructors:
node = ast_call(constructor)
assert modifier.node_should_be_modified(node) is True


def test_constructor_node_should_be_modified_random_function_call():
modifier = tfs.TensorFlowServingConstructorRenamer()
node = ast_call("Model()")
assert modifier.node_should_be_modified(node) is False


def test_constructor_modify_node():
modifier = tfs.TensorFlowServingConstructorRenamer()

node = ast_call("sagemaker.tensorflow.serving.Model()")
modifier.modify_node(node)
assert "sagemaker.tensorflow.TensorFlowModel()" == pasta.dump(node)

node = ast_call("sagemaker.tensorflow.serving.Predictor()")
modifier.modify_node(node)
assert "sagemaker.tensorflow.TensorFlowPredictor()" == pasta.dump(node)

node = ast_call("Predictor()")
modifier.modify_node(node)
assert "TensorFlowPredictor()" == pasta.dump(node)


def test_import_from_node_should_be_modified_tfs_module():
import_statements = (
"from sagemaker.tensorflow.serving import Model, Predictor",
"from sagemaker.tensorflow.serving import Predictor",
"from sagemaker.tensorflow.serving import Model as tfsModel",
)

modifier = tfs.TensorFlowServingImportFromRenamer()

for import_from in import_statements:
node = ast_import(import_from)
assert modifier.node_should_be_modified(node) is True


def test_import_from_node_should_be_modified_random_import():
modifier = tfs.TensorFlowServingImportFromRenamer()
node = ast_import("from sagemaker import Session")
assert modifier.node_should_be_modified(node) is False


def test_import_from_modify_node():
modifier = tfs.TensorFlowServingImportFromRenamer()

node = ast_import("from sagemaker.tensorflow.serving import Model, Predictor")
modifier.modify_node(node)
expected_result = "from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor"
assert expected_result == pasta.dump(node)

node = ast_import("from sagemaker.tensorflow.serving import Predictor")
modifier.modify_node(node)
assert "from sagemaker.tensorflow import TensorFlowPredictor" == pasta.dump(node)

node = ast_import("from sagemaker.tensorflow.serving import Model as tfsModel")
modifier.modify_node(node)
assert "from sagemaker.tensorflow import TensorFlowModel as tfsModel" == pasta.dump(node)


def test_import_check_and_modify_node_tfs_import():
modifier = tfs.TensorFlowServingImportRenamer()
node = ast_import("import sagemaker.tensorflow.serving")
modifier.check_and_modify_node(node)
assert "import sagemaker.tensorflow" == pasta.dump(node)


def test_import_check_and_modify_node_random_import():
modifier = tfs.TensorFlowServingImportRenamer()

import_statement = "import random"
node = ast_import(import_statement)
modifier.check_and_modify_node(node)
assert import_statement == pasta.dump(node)