Skip to content

Commit c968c8d

Browse files
committed
use generated fixtures for RL versions
1 parent e2e3cb2 commit c968c8d

File tree

5 files changed

+67
-88
lines changed

5 files changed

+67
-88
lines changed

src/sagemaker/rl/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"0.11.0": {"tensorflow": "1.11", "mxnet": "1.3"},
3737
"0.11.1": {"tensorflow": "1.12"},
3838
"0.11": {"tensorflow": "1.12", "mxnet": "1.3"},
39+
"1.0.0": {"tensorflow": "1.12"},
3940
},
4041
"ray": {
4142
"0.5.3": {"tensorflow": "1.11"},

tests/conftest.py

Lines changed: 5 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from sagemaker import Session, image_uris, utils
2424
from sagemaker.local import LocalSession
25-
from sagemaker.rl import RLEstimator
2625
import tests.integ
2726

2827
DEFAULT_REGION = "us-west-2"
@@ -41,15 +40,20 @@
4140

4241
FRAMEWORKS_FOR_GENERATED_VERSION_FIXTURES = (
4342
"chainer",
43+
"coach_mxnet",
44+
"coach_tensorflow",
4445
"inferentia_mxnet",
4546
"inferentia_tensorflow",
4647
"mxnet",
4748
"neo_mxnet",
4849
"neo_pytorch",
4950
"neo_tensorflow",
5051
"pytorch",
52+
"ray_pytorch",
53+
"ray_tensorflow",
5154
"sklearn",
5255
"tensorflow",
56+
"vw",
5357
"xgboost",
5458
)
5559

@@ -181,46 +185,6 @@ def _tf_py_version(tf_version, request):
181185
return "py37"
182186

183187

184-
@pytest.fixture(scope="module", params=["0.10.1", "0.10.1", "0.11", "0.11.0", "0.11.1"])
185-
def rl_coach_tf_version(request):
186-
return request.param
187-
188-
189-
@pytest.fixture(scope="module", params=["0.11", "0.11.0"])
190-
def rl_coach_mxnet_version(request):
191-
return request.param
192-
193-
194-
@pytest.fixture(scope="module", params=["0.5", "0.5.3", "0.6", "0.6.5", "0.8.2", "0.8.5"])
195-
def rl_ray_tf_version(request):
196-
return request.param
197-
198-
199-
@pytest.fixture(scope="module", params=["0.8.5"])
200-
def rl_ray_pytorch_version(request):
201-
return request.param
202-
203-
204-
@pytest.fixture(scope="module", params=["8.7.0"])
205-
def rl_vw_version(request):
206-
return request.param
207-
208-
209-
@pytest.fixture(scope="module")
210-
def rl_coach_mxnet_full_version():
211-
return RLEstimator.COACH_LATEST_VERSION_MXNET
212-
213-
214-
@pytest.fixture(scope="module")
215-
def rl_coach_tf_full_version():
216-
return RLEstimator.COACH_LATEST_VERSION_TF
217-
218-
219-
@pytest.fixture(scope="module")
220-
def rl_ray_full_version():
221-
return RLEstimator.RAY_LATEST_VERSION
222-
223-
224188
@pytest.fixture(scope="module")
225189
def tf_full_version(tensorflow_training_latest_version, tensorflow_inference_latest_version):
226190
"""Fixture for TF tests that test both training and inference.

tests/integ/test_rl.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525

2626
@pytest.mark.canary_quick
27-
def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version, cpu_instance_type):
27+
def test_coach_mxnet(sagemaker_session, coach_mxnet_latest_version, cpu_instance_type):
2828
estimator = _test_coach(
29-
sagemaker_session, RLFramework.MXNET, rl_coach_mxnet_full_version, cpu_instance_type
29+
sagemaker_session, RLFramework.MXNET, coach_mxnet_latest_version, cpu_instance_type
3030
)
3131
job_name = unique_name_from_base("test-coach-mxnet")
3232

@@ -51,9 +51,12 @@ def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version, cpu_instanc
5151
assert 0 < action[0][1] < 1
5252

5353

54-
def test_coach_tf(sagemaker_session, rl_coach_tf_full_version, cpu_instance_type):
54+
def test_coach_tf(sagemaker_session, coach_tensorflow_latest_version, cpu_instance_type):
5555
estimator = _test_coach(
56-
sagemaker_session, RLFramework.TENSORFLOW, rl_coach_tf_full_version, cpu_instance_type
56+
sagemaker_session,
57+
RLFramework.TENSORFLOW,
58+
coach_tensorflow_latest_version,
59+
cpu_instance_type,
5760
)
5861
job_name = unique_name_from_base("test-coach-tf")
5962

@@ -96,7 +99,7 @@ def _test_coach(sagemaker_session, rl_framework, rl_coach_version, cpu_instance_
9699

97100

98101
@pytest.mark.canary_quick
99-
def test_ray_tf(sagemaker_session, rl_ray_full_version, cpu_instance_type):
102+
def test_ray_tf(sagemaker_session, ray_tensorflow_latest_version, cpu_instance_type):
100103
source_dir = os.path.join(DATA_DIR, "ray_cartpole")
101104
cartpole = "train_ray.py"
102105

@@ -105,7 +108,7 @@ def test_ray_tf(sagemaker_session, rl_ray_full_version, cpu_instance_type):
105108
source_dir=source_dir,
106109
toolkit=RLToolkit.RAY,
107110
framework=RLFramework.TENSORFLOW,
108-
toolkit_version=rl_ray_full_version,
111+
toolkit_version=ray_tensorflow_latest_version,
109112
sagemaker_session=sagemaker_session,
110113
role="SageMakerRole",
111114
instance_type=cpu_instance_type,

tests/unit/sagemaker/image_uris/test_rl.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,15 @@ def _version_for_tag(toolkit, toolkit_version, framework, framework_in_tag=False
3131
return "{}{}".format(toolkit, toolkit_version)
3232

3333

34-
def test_coach_tf(rl_coach_tf_version):
34+
def test_coach_tf(coach_tensorflow_version):
3535
for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
3636
uri = image_uris.retrieve(
37-
"coach-tensorflow", REGION, version=rl_coach_tf_version, instance_type=instance_type
37+
"coach-tensorflow",
38+
REGION,
39+
version=coach_tensorflow_version,
40+
instance_type=instance_type,
3841
)
39-
assert _expected_coach_tf_uri(rl_coach_tf_version, processor) == uri
42+
assert _expected_coach_tf_uri(coach_tensorflow_version, processor) == uri
4043

4144

4245
def _expected_coach_tf_uri(coach_tf_version, processor):
@@ -58,28 +61,28 @@ def _expected_coach_tf_uri(coach_tf_version, processor):
5861
)
5962

6063

61-
def test_coach_mxnet(rl_coach_mxnet_version):
64+
def test_coach_mxnet(coach_mxnet_version):
6265
for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
6366
uri = image_uris.retrieve(
64-
"coach-mxnet", REGION, version=rl_coach_mxnet_version, instance_type=instance_type
67+
"coach-mxnet", REGION, version=coach_mxnet_version, instance_type=instance_type
6568
)
6669

6770
expected = expected_uris.framework_uri(
6871
"sagemaker-rl-mxnet",
69-
"coach{}".format(rl_coach_mxnet_version),
72+
"coach{}".format(coach_mxnet_version),
7073
SAGEMAKER_ACCOUNT,
7174
py_version="py3",
7275
processor=processor,
7376
)
7477
assert expected == uri
7578

7679

77-
def test_ray_tf(rl_ray_tf_version):
80+
def test_ray_tf(ray_tensorflow_version):
7881
for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
7982
uri = image_uris.retrieve(
80-
"ray-tensorflow", REGION, version=rl_ray_tf_version, instance_type=instance_type
83+
"ray-tensorflow", REGION, version=ray_tensorflow_version, instance_type=instance_type
8184
)
82-
assert _expected_ray_tf_uri(rl_ray_tf_version, processor) == uri
85+
assert _expected_ray_tf_uri(ray_tensorflow_version, processor) == uri
8386

8487

8588
def _expected_ray_tf_uri(ray_tf_version, processor):
@@ -101,15 +104,15 @@ def _expected_ray_tf_uri(ray_tf_version, processor):
101104
)
102105

103106

104-
def test_ray_pytorch(rl_ray_pytorch_version):
107+
def test_ray_pytorch(ray_pytorch_version):
105108
for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
106109
uri = image_uris.retrieve(
107-
"ray-pytorch", REGION, version=rl_ray_pytorch_version, instance_type=instance_type
110+
"ray-pytorch", REGION, version=ray_pytorch_version, instance_type=instance_type
108111
)
109112

110113
expected = expected_uris.framework_uri(
111114
"sagemaker-rl-ray-container",
112-
"ray-{}-torch".format(rl_ray_pytorch_version),
115+
"ray-{}-torch".format(ray_pytorch_version),
113116
RL_ACCOUNT,
114117
py_version="py36",
115118
processor=processor,
@@ -118,8 +121,8 @@ def test_ray_pytorch(rl_ray_pytorch_version):
118121
assert expected == uri
119122

120123

121-
def test_vw(rl_vw_version):
122-
version = "vw-{}".format(rl_vw_version)
124+
def test_vw(vw_version):
125+
version = "vw-{}".format(vw_version)
123126
uri = image_uris.retrieve("vw", REGION, version=version, instance_type="ml.c4.xlarge")
124127

125128
expected = expected_uris.framework_uri("sagemaker-rl-vw-container", version, RL_ACCOUNT)

0 commit comments

Comments
 (0)