Skip to content

Commit b967e49

Browse files
committed
change framework name to 'sklearn'
1 parent 1b38cbc commit b967e49

File tree

12 files changed

+88
-124
lines changed

12 files changed

+88
-124
lines changed

src/sagemaker/sklearn/defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
SKLEARN_NAME = "scikit-learn"
16+
SKLEARN_NAME = "sklearn"

src/sagemaker/sklearn/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ class constructor
246246
init_params["image_uri"] = image_uri
247247
return init_params
248248

249-
if framework and framework != cls.__framework_name__:
249+
if framework and framework != "scikit-learn":
250250
raise ValueError(
251251
"Training job: {} didn't use image for requested framework".format(
252252
job_details["TrainingJobName"]

tests/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,8 @@ def pytest_generate_tests(metafunc):
285285

286286

287287
def _generate_all_framework_version_fixtures(metafunc):
288-
for fw in ("chainer", "mxnet", "pytorch", "scikit-learn", "tensorflow", "xgboost"):
288+
for fw in ("chainer", "mxnet", "pytorch", "sklearn", "tensorflow", "xgboost"):
289289
config = image_uris.config_for_framework(fw)
290-
fw = fw.replace("-", "_") # for fixture names
291290
if "scope" in config:
292291
_parametrize_framework_version_fixtures(metafunc, fw, config)
293292
else:

tests/integ/test_airflow_config.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,7 @@ def test_mxnet_airflow_config_uploads_data_source_to_s3(
478478

479479
@pytest.mark.canary_quick
480480
def test_sklearn_airflow_config_uploads_data_source_to_s3(
481-
sagemaker_session,
482-
cpu_instance_type,
483-
scikit_learn_latest_version,
484-
scikit_learn_latest_py_version,
481+
sagemaker_session, cpu_instance_type, sklearn_latest_version, sklearn_latest_py_version,
485482
):
486483
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
487484
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py")
@@ -491,8 +488,8 @@ def test_sklearn_airflow_config_uploads_data_source_to_s3(
491488
entry_point=script_path,
492489
role=ROLE,
493490
instance_type=cpu_instance_type,
494-
framework_version=scikit_learn_latest_version,
495-
py_version=scikit_learn_latest_py_version,
491+
framework_version=sklearn_latest_version,
492+
py_version=sklearn_latest_py_version,
496493
sagemaker_session=sagemaker_session,
497494
hyperparameters={"epochs": 1},
498495
)

tests/integ/test_git.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_private_github(
138138
@pytest.mark.local_mode
139139
@pytest.mark.skip("needs a secure authentication approach")
140140
def test_private_github_with_2fa(
141-
sagemaker_local_session, scikit_learn_latest_version, scikit_learn_latest_py_version
141+
sagemaker_local_session, sklearn_latest_version, sklearn_latest_py_version
142142
):
143143
script_path = "mnist.py"
144144
data_path = os.path.join(DATA_DIR, "sklearn_mnist")
@@ -155,11 +155,11 @@ def test_private_github_with_2fa(
155155
entry_point=script_path,
156156
role="SageMakerRole",
157157
source_dir=source_dir,
158-
py_version=scikit_learn_latest_py_version,
158+
py_version=sklearn_latest_py_version,
159159
instance_count=1,
160160
instance_type="local",
161161
sagemaker_session=sagemaker_local_session,
162-
framework_version=scikit_learn_latest_version,
162+
framework_version=sklearn_latest_version,
163163
hyperparameters={"epochs": 1},
164164
git_config=git_config,
165165
)
@@ -178,7 +178,7 @@ def test_private_github_with_2fa(
178178
model_data,
179179
"SageMakerRole",
180180
entry_point=script_path,
181-
framework_version=scikit_learn_latest_version,
181+
framework_version=sklearn_latest_version,
182182
source_dir=source_dir,
183183
sagemaker_session=sagemaker_local_session,
184184
git_config=git_config,
@@ -194,7 +194,7 @@ def test_private_github_with_2fa(
194194

195195
@pytest.mark.local_mode
196196
def test_github_with_ssh_passphrase_not_configured(
197-
sagemaker_local_session, scikit_learn_latest_version, scikit_learn_latest_py_version
197+
sagemaker_local_session, sklearn_latest_version, sklearn_latest_py_version
198198
):
199199
script_path = "mnist.py"
200200
data_path = os.path.join(DATA_DIR, "sklearn_mnist")
@@ -212,8 +212,8 @@ def test_github_with_ssh_passphrase_not_configured(
212212
instance_count=1,
213213
instance_type="local",
214214
sagemaker_session=sagemaker_local_session,
215-
framework_version=scikit_learn_latest_version,
216-
py_version=scikit_learn_latest_py_version,
215+
framework_version=sklearn_latest_version,
216+
py_version=sklearn_latest_py_version,
217217
hyperparameters={"epochs": 1},
218218
git_config=git_config,
219219
)

tests/integ/test_processing.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,13 @@ def sagemaker_session_with_custom_bucket(
5959

6060
@pytest.fixture(scope="module")
6161
def image_uri(
62-
scikit_learn_latest_version,
63-
scikit_learn_latest_py_version,
64-
cpu_instance_type,
65-
sagemaker_session,
62+
sklearn_latest_version, sklearn_latest_py_version, cpu_instance_type, sagemaker_session,
6663
):
6764
return image_uris.retrieve(
6865
"scikit-learn",
6966
sagemaker_session.boto_region_name,
70-
version=scikit_learn_latest_version,
71-
py_version=scikit_learn_latest_py_version,
67+
version=sklearn_latest_version,
68+
py_version=sklearn_latest_py_version,
7269
instance_type=cpu_instance_type,
7370
)
7471

@@ -97,12 +94,12 @@ def output_kms_key(sagemaker_session):
9794
)
9895

9996

100-
def test_sklearn(sagemaker_session, scikit_learn_latest_version, cpu_instance_type):
97+
def test_sklearn(sagemaker_session, sklearn_latest_version, cpu_instance_type):
10198
script_path = os.path.join(DATA_DIR, "dummy_script.py")
10299
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
103100

104101
sklearn_processor = SKLearnProcessor(
105-
framework_version=scikit_learn_latest_version,
102+
framework_version=sklearn_latest_version,
106103
role=ROLE,
107104
instance_type=cpu_instance_type,
108105
instance_count=1,
@@ -136,12 +133,12 @@ def test_sklearn(sagemaker_session, scikit_learn_latest_version, cpu_instance_ty
136133

137134
@pytest.mark.canary_quick
138135
def test_sklearn_with_customizations(
139-
sagemaker_session, image_uri, scikit_learn_latest_version, cpu_instance_type, output_kms_key
136+
sagemaker_session, image_uri, sklearn_latest_version, cpu_instance_type, output_kms_key
140137
):
141138
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
142139

143140
sklearn_processor = SKLearnProcessor(
144-
framework_version=scikit_learn_latest_version,
141+
framework_version=sklearn_latest_version,
145142
role=ROLE,
146143
command=["python3"],
147144
instance_type=cpu_instance_type,
@@ -218,14 +215,14 @@ def test_sklearn_with_custom_default_bucket(
218215
sagemaker_session_with_custom_bucket,
219216
custom_bucket_name,
220217
image_uri,
221-
scikit_learn_latest_version,
218+
sklearn_latest_version,
222219
cpu_instance_type,
223220
output_kms_key,
224221
):
225222
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
226223

227224
sklearn_processor = SKLearnProcessor(
228-
framework_version=scikit_learn_latest_version,
225+
framework_version=sklearn_latest_version,
229226
role=ROLE,
230227
command=["python3"],
231228
instance_type=cpu_instance_type,
@@ -301,10 +298,10 @@ def test_sklearn_with_custom_default_bucket(
301298

302299

303300
def test_sklearn_with_no_inputs_or_outputs(
304-
sagemaker_session, image_uri, scikit_learn_latest_version, cpu_instance_type
301+
sagemaker_session, image_uri, sklearn_latest_version, cpu_instance_type
305302
):
306303
sklearn_processor = SKLearnProcessor(
307-
framework_version=scikit_learn_latest_version,
304+
framework_version=sklearn_latest_version,
308305
role=ROLE,
309306
command=["python3"],
310307
instance_type=cpu_instance_type,
@@ -653,14 +650,12 @@ def test_processor_with_custom_bucket(
653650
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
654651

655652

656-
def test_sklearn_with_network_config(
657-
sagemaker_session, scikit_learn_latest_version, cpu_instance_type
658-
):
653+
def test_sklearn_with_network_config(sagemaker_session, sklearn_latest_version, cpu_instance_type):
659654
script_path = os.path.join(DATA_DIR, "dummy_script.py")
660655
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
661656

662657
sklearn_processor = SKLearnProcessor(
663-
framework_version=scikit_learn_latest_version,
658+
framework_version=sklearn_latest_version,
664659
role=ROLE,
665660
instance_type=cpu_instance_type,
666661
instance_count=1,

tests/integ/test_sklearn.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,16 @@
3131
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
3232
)
3333
def sklearn_training_job(
34-
sagemaker_session,
35-
scikit_learn_latest_version,
36-
scikit_learn_latest_py_version,
37-
cpu_instance_type,
34+
sagemaker_session, sklearn_latest_version, sklearn_latest_py_version, cpu_instance_type,
3835
):
3936
return _run_mnist_training_job(
40-
sagemaker_session,
41-
cpu_instance_type,
42-
scikit_learn_latest_version,
43-
scikit_learn_latest_py_version,
37+
sagemaker_session, cpu_instance_type, sklearn_latest_version, sklearn_latest_py_version,
4438
)
4539
sagemaker_session.boto_region_name
4640

4741

4842
def test_training_with_additional_hyperparameters(
49-
sagemaker_session,
50-
scikit_learn_latest_version,
51-
scikit_learn_latest_py_version,
52-
cpu_instance_type,
43+
sagemaker_session, sklearn_latest_version, sklearn_latest_py_version, cpu_instance_type,
5344
):
5445
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
5546
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py")
@@ -59,8 +50,8 @@ def test_training_with_additional_hyperparameters(
5950
entry_point=script_path,
6051
role="SageMakerRole",
6152
instance_type=cpu_instance_type,
62-
framework_version=scikit_learn_latest_version,
63-
py_version=scikit_learn_latest_py_version,
53+
framework_version=sklearn_latest_version,
54+
py_version=sklearn_latest_py_version,
6455
sagemaker_session=sagemaker_session,
6556
hyperparameters={"epochs": 1},
6657
)
@@ -77,10 +68,7 @@ def test_training_with_additional_hyperparameters(
7768

7869

7970
def test_training_with_network_isolation(
80-
sagemaker_session,
81-
scikit_learn_latest_version,
82-
scikit_learn_latest_py_version,
83-
cpu_instance_type,
71+
sagemaker_session, sklearn_latest_version, sklearn_latest_py_version, cpu_instance_type,
8472
):
8573
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
8674
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py")
@@ -90,8 +78,8 @@ def test_training_with_network_isolation(
9078
entry_point=script_path,
9179
role="SageMakerRole",
9280
instance_type=cpu_instance_type,
93-
framework_version=scikit_learn_latest_version,
94-
py_version=scikit_learn_latest_py_version,
81+
framework_version=sklearn_latest_version,
82+
py_version=sklearn_latest_py_version,
9583
sagemaker_session=sagemaker_session,
9684
hyperparameters={"epochs": 1},
9785
enable_network_isolation=True,
@@ -133,8 +121,8 @@ def test_deploy_model(
133121
sklearn_training_job,
134122
sagemaker_session,
135123
cpu_instance_type,
136-
scikit_learn_latest_version,
137-
scikit_learn_latest_py_version,
124+
sklearn_latest_version,
125+
sklearn_latest_py_version,
138126
):
139127
endpoint_name = "test-sklearn-deploy-model-{}".format(sagemaker_timestamp())
140128
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
@@ -147,7 +135,7 @@ def test_deploy_model(
147135
model_data,
148136
"SageMakerRole",
149137
entry_point=script_path,
150-
framework_version=scikit_learn_latest_version,
138+
framework_version=sklearn_latest_version,
151139
sagemaker_session=sagemaker_session,
152140
)
153141
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
@@ -159,18 +147,15 @@ def test_deploy_model(
159147
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
160148
)
161149
def test_async_fit(
162-
sagemaker_session,
163-
cpu_instance_type,
164-
scikit_learn_latest_version,
165-
scikit_learn_latest_py_version,
150+
sagemaker_session, cpu_instance_type, sklearn_latest_version, sklearn_latest_py_version,
166151
):
167152
endpoint_name = "test-sklearn-attach-deploy-{}".format(sagemaker_timestamp())
168153

169154
with timeout(minutes=5):
170155
training_job_name = _run_mnist_training_job(
171156
sagemaker_session,
172157
cpu_instance_type,
173-
sklearn_version=scikit_learn_latest_version,
158+
sklearn_version=sklearn_latest_version,
174159
wait=False,
175160
)
176161

@@ -187,10 +172,7 @@ def test_async_fit(
187172

188173

189174
def test_failed_training_job(
190-
sagemaker_session,
191-
scikit_learn_latest_version,
192-
scikit_learn_latest_py_version,
193-
cpu_instance_type,
175+
sagemaker_session, sklearn_latest_version, sklearn_latest_py_version, cpu_instance_type,
194176
):
195177
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
196178
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "failure_script.py")
@@ -199,8 +181,8 @@ def test_failed_training_job(
199181
sklearn = SKLearn(
200182
entry_point=script_path,
201183
role="SageMakerRole",
202-
framework_version=scikit_learn_latest_version,
203-
py_version=scikit_learn_latest_py_version,
184+
framework_version=sklearn_latest_version,
185+
py_version=sklearn_latest_py_version,
204186
instance_count=1,
205187
instance_type=cpu_instance_type,
206188
sagemaker_session=sagemaker_session,

tests/unit/sagemaker/image_uris/test_sklearn.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,45 +43,45 @@
4343
}
4444

4545

46-
def test_valid_uris(scikit_learn_version):
46+
def test_valid_uris(sklearn_version):
4747
for region in regions.regions():
4848
uri = image_uris.retrieve(
4949
"scikit-learn",
5050
region=region,
51-
version=scikit_learn_version,
51+
version=sklearn_version,
5252
py_version="py3",
5353
instance_type="ml.c4.xlarge",
5454
)
5555

5656
expected = expected_uris.framework_uri(
5757
"sagemaker-scikit-learn",
58-
scikit_learn_version,
58+
sklearn_version,
5959
ACCOUNTS[region],
6060
py_version="py3",
6161
region=region,
6262
)
6363
assert expected == uri
6464

6565

66-
def test_py2_error(scikit_learn_version):
66+
def test_py2_error(sklearn_version):
6767
with pytest.raises(ValueError) as e:
6868
image_uris.retrieve(
6969
"scikit-learn",
7070
region="us-west-2",
71-
version=scikit_learn_version,
71+
version=sklearn_version,
7272
py_version="py2",
7373
instance_type="ml.c4.xlarge",
7474
)
7575

7676
assert "Unsupported Python version: py2." in str(e.value)
7777

7878

79-
def test_gpu_error(scikit_learn_version):
79+
def test_gpu_error(sklearn_version):
8080
with pytest.raises(ValueError) as e:
8181
image_uris.retrieve(
8282
"scikit-learn",
8383
region="us-west-2",
84-
version=scikit_learn_version,
84+
version=sklearn_version,
8585
instance_type="ml.p2.xlarge",
8686
)
8787

tests/unit/test_fw_registry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
import pytest
1616

1717
from sagemaker.fw_registry import registry, default_framework_uri
18-
from sagemaker.sklearn import SKLearn
1918

2019

21-
scikit_learn_framework_name = SKLearn.__framework_name__
20+
scikit_learn_framework_name = "scikit-learn"
2221

2322

2423
def test_registry_sparkml_serving():

0 commit comments

Comments
 (0)