Skip to content

Commit 1b38cbc

Browse files
committed
update processing integ tests
1 parent d7a3a61 commit 1b38cbc

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

tests/integ/test_processing.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
import pytest
1818
from botocore.config import Config
19-
from sagemaker import Session
20-
from sagemaker.fw_registry import default_framework_uri
2119

20+
from sagemaker import image_uris, Session
21+
from sagemaker.network import NetworkConfig
2222
from sagemaker.processing import (
2323
ProcessingInput,
2424
ProcessingOutput,
@@ -27,7 +27,6 @@
2727
ProcessingJob,
2828
)
2929
from sagemaker.sklearn.processing import SKLearnProcessor
30-
from sagemaker.network import NetworkConfig
3130
from tests.integ import DATA_DIR
3231
from tests.integ.kms_utils import get_or_create_kms_key
3332

@@ -59,10 +58,18 @@ def sagemaker_session_with_custom_bucket(
5958

6059

6160
@pytest.fixture(scope="module")
62-
def image_uri(sagemaker_session):
63-
image_tag = "{}-{}-{}".format("0.20.0", "cpu", "py3")
64-
return default_framework_uri(
65-
"scikit-learn", sagemaker_session.boto_session.region_name, image_tag
61+
def image_uri(
62+
scikit_learn_latest_version,
63+
scikit_learn_latest_py_version,
64+
cpu_instance_type,
65+
sagemaker_session,
66+
):
67+
return image_uris.retrieve(
68+
"scikit-learn",
69+
sagemaker_session.boto_region_name,
70+
version=scikit_learn_latest_version,
71+
py_version=scikit_learn_latest_py_version,
72+
instance_type=cpu_instance_type,
6673
)
6774

6875

tests/integ/test_sklearn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def test_training_with_additional_hyperparameters(
7474
job_name = unique_name_from_base("test-sklearn-hp")
7575

7676
sklearn.fit({"train": train_input, "test": test_input}, job_name=job_name)
77-
return sklearn.latest_training_job.name
7877

7978

8079
def test_training_with_network_isolation(
@@ -110,7 +109,6 @@ def test_training_with_network_isolation(
110109
assert sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=job_name)[
111110
"EnableNetworkIsolation"
112111
]
113-
return sklearn.latest_training_job.name
114112

115113

116114
@pytest.mark.canary_quick

0 commit comments

Comments
 (0)