Skip to content

Commit 059c98e

Browse files
authored
fix: check ast node on later renamers for cli v2 updater (#1848)
1 parent 03ac216 commit 059c98e

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/image_uris.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ def node_should_be_modified(self, node):
112112
Returns:
113113
bool: If the import statement imports ``get_image_uri`` from the correct module.
114114
"""
115-
return node.module in GET_IMAGE_URI_NAMESPACES and any(
116-
name.name == GET_IMAGE_URI_NAME for name in node.names
115+
return (
116+
node is not None
117+
and node.module in GET_IMAGE_URI_NAMESPACES
118+
and any(name.name == GET_IMAGE_URI_NAME for name in node.names)
117119
)
118120

119121
def modify_node(self, node):
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import textwrap
16+
17+
import pasta
18+
import pytest
19+
20+
from sagemaker.cli.compatibility.v2.ast_transformer import ASTTransformer
21+
22+
23+
@pytest.fixture
24+
def input_code():
25+
return textwrap.dedent(
26+
"""
27+
from sagemaker.predictor import csv_serializer
28+
29+
csv_serializer.__doc__
30+
"""
31+
)
32+
33+
34+
@pytest.fixture
35+
def output_code():
36+
return textwrap.dedent(
37+
"""
38+
from sagemaker import serializers
39+
40+
serializers.CSVSerializer().__doc__
41+
"""
42+
)
43+
44+
45+
def test_simple_script(input_code, output_code):
46+
input_ast = pasta.parse(input_code)
47+
output_ast = ASTTransformer().visit(input_ast)
48+
assert pasta.dump(output_ast) == output_code

0 commit comments

Comments
 (0)