Skip to content

Commit 1bf8c66

Browse files
committed
modify rl ray images mapping for newer versions
1 parent c919830 commit 1bf8c66

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
"{} framework does not support version {}. Please use one of the following: {}."
7070
)
7171

72-
VALID_PY_VERSIONS = ["py2", "py3", "py37"]
72+
VALID_PY_VERSIONS = ["py2", "py3", "py37", "py36"]
7373
VALID_EIA_FRAMEWORKS = [
7474
"tensorflow",
7575
"tensorflow-serving",

src/sagemaker/rl/estimator.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
logger = logging.getLogger("sagemaker")
2727

28+
DEFAULT_RL_ACCOUNT = "462105765813"
2829
SAGEMAKER_ESTIMATOR = "sagemaker_estimator"
2930
SAGEMAKER_ESTIMATOR_VALUE = "RLEstimator"
3031
PYTHON_VERSION = "py3"
@@ -41,6 +42,8 @@
4142
"0.5": {"tensorflow": "1.11"},
4243
"0.6.5": {"tensorflow": "1.12"},
4344
"0.6": {"tensorflow": "1.12"},
45+
"0.8.2":{"tensorflow": "2.1"},
46+
"0.8.5":{"tensorflow": "2.1", "pytorch": "1.5"}
4447
},
4548
}
4649

@@ -57,14 +60,15 @@ class RLFramework(enum.Enum):
5760

5861
TENSORFLOW = "tensorflow"
5962
MXNET = "mxnet"
63+
PYTORCH = "pytorch"
6064

6165

6266
class RLEstimator(Framework):
6367
"""Handle end-to-end training and deployment of custom RLEstimator code."""
6468

6569
COACH_LATEST_VERSION_TF = "0.11.1"
6670
COACH_LATEST_VERSION_MXNET = "0.11.0"
67-
RAY_LATEST_VERSION = "0.6.5"
71+
RAY_LATEST_VERSION = "0.8.5"
6872

6973
def __init__(
7074
self,
@@ -277,6 +281,18 @@ def train_image(self):
277281
"""
278282
if self.image_name:
279283
return self.image_name
284+
285+
# use different account for rl images if ray version is later than 0.8.2
286+
if self.toolkit == RLToolkit.RAY.value and self.toolkit_version >= "0.8.2":
287+
return fw_utils.create_image_uri(
288+
self.sagemaker_session.boto_region_name,
289+
"rl-ray-container",
290+
self.train_instance_type,
291+
self._image_version(),
292+
py_version="py36",
293+
account=DEFAULT_RL_ACCOUNT
294+
)
295+
280296
return fw_utils.create_image_uri(
281297
self.sagemaker_session.boto_region_name,
282298
self._image_framework(),
@@ -454,6 +470,13 @@ def _validate_toolkit_support(cls, toolkit, toolkit_version, framework):
454470

455471
def _image_version(self):
456472
"""Placeholder docstring"""
473+
if self.toolkit == RLToolkit.RAY.value and self.toolkit_version >= "0.8.2":
474+
frameworkd_tag = None
475+
if self.framework == RLFramework.TENSORFLOW.value:
476+
frameworkd_tag = "tf"
477+
elif self.framework == RLFramework.PYTORCH.value:
478+
frameworkd_tag = "torch"
479+
return "{}-{}-{}".format(self.toolkit, self.toolkit_version, frameworkd_tag)
457480
return "{}{}".format(self.toolkit, self.toolkit_version)
458481

459482
def _image_framework(self):

0 commit comments

Comments
 (0)