12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
14
15
- import gzip
16
15
import os
17
- import pickle
18
- import sys
19
- import pytest
20
- import tests .integ
21
16
17
+ import airflow
18
+ import pytest
22
19
import numpy as np
20
+ from airflow import DAG
21
+ from airflow .contrib .operators .sagemaker_training_operator import SageMakerTrainingOperator
22
+ from airflow .contrib .operators .sagemaker_transform_operator import SageMakerTransformOperator
23
+ from six .moves .urllib .parse import urlparse
23
24
25
+ import tests .integ
24
26
from sagemaker import (
25
27
KMeans ,
26
28
FactorizationMachines ,
40
42
from sagemaker .pytorch .estimator import PyTorch
41
43
from sagemaker .sklearn import SKLearn
42
44
from sagemaker .tensorflow import TensorFlow
43
- from sagemaker .workflow import airflow as sm_airflow
44
45
from sagemaker .utils import sagemaker_timestamp
45
-
46
- import airflow
47
- from airflow import DAG
48
- from airflow .contrib .operators .sagemaker_training_operator import SageMakerTrainingOperator
49
- from airflow .contrib .operators .sagemaker_transform_operator import SageMakerTransformOperator
50
-
46
+ from sagemaker .workflow import airflow as sm_airflow
51
47
from sagemaker .xgboost import XGBoost
52
- from tests .integ import DATA_DIR , PYTHON_VERSION
48
+ from tests .integ import datasets , DATA_DIR , PYTHON_VERSION
53
49
from tests .integ .record_set import prepare_record_set_from_local_files
54
50
from tests .integ .timeout import timeout
55
51
56
- from six .moves .urllib .parse import urlparse
57
-
58
52
PYTORCH_MNIST_DIR = os .path .join (DATA_DIR , "pytorch_mnist" )
59
53
PYTORCH_MNIST_SCRIPT = os .path .join (PYTORCH_MNIST_DIR , "mnist.py" )
60
54
AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS = 10
@@ -101,13 +95,6 @@ def test_byo_airflow_config_uploads_data_source_to_s3_when_inputs_provided(
101
95
@pytest .mark .canary_quick
102
96
def test_kmeans_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
103
97
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
104
- data_path = os .path .join (DATA_DIR , "one_p_mnist" , "mnist.pkl.gz" )
105
- pickle_args = {} if sys .version_info .major == 2 else {"encoding" : "latin1" }
106
-
107
- # Load the data into memory as numpy arrays
108
- with gzip .open (data_path , "rb" ) as f :
109
- train_set , _ , _ = pickle .load (f , ** pickle_args )
110
-
111
98
kmeans = KMeans (
112
99
role = ROLE ,
113
100
train_instance_count = SINGLE_INSTANCE_COUNT ,
@@ -126,7 +113,7 @@ def test_kmeans_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_
126
113
kmeans .center_factor = 1
127
114
kmeans .eval_metrics = ["ssd" , "msd" ]
128
115
129
- records = kmeans .record_set (train_set [0 ][:100 ])
116
+ records = kmeans .record_set (datasets . one_p_mnist () [0 ][:100 ])
130
117
131
118
training_config = _build_airflow_workflow (
132
119
estimator = kmeans , instance_type = cpu_instance_type , inputs = records
@@ -140,13 +127,6 @@ def test_kmeans_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_
140
127
141
128
def test_fm_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
142
129
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
143
- data_path = os .path .join (DATA_DIR , "one_p_mnist" , "mnist.pkl.gz" )
144
- pickle_args = {} if sys .version_info .major == 2 else {"encoding" : "latin1" }
145
-
146
- # Load the data into memory as numpy arrays
147
- with gzip .open (data_path , "rb" ) as f :
148
- train_set , _ , _ = pickle .load (f , ** pickle_args )
149
-
150
130
fm = FactorizationMachines (
151
131
role = ROLE ,
152
132
train_instance_count = SINGLE_INSTANCE_COUNT ,
@@ -160,7 +140,8 @@ def test_fm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_inst
160
140
sagemaker_session = sagemaker_session ,
161
141
)
162
142
163
- records = fm .record_set (train_set [0 ][:200 ], train_set [1 ][:200 ].astype ("float32" ))
143
+ training_set = datasets .one_p_mnist ()
144
+ records = fm .record_set (training_set [0 ][:200 ], training_set [1 ][:200 ].astype ("float32" ))
164
145
165
146
training_config = _build_airflow_workflow (
166
147
estimator = fm , instance_type = cpu_instance_type , inputs = records
@@ -206,13 +187,6 @@ def test_ipinsights_airflow_config_uploads_data_source_to_s3(sagemaker_session,
206
187
207
188
def test_knn_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
208
189
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
209
- data_path = os .path .join (DATA_DIR , "one_p_mnist" , "mnist.pkl.gz" )
210
- pickle_args = {} if sys .version_info .major == 2 else {"encoding" : "latin1" }
211
-
212
- # Load the data into memory as numpy arrays
213
- with gzip .open (data_path , "rb" ) as f :
214
- train_set , _ , _ = pickle .load (f , ** pickle_args )
215
-
216
190
knn = KNN (
217
191
role = ROLE ,
218
192
train_instance_count = SINGLE_INSTANCE_COUNT ,
@@ -223,7 +197,8 @@ def test_knn_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
223
197
sagemaker_session = sagemaker_session ,
224
198
)
225
199
226
- records = knn .record_set (train_set [0 ][:200 ], train_set [1 ][:200 ].astype ("float32" ))
200
+ training_set = datasets .one_p_mnist ()
201
+ records = knn .record_set (training_set [0 ][:200 ], training_set [1 ][:200 ].astype ("float32" ))
227
202
228
203
training_config = _build_airflow_workflow (
229
204
estimator = knn , instance_type = cpu_instance_type , inputs = records
@@ -277,16 +252,10 @@ def test_linearlearner_airflow_config_uploads_data_source_to_s3(
277
252
sagemaker_session , cpu_instance_type
278
253
):
279
254
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
280
- data_path = os .path .join (DATA_DIR , "one_p_mnist" , "mnist.pkl.gz" )
281
- pickle_args = {} if sys .version_info .major == 2 else {"encoding" : "latin1" }
282
-
283
- # Load the data into memory as numpy arrays
284
- with gzip .open (data_path , "rb" ) as f :
285
- train_set , _ , _ = pickle .load (f , ** pickle_args )
286
-
287
- train_set [1 ][:100 ] = 1
288
- train_set [1 ][100 :200 ] = 0
289
- train_set = train_set [0 ], train_set [1 ].astype (np .dtype ("float32" ))
255
+ training_set = datasets .one_p_mnist ()
256
+ training_set [1 ][:100 ] = 1
257
+ training_set [1 ][100 :200 ] = 0
258
+ training_set = training_set [0 ], training_set [1 ].astype (np .dtype ("float32" ))
290
259
291
260
ll = LinearLearner (
292
261
ROLE ,
@@ -331,7 +300,7 @@ def test_linearlearner_airflow_config_uploads_data_source_to_s3(
331
300
ll .early_stopping_tolerance = 0.0001
332
301
ll .early_stopping_patience = 3
333
302
334
- records = ll .record_set (train_set [0 ][:200 ], train_set [1 ][:200 ])
303
+ records = ll .record_set (training_set [0 ][:200 ], training_set [1 ][:200 ])
335
304
336
305
training_config = _build_airflow_workflow (
337
306
estimator = ll , instance_type = cpu_instance_type , inputs = records
@@ -380,13 +349,6 @@ def test_ntm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
380
349
@pytest .mark .canary_quick
381
350
def test_pca_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
382
351
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
383
- data_path = os .path .join (DATA_DIR , "one_p_mnist" , "mnist.pkl.gz" )
384
- pickle_args = {} if sys .version_info .major == 2 else {"encoding" : "latin1" }
385
-
386
- # Load the data into memory as numpy arrays
387
- with gzip .open (data_path , "rb" ) as f :
388
- train_set , _ , _ = pickle .load (f , ** pickle_args )
389
-
390
352
pca = PCA (
391
353
role = ROLE ,
392
354
train_instance_count = SINGLE_INSTANCE_COUNT ,
@@ -399,7 +361,7 @@ def test_pca_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
399
361
pca .subtract_mean = True
400
362
pca .extra_components = 5
401
363
402
- records = pca .record_set (train_set [0 ][:100 ])
364
+ records = pca .record_set (datasets . one_p_mnist () [0 ][:100 ])
403
365
404
366
training_config = _build_airflow_workflow (
405
367
estimator = pca , instance_type = cpu_instance_type , inputs = records
0 commit comments