Skip to content

Commit c211417

Browse files
authored
infra: clean up pickle.load logic in integ tests (#1611)
Because we no longer run our tests with Python 2, we no longer need the branched logic for pickle.load args.
1 parent 124d6e0 commit c211417

12 files changed

+163
-288
lines changed

tests/integ/datasets.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import gzip
16+
import os
17+
import pickle
18+
19+
from tests.integ import DATA_DIR
20+
21+
22+
def one_p_mnist():
23+
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")
24+
with gzip.open(data_path, "rb") as f:
25+
training_set, _, _ = pickle.load(f, encoding="latin1")
26+
27+
return training_set

tests/integ/test_airflow_config.py

Lines changed: 20 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,17 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import gzip
1615
import os
17-
import pickle
18-
import sys
19-
import pytest
20-
import tests.integ
2116

17+
import airflow
18+
import pytest
2219
import numpy as np
20+
from airflow import DAG
21+
from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator
22+
from airflow.contrib.operators.sagemaker_transform_operator import SageMakerTransformOperator
23+
from six.moves.urllib.parse import urlparse
2324

25+
import tests.integ
2426
from sagemaker import (
2527
KMeans,
2628
FactorizationMachines,
@@ -40,21 +42,13 @@
4042
from sagemaker.pytorch.estimator import PyTorch
4143
from sagemaker.sklearn import SKLearn
4244
from sagemaker.tensorflow import TensorFlow
43-
from sagemaker.workflow import airflow as sm_airflow
4445
from sagemaker.utils import sagemaker_timestamp
45-
46-
import airflow
47-
from airflow import DAG
48-
from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator
49-
from airflow.contrib.operators.sagemaker_transform_operator import SageMakerTransformOperator
50-
46+
from sagemaker.workflow import airflow as sm_airflow
5147
from sagemaker.xgboost import XGBoost
52-
from tests.integ import DATA_DIR, PYTHON_VERSION
48+
from tests.integ import datasets, DATA_DIR, PYTHON_VERSION
5349
from tests.integ.record_set import prepare_record_set_from_local_files
5450
from tests.integ.timeout import timeout
5551

56-
from six.moves.urllib.parse import urlparse
57-
5852
PYTORCH_MNIST_DIR = os.path.join(DATA_DIR, "pytorch_mnist")
5953
PYTORCH_MNIST_SCRIPT = os.path.join(PYTORCH_MNIST_DIR, "mnist.py")
6054
AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS = 10
@@ -101,13 +95,6 @@ def test_byo_airflow_config_uploads_data_source_to_s3_when_inputs_provided(
10195
@pytest.mark.canary_quick
10296
def test_kmeans_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_instance_type):
10397
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
104-
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")
105-
pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"}
106-
107-
# Load the data into memory as numpy arrays
108-
with gzip.open(data_path, "rb") as f:
109-
train_set, _, _ = pickle.load(f, **pickle_args)
110-
11198
kmeans = KMeans(
11299
role=ROLE,
113100
train_instance_count=SINGLE_INSTANCE_COUNT,
@@ -126,7 +113,7 @@ def test_kmeans_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_
126113
kmeans.center_factor = 1
127114
kmeans.eval_metrics = ["ssd", "msd"]
128115

129-
records = kmeans.record_set(train_set[0][:100])
116+
records = kmeans.record_set(datasets.one_p_mnist()[0][:100])
130117

131118
training_config = _build_airflow_workflow(
132119
estimator=kmeans, instance_type=cpu_instance_type, inputs=records
@@ -140,13 +127,6 @@ def test_kmeans_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_
140127

141128
def test_fm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_instance_type):
142129
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
143-
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")
144-
pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"}
145-
146-
# Load the data into memory as numpy arrays
147-
with gzip.open(data_path, "rb") as f:
148-
train_set, _, _ = pickle.load(f, **pickle_args)
149-
150130
fm = FactorizationMachines(
151131
role=ROLE,
152132
train_instance_count=SINGLE_INSTANCE_COUNT,
@@ -160,7 +140,8 @@ def test_fm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_inst
160140
sagemaker_session=sagemaker_session,
161141
)
162142

163-
records = fm.record_set(train_set[0][:200], train_set[1][:200].astype("float32"))
143+
training_set = datasets.one_p_mnist()
144+
records = fm.record_set(training_set[0][:200], training_set[1][:200].astype("float32"))
164145

165146
training_config = _build_airflow_workflow(
166147
estimator=fm, instance_type=cpu_instance_type, inputs=records
@@ -206,13 +187,6 @@ def test_ipinsights_airflow_config_uploads_data_source_to_s3(sagemaker_session,
206187

207188
def test_knn_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_instance_type):
208189
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
209-
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")
210-
pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"}
211-
212-
# Load the data into memory as numpy arrays
213-
with gzip.open(data_path, "rb") as f:
214-
train_set, _, _ = pickle.load(f, **pickle_args)
215-
216190
knn = KNN(
217191
role=ROLE,
218192
train_instance_count=SINGLE_INSTANCE_COUNT,
@@ -223,7 +197,8 @@ def test_knn_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
223197
sagemaker_session=sagemaker_session,
224198
)
225199

226-
records = knn.record_set(train_set[0][:200], train_set[1][:200].astype("float32"))
200+
training_set = datasets.one_p_mnist()
201+
records = knn.record_set(training_set[0][:200], training_set[1][:200].astype("float32"))
227202

228203
training_config = _build_airflow_workflow(
229204
estimator=knn, instance_type=cpu_instance_type, inputs=records
@@ -277,16 +252,10 @@ def test_linearlearner_airflow_config_uploads_data_source_to_s3(
277252
sagemaker_session, cpu_instance_type
278253
):
279254
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
280-
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")
281-
pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"}
282-
283-
# Load the data into memory as numpy arrays
284-
with gzip.open(data_path, "rb") as f:
285-
train_set, _, _ = pickle.load(f, **pickle_args)
286-
287-
train_set[1][:100] = 1
288-
train_set[1][100:200] = 0
289-
train_set = train_set[0], train_set[1].astype(np.dtype("float32"))
255+
training_set = datasets.one_p_mnist()
256+
training_set[1][:100] = 1
257+
training_set[1][100:200] = 0
258+
training_set = training_set[0], training_set[1].astype(np.dtype("float32"))
290259

291260
ll = LinearLearner(
292261
ROLE,
@@ -331,7 +300,7 @@ def test_linearlearner_airflow_config_uploads_data_source_to_s3(
331300
ll.early_stopping_tolerance = 0.0001
332301
ll.early_stopping_patience = 3
333302

334-
records = ll.record_set(train_set[0][:200], train_set[1][:200])
303+
records = ll.record_set(training_set[0][:200], training_set[1][:200])
335304

336305
training_config = _build_airflow_workflow(
337306
estimator=ll, instance_type=cpu_instance_type, inputs=records
@@ -380,13 +349,6 @@ def test_ntm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
380349
@pytest.mark.canary_quick
381350
def test_pca_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_instance_type):
382351
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
383-
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")
384-
pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"}
385-
386-
# Load the data into memory as numpy arrays
387-
with gzip.open(data_path, "rb") as f:
388-
train_set, _, _ = pickle.load(f, **pickle_args)
389-
390352
pca = PCA(
391353
role=ROLE,
392354
train_instance_count=SINGLE_INSTANCE_COUNT,
@@ -399,7 +361,7 @@ def test_pca_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
399361
pca.subtract_mean = True
400362
pca.extra_components = 5
401363

402-
records = pca.record_set(train_set[0][:100])
364+
records = pca.record_set(datasets.one_p_mnist()[0][:100])
403365

404366
training_config = _build_airflow_workflow(
405367
estimator=pca, instance_type=cpu_instance_type, inputs=records

tests/integ/test_byo_estimator.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,16 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import gzip
1615
import json
1716
import os
18-
import pickle
19-
import sys
2017

2118
import pytest
2219

2320
import sagemaker
2421
from sagemaker.amazon.amazon_estimator import get_image_uri
2522
from sagemaker.estimator import Estimator
2623
from sagemaker.utils import unique_name_from_base
27-
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
24+
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, datasets
2825
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2926

3027

@@ -33,6 +30,11 @@ def region(sagemaker_session):
3330
return sagemaker_session.boto_session.region_name
3431

3532

33+
@pytest.fixture
34+
def training_set():
35+
return datasets.one_p_mnist()
36+
37+
3638
def fm_serializer(data):
3739
js = {"instances": []}
3840
for row in data:
@@ -41,7 +43,7 @@ def fm_serializer(data):
4143

4244

4345
@pytest.mark.canary_quick
44-
def test_byo_estimator(sagemaker_session, region, cpu_instance_type):
46+
def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_set):
4547
"""Use Factorization Machines algorithm as an example here.
4648
4749
First we need to prepare data for training. We take standard data set, convert it to the
@@ -57,12 +59,6 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type):
5759
job_name = unique_name_from_base("byo")
5860

5961
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
60-
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")
61-
pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"}
62-
63-
with gzip.open(data_path, "rb") as f:
64-
train_set, _, _ = pickle.load(f, **pickle_args)
65-
6662
prefix = "test_byo_estimator"
6763
key = "recordio-pb-data"
6864

@@ -92,26 +88,20 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type):
9288
predictor.content_type = "application/json"
9389
predictor.deserializer = sagemaker.predictor.json_deserializer
9490

95-
result = predictor.predict(train_set[0][:10])
91+
result = predictor.predict(training_set[0][:10])
9692

9793
assert len(result["predictions"]) == 10
9894
for prediction in result["predictions"]:
9995
assert prediction["score"] is not None
10096

10197

102-
def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type):
98+
def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, training_set):
10399
image_name = get_image_uri(region, "factorization-machines")
104100
endpoint_name = unique_name_from_base("byo")
105101
training_data_path = os.path.join(DATA_DIR, "dummy_tensor")
106102
job_name = unique_name_from_base("byo")
107103

108104
with timeout(minutes=5):
109-
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")
110-
pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"}
111-
112-
with gzip.open(data_path, "rb") as f:
113-
train_set, _, _ = pickle.load(f, **pickle_args)
114-
115105
prefix = "test_byo_estimator"
116106
key = "recordio-pb-data"
117107

@@ -144,7 +134,7 @@ def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type):
144134
predictor.content_type = "application/json"
145135
predictor.deserializer = sagemaker.predictor.json_deserializer
146136

147-
result = predictor.predict(train_set[0][:10])
137+
result = predictor.predict(training_set[0][:10])
148138

149139
assert len(result["predictions"]) == 10
150140
for prediction in result["predictions"]:

tests/integ/test_factorization_machines.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,25 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import gzip
16-
import os
17-
import pickle
18-
import sys
1915
import time
2016

17+
import pytest
18+
2119
from sagemaker import FactorizationMachines, FactorizationMachinesModel
2220
from sagemaker.utils import unique_name_from_base
23-
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
21+
from tests.integ import datasets, TRAINING_DEFAULT_TIMEOUT_MINUTES
2422
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2523

2624

27-
def test_factorization_machines(sagemaker_session, cpu_instance_type):
28-
job_name = unique_name_from_base("fm")
25+
@pytest.fixture
26+
def training_set():
27+
return datasets.one_p_mnist()
2928

30-
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
31-
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")
32-
pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"}
3329

34-
# Load the data into memory as numpy arrays
35-
with gzip.open(data_path, "rb") as f:
36-
train_set, _, _ = pickle.load(f, **pickle_args)
30+
def test_factorization_machines(sagemaker_session, cpu_instance_type, training_set):
31+
job_name = unique_name_from_base("fm")
3732

33+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
3834
fm = FactorizationMachines(
3935
role="SageMakerRole",
4036
train_instance_count=1,
@@ -50,7 +46,7 @@ def test_factorization_machines(sagemaker_session, cpu_instance_type):
5046

5147
# training labels must be 'float32'
5248
fm.fit(
53-
fm.record_set(train_set[0][:200], train_set[1][:200].astype("float32")),
49+
fm.record_set(training_set[0][:200], training_set[1][:200].astype("float32")),
5450
job_name=job_name,
5551
)
5652

@@ -59,24 +55,17 @@ def test_factorization_machines(sagemaker_session, cpu_instance_type):
5955
fm.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session
6056
)
6157
predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
62-
result = predictor.predict(train_set[0][:10])
58+
result = predictor.predict(training_set[0][:10])
6359

6460
assert len(result) == 10
6561
for record in result:
6662
assert record.label["score"] is not None
6763

6864

69-
def test_async_factorization_machines(sagemaker_session, cpu_instance_type):
65+
def test_async_factorization_machines(sagemaker_session, cpu_instance_type, training_set):
7066
job_name = unique_name_from_base("fm")
7167

7268
with timeout(minutes=5):
73-
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")
74-
pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"}
75-
76-
# Load the data into memory as numpy arrays
77-
with gzip.open(data_path, "rb") as f:
78-
train_set, _, _ = pickle.load(f, **pickle_args)
79-
8069
fm = FactorizationMachines(
8170
role="SageMakerRole",
8271
train_instance_count=1,
@@ -92,7 +81,7 @@ def test_async_factorization_machines(sagemaker_session, cpu_instance_type):
9281

9382
# training labels must be 'float32'
9483
fm.fit(
95-
fm.record_set(train_set[0][:200], train_set[1][:200].astype("float32")),
84+
fm.record_set(training_set[0][:200], training_set[1][:200].astype("float32")),
9685
job_name=job_name,
9786
wait=False,
9887
)
@@ -109,7 +98,7 @@ def test_async_factorization_machines(sagemaker_session, cpu_instance_type):
10998
estimator.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session
11099
)
111100
predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
112-
result = predictor.predict(train_set[0][:10])
101+
result = predictor.predict(training_set[0][:10])
113102

114103
assert len(result) == 10
115104
for record in result:

0 commit comments

Comments
 (0)