@@ -32,24 +32,34 @@ def test_node_should_be_modified_fw_constructor_no_fw_version():
32
32
fw_constructors = (
33
33
"TensorFlow()" ,
34
34
"sagemaker.tensorflow.TensorFlow()" ,
35
+ "sagemaker.tensorflow.estimator.TensorFlow()" ,
35
36
"TensorFlowModel()" ,
36
37
"sagemaker.tensorflow.TensorFlowModel()" ,
38
+ "sagemaker.tensorflow.model.TensorFlowModel()" ,
37
39
"MXNet()" ,
38
40
"sagemaker.mxnet.MXNet()" ,
41
+ "sagemaker.mxnet.estimator.MXNet()" ,
39
42
"MXNetModel()" ,
40
43
"sagemaker.mxnet.MXNetModel()" ,
44
+ "sagemaker.mxnet.model.MXNetModel()" ,
41
45
"Chainer()" ,
42
46
"sagemaker.chainer.Chainer()" ,
47
+ "sagemaker.chainer.estimator.Chainer()" ,
43
48
"ChainerModel()" ,
44
49
"sagemaker.chainer.ChainerModel()" ,
50
+ "sagemaker.chainer.model.ChainerModel()" ,
45
51
"PyTorch()" ,
46
52
"sagemaker.pytorch.PyTorch()" ,
53
+ "sagemaker.pytorch.estimator.PyTorch()" ,
47
54
"PyTorchModel()" ,
48
55
"sagemaker.pytorch.PyTorchModel()" ,
56
+ "sagemaker.pytorch.model.PyTorchModel()" ,
49
57
"SKLearn()" ,
50
58
"sagemaker.sklearn.SKLearn()" ,
59
+ "sagemaker.sklearn.estimator.SKLearn()" ,
51
60
"SKLearnModel()" ,
52
61
"sagemaker.sklearn.SKLearnModel()" ,
62
+ "sagemaker.sklearn.model.SKLearnModel()" ,
53
63
)
54
64
55
65
modifier = framework_version .FrameworkVersionEnforcer ()
@@ -63,24 +73,34 @@ def test_node_should_be_modified_fw_constructor_with_fw_version():
63
73
fw_constructors = (
64
74
"TensorFlow(framework_version='2.2')" ,
65
75
"sagemaker.tensorflow.TensorFlow(framework_version='2.2')" ,
76
+ "sagemaker.tensorflow.estimator.TensorFlow(framework_version='2.2')" ,
66
77
"TensorFlowModel(framework_version='1.10')" ,
67
78
"sagemaker.tensorflow.TensorFlowModel(framework_version='1.10')" ,
79
+ "sagemaker.tensorflow.model.TensorFlowModel(framework_version='1.10')" ,
68
80
"MXNet(framework_version='1.6')" ,
69
81
"sagemaker.mxnet.MXNet(framework_version='1.6')" ,
82
+ "sagemaker.mxnet.estimator.MXNet(framework_version='1.6')" ,
70
83
"MXNetModel(framework_version='1.6')" ,
71
84
"sagemaker.mxnet.MXNetModel(framework_version='1.6')" ,
85
+ "sagemaker.mxnet.model.MXNetModel(framework_version='1.6')" ,
72
86
"PyTorch(framework_version='1.4')" ,
73
87
"sagemaker.pytorch.PyTorch(framework_version='1.4')" ,
88
+ "sagemaker.pytorch.estimator.PyTorch(framework_version='1.4')" ,
74
89
"PyTorchModel(framework_version='1.4')" ,
75
90
"sagemaker.pytorch.PyTorchModel(framework_version='1.4')" ,
91
+ "sagemaker.pytorch.model.PyTorchModel(framework_version='1.4')" ,
76
92
"Chainer(framework_version='5.0')" ,
77
93
"sagemaker.chainer.Chainer(framework_version='5.0')" ,
94
+ "sagemaker.chainer.estimator.Chainer(framework_version='5.0')" ,
78
95
"ChainerModel(framework_version='5.0')" ,
79
96
"sagemaker.chainer.ChainerModel(framework_version='5.0')" ,
97
+ "sagemaker.chainer.model.ChainerModel(framework_version='5.0')" ,
80
98
"SKLearn(framework_version='0.20.0')" ,
81
99
"sagemaker.sklearn.SKLearn(framework_version='0.20.0')" ,
100
+ "sagemaker.sklearn.estimator.SKLearn(framework_version='0.20.0')" ,
82
101
"SKLearnModel(framework_version='0.20.0')" ,
83
102
"sagemaker.sklearn.SKLearnModel(framework_version='0.20.0')" ,
103
+ "sagemaker.sklearn.model.SKLearnModel(framework_version='0.20.0')" ,
84
104
)
85
105
86
106
modifier = framework_version .FrameworkVersionEnforcer ()
@@ -97,51 +117,36 @@ def test_node_should_be_modified_random_function_call():
97
117
98
118
99
119
def test_modify_node_tf ():
100
- classes = (
101
- "TensorFlow" "sagemaker.tensorflow.TensorFlow" ,
102
- "TensorFlowModel" ,
103
- "sagemaker.tensorflow.TensorFlowModel" ,
104
- )
105
- _test_modify_node (classes , "1.11.0" )
120
+ _test_modify_node ("TensorFlow" , "1.11.0" )
106
121
107
122
108
123
def test_modify_node_mx ():
109
- classes = ("MXNet" , "sagemaker.mxnet.MXNet" , "MXNetModel" , "sagemaker.mxnet.MXNetModel" )
110
- _test_modify_node (classes , "1.2.0" )
124
+ _test_modify_node ("MXNet" , "1.2.0" )
111
125
112
126
113
127
def test_modify_node_chainer ():
114
- classes = (
115
- "Chainer" ,
116
- "sagemaker.chainer.Chainer" ,
117
- "ChainerModel" ,
118
- "sagemaker.chainer.ChainerModel" ,
119
- )
120
- _test_modify_node (classes , "4.1.0" )
128
+ _test_modify_node ("Chainer" , "4.1.0" )
121
129
122
130
123
131
def test_modify_node_pt ():
124
- classes = (
125
- "PyTorch" ,
126
- "sagemaker.pytorch.PyTorch" ,
127
- "PyTorchModel" ,
128
- "sagemaker.pytorch.PyTorchModel" ,
129
- )
130
- _test_modify_node (classes , "0.4.0" )
132
+ _test_modify_node ("PyTorch" , "0.4.0" )
131
133
132
134
133
135
def test_modify_node_sklearn ():
134
- classes = (
135
- "SKLearn" ,
136
- "sagemaker.sklearn.SKLearn" ,
137
- "SKLearnModel" ,
138
- "sagemaker.sklearn.SKLearnModel" ,
139
- )
140
- _test_modify_node (classes , "0.20.0" )
136
+ _test_modify_node ("SKLearn" , "0.20.0" )
141
137
142
138
143
- def _test_modify_node (classes , default_version ):
139
+ def _test_modify_node (framework , default_version ):
144
140
modifier = framework_version .FrameworkVersionEnforcer ()
141
+
142
+ classes = (
143
+ "{}" .format (framework ),
144
+ "sagemaker.{}.{}" .format (framework .lower (), framework ),
145
+ "sagemaker.{}.estimator.{}" .format (framework .lower (), framework ),
146
+ "{}Model" .format (framework ),
147
+ "sagemaker.{}.{}Model" .format (framework .lower (), framework ),
148
+ "sagemaker.{}.model.{}Model" .format (framework .lower (), framework ),
149
+ )
145
150
for cls in classes :
146
151
node = ast_call ("{}()" .format (cls ))
147
152
modifier .modify_node (node )
0 commit comments