25
25
26
26
logger = logging .getLogger ("sagemaker" )
27
27
28
+ DEFAULT_RL_ACCOUNT = "462105765813"
28
29
SAGEMAKER_ESTIMATOR = "sagemaker_estimator"
29
30
SAGEMAKER_ESTIMATOR_VALUE = "RLEstimator"
30
31
PYTHON_VERSION = "py3"
41
42
"0.5" : {"tensorflow" : "1.11" },
42
43
"0.6.5" : {"tensorflow" : "1.12" },
43
44
"0.6" : {"tensorflow" : "1.12" },
45
+ "0.8.2" :{"tensorflow" : "2.1" },
46
+ "0.8.5" :{"tensorflow" : "2.1" , "pytorch" : "1.5" }
44
47
},
45
48
}
46
49
@@ -57,14 +60,15 @@ class RLFramework(enum.Enum):
57
60
58
61
TENSORFLOW = "tensorflow"
59
62
MXNET = "mxnet"
63
+ PYTORCH = "pytorch"
60
64
61
65
62
66
class RLEstimator (Framework ):
63
67
"""Handle end-to-end training and deployment of custom RLEstimator code."""
64
68
65
69
COACH_LATEST_VERSION_TF = "0.11.1"
66
70
COACH_LATEST_VERSION_MXNET = "0.11.0"
67
- RAY_LATEST_VERSION = "0.6 .5"
71
+ RAY_LATEST_VERSION = "0.8 .5"
68
72
69
73
def __init__ (
70
74
self ,
@@ -277,6 +281,18 @@ def train_image(self):
277
281
"""
278
282
if self .image_name :
279
283
return self .image_name
284
+
285
+ # use different account for rl images if ray version is later than 0.8.2
286
+ if self .toolkit == RLToolkit .RAY .value and self .toolkit_version >= "0.8.2" :
287
+ return fw_utils .create_image_uri (
288
+ self .sagemaker_session .boto_region_name ,
289
+ "rl-ray-container" ,
290
+ self .train_instance_type ,
291
+ self ._image_version (),
292
+ py_version = "py36" ,
293
+ account = DEFAULT_RL_ACCOUNT
294
+ )
295
+
280
296
return fw_utils .create_image_uri (
281
297
self .sagemaker_session .boto_region_name ,
282
298
self ._image_framework (),
@@ -454,6 +470,13 @@ def _validate_toolkit_support(cls, toolkit, toolkit_version, framework):
454
470
455
471
def _image_version (self ):
456
472
"""Placeholder docstring"""
473
+ if self .toolkit == RLToolkit .RAY .value and self .toolkit_version >= "0.8.2" :
474
+ frameworkd_tag = None
475
+ if self .framework == RLFramework .TENSORFLOW .value :
476
+ frameworkd_tag = "tf"
477
+ elif self .framework == RLFramework .PYTORCH .value :
478
+ frameworkd_tag = "torch"
479
+ return "{}-{}-{}" .format (self .toolkit , self .toolkit_version , frameworkd_tag )
457
480
return "{}{}" .format (self .toolkit , self .toolkit_version )
458
481
459
482
def _image_framework (self ):
0 commit comments