Skip to content

change: use generated RL version fixtures and update Ray version #1769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"0.11.0": {"tensorflow": "1.11", "mxnet": "1.3"},
"0.11.1": {"tensorflow": "1.12"},
"0.11": {"tensorflow": "1.12", "mxnet": "1.3"},
"1.0.0": {"tensorflow": "1.12"},
},
"ray": {
"0.5.3": {"tensorflow": "1.11"},
Expand Down Expand Up @@ -68,7 +69,7 @@ class RLEstimator(Framework):

COACH_LATEST_VERSION_TF = "0.11.1"
COACH_LATEST_VERSION_MXNET = "0.11.0"
RAY_LATEST_VERSION = "0.6.5"
RAY_LATEST_VERSION = "0.8.5"

def __init__(
self,
Expand Down
46 changes: 5 additions & 41 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from sagemaker import Session, image_uris, utils
from sagemaker.local import LocalSession
from sagemaker.rl import RLEstimator
import tests.integ

DEFAULT_REGION = "us-west-2"
Expand All @@ -41,15 +40,20 @@

FRAMEWORKS_FOR_GENERATED_VERSION_FIXTURES = (
"chainer",
"coach_mxnet",
"coach_tensorflow",
"inferentia_mxnet",
"inferentia_tensorflow",
"mxnet",
"neo_mxnet",
"neo_pytorch",
"neo_tensorflow",
"pytorch",
"ray_pytorch",
"ray_tensorflow",
"sklearn",
"tensorflow",
"vw",
"xgboost",
)

Expand Down Expand Up @@ -181,46 +185,6 @@ def _tf_py_version(tf_version, request):
return "py37"


@pytest.fixture(scope="module", params=["0.10.1", "0.10.1", "0.11", "0.11.0", "0.11.1"])
def rl_coach_tf_version(request):
return request.param


@pytest.fixture(scope="module", params=["0.11", "0.11.0"])
def rl_coach_mxnet_version(request):
return request.param


@pytest.fixture(scope="module", params=["0.5", "0.5.3", "0.6", "0.6.5", "0.8.2", "0.8.5"])
def rl_ray_tf_version(request):
return request.param


@pytest.fixture(scope="module", params=["0.8.5"])
def rl_ray_pytorch_version(request):
return request.param


@pytest.fixture(scope="module", params=["8.7.0"])
def rl_vw_version(request):
return request.param


@pytest.fixture(scope="module")
def rl_coach_mxnet_full_version():
return RLEstimator.COACH_LATEST_VERSION_MXNET


@pytest.fixture(scope="module")
def rl_coach_tf_full_version():
return RLEstimator.COACH_LATEST_VERSION_TF


@pytest.fixture(scope="module")
def rl_ray_full_version():
return RLEstimator.RAY_LATEST_VERSION


@pytest.fixture(scope="module")
def tf_full_version(tensorflow_training_latest_version, tensorflow_inference_latest_version):
"""Fixture for TF tests that test both training and inference.
Expand Down
4 changes: 2 additions & 2 deletions tests/data/ray_cartpole/train_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from ray.tune.logger import pretty_print

# Based on https://github.com/ray-project/ray/blob/master/doc/source/rllib-training.rst#python-api
ray.init(log_to_driver=False)
ray.init(log_to_driver=False, webui_host="127.0.0.1")
config = ppo.DEFAULT_CONFIG.copy()
config["num_gpus"] = int(os.environ.get("SM_NUM_GPUS", 0))
checkpoint_dir = os.environ.get("SM_MODEL_DIR", "/Users/nadzeya/gym")
config["num_workers"] = 1
agent = ppo.PPOAgent(config=config, env="CartPole-v0")
agent = ppo.PPOTrainer(config=config, env="CartPole-v0")

# Can optionally call agent.restore(path) to load a checkpoint.

Expand Down
15 changes: 9 additions & 6 deletions tests/integ/test_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@


@pytest.mark.canary_quick
def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version, cpu_instance_type):
def test_coach_mxnet(sagemaker_session, coach_mxnet_latest_version, cpu_instance_type):
estimator = _test_coach(
sagemaker_session, RLFramework.MXNET, rl_coach_mxnet_full_version, cpu_instance_type
sagemaker_session, RLFramework.MXNET, coach_mxnet_latest_version, cpu_instance_type
)
job_name = unique_name_from_base("test-coach-mxnet")

Expand All @@ -51,9 +51,12 @@ def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version, cpu_instanc
assert 0 < action[0][1] < 1


def test_coach_tf(sagemaker_session, rl_coach_tf_full_version, cpu_instance_type):
def test_coach_tf(sagemaker_session, coach_tensorflow_latest_version, cpu_instance_type):
estimator = _test_coach(
sagemaker_session, RLFramework.TENSORFLOW, rl_coach_tf_full_version, cpu_instance_type
sagemaker_session,
RLFramework.TENSORFLOW,
coach_tensorflow_latest_version,
cpu_instance_type,
)
job_name = unique_name_from_base("test-coach-tf")

Expand Down Expand Up @@ -96,7 +99,7 @@ def _test_coach(sagemaker_session, rl_framework, rl_coach_version, cpu_instance_


@pytest.mark.canary_quick
def test_ray_tf(sagemaker_session, rl_ray_full_version, cpu_instance_type):
def test_ray_tf(sagemaker_session, ray_tensorflow_latest_version, cpu_instance_type):
source_dir = os.path.join(DATA_DIR, "ray_cartpole")
cartpole = "train_ray.py"

Expand All @@ -105,7 +108,7 @@ def test_ray_tf(sagemaker_session, rl_ray_full_version, cpu_instance_type):
source_dir=source_dir,
toolkit=RLToolkit.RAY,
framework=RLFramework.TENSORFLOW,
toolkit_version=rl_ray_full_version,
toolkit_version=ray_tensorflow_latest_version,
sagemaker_session=sagemaker_session,
role="SageMakerRole",
instance_type=cpu_instance_type,
Expand Down
31 changes: 17 additions & 14 deletions tests/unit/sagemaker/image_uris/test_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@ def _version_for_tag(toolkit, toolkit_version, framework, framework_in_tag=False
return "{}{}".format(toolkit, toolkit_version)


def test_coach_tf(rl_coach_tf_version):
def test_coach_tf(coach_tensorflow_version):
for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
uri = image_uris.retrieve(
"coach-tensorflow", REGION, version=rl_coach_tf_version, instance_type=instance_type
"coach-tensorflow",
REGION,
version=coach_tensorflow_version,
instance_type=instance_type,
)
assert _expected_coach_tf_uri(rl_coach_tf_version, processor) == uri
assert _expected_coach_tf_uri(coach_tensorflow_version, processor) == uri


def _expected_coach_tf_uri(coach_tf_version, processor):
Expand All @@ -58,28 +61,28 @@ def _expected_coach_tf_uri(coach_tf_version, processor):
)


def test_coach_mxnet(rl_coach_mxnet_version):
def test_coach_mxnet(coach_mxnet_version):
for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
uri = image_uris.retrieve(
"coach-mxnet", REGION, version=rl_coach_mxnet_version, instance_type=instance_type
"coach-mxnet", REGION, version=coach_mxnet_version, instance_type=instance_type
)

expected = expected_uris.framework_uri(
"sagemaker-rl-mxnet",
"coach{}".format(rl_coach_mxnet_version),
"coach{}".format(coach_mxnet_version),
SAGEMAKER_ACCOUNT,
py_version="py3",
processor=processor,
)
assert expected == uri


def test_ray_tf(rl_ray_tf_version):
def test_ray_tf(ray_tensorflow_version):
for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
uri = image_uris.retrieve(
"ray-tensorflow", REGION, version=rl_ray_tf_version, instance_type=instance_type
"ray-tensorflow", REGION, version=ray_tensorflow_version, instance_type=instance_type
)
assert _expected_ray_tf_uri(rl_ray_tf_version, processor) == uri
assert _expected_ray_tf_uri(ray_tensorflow_version, processor) == uri


def _expected_ray_tf_uri(ray_tf_version, processor):
Expand All @@ -101,15 +104,15 @@ def _expected_ray_tf_uri(ray_tf_version, processor):
)


def test_ray_pytorch(rl_ray_pytorch_version):
def test_ray_pytorch(ray_pytorch_version):
for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS:
uri = image_uris.retrieve(
"ray-pytorch", REGION, version=rl_ray_pytorch_version, instance_type=instance_type
"ray-pytorch", REGION, version=ray_pytorch_version, instance_type=instance_type
)

expected = expected_uris.framework_uri(
"sagemaker-rl-ray-container",
"ray-{}-torch".format(rl_ray_pytorch_version),
"ray-{}-torch".format(ray_pytorch_version),
RL_ACCOUNT,
py_version="py36",
processor=processor,
Expand All @@ -118,8 +121,8 @@ def test_ray_pytorch(rl_ray_pytorch_version):
assert expected == uri


def test_vw(rl_vw_version):
version = "vw-{}".format(rl_vw_version)
def test_vw(vw_version):
version = "vw-{}".format(vw_version)
uri = image_uris.retrieve("vw", REGION, version=version, instance_type="ml.c4.xlarge")

expected = expected_uris.framework_uri("sagemaker-rl-vw-container", version, RL_ACCOUNT)
Expand Down
Loading