File tree 4 files changed +7
-43
lines changed
4 files changed +7
-43
lines changed Original file line number Diff line number Diff line change 49
49
SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge" , "ml.p3.2xlarge" )
50
50
51
51
52
- def is_version_equal_or_higher (lowest_version , framework_version ):
53
- """Determine whether the ``framework_version`` is equal to or higher than
54
- ``lowest_version``
55
-
56
- Args:
57
- lowest_version (List[int]): lowest version represented in an integer
58
- list
59
- framework_version (str): framework version string
60
-
61
- Returns:
62
- bool: Whether or not ``framework_version`` is equal to or higher than
63
- ``lowest_version``
64
- """
65
- version_list = [int (s ) for s in framework_version .split ("." )]
66
- return version_list >= lowest_version [0 : len (version_list )]
67
-
68
-
69
- def is_version_equal_or_lower (highest_version , framework_version ):
70
- """Determine whether the ``framework_version`` is equal to or lower than
71
- ``highest_version``
72
-
73
- Args:
74
- highest_version (List[int]): highest version represented in an integer
75
- list
76
- framework_version (str): framework version string
77
-
78
- Returns:
79
- bool: Whether or not ``framework_version`` is equal to or lower than
80
- ``highest_version``
81
- """
82
- version_list = [int (s ) for s in framework_version .split ("." )]
83
- return version_list <= highest_version [0 : len (version_list )]
84
-
85
-
86
52
def validate_source_dir (script , directory ):
87
53
"""Validate that the source directory exists and it contains the user script
88
54
Args:
Original file line number Diff line number Diff line change 15
15
16
16
import logging
17
17
18
+ from packaging .version import Version
19
+
18
20
from sagemaker .estimator import Framework
19
21
from sagemaker .fw_utils import (
20
22
framework_name_from_image ,
21
23
framework_version_from_tag ,
22
- is_version_equal_or_higher ,
23
24
python_deprecation_warning ,
24
25
validate_version_or_image_args ,
25
26
warn_if_parameter_server_with_multi_gpu ,
@@ -157,9 +158,7 @@ def __init__(
157
158
158
159
if "enable_sagemaker_metrics" not in kwargs :
159
160
# enable sagemaker metrics for MXNet v1.6 or greater:
160
- if self .framework_version and is_version_equal_or_higher (
161
- [1 , 6 ], self .framework_version
162
- ):
161
+ if self .framework_version and Version (self .framework_version ) >= Version ("1.6" ):
163
162
kwargs ["enable_sagemaker_metrics" ] = True
164
163
165
164
super (MXNet , self ).__init__ (
Original file line number Diff line number Diff line change 15
15
16
16
import logging
17
17
18
+ from packaging .version import Version
19
+
18
20
from sagemaker .estimator import Framework
19
21
from sagemaker .fw_utils import (
20
22
framework_name_from_image ,
21
23
framework_version_from_tag ,
22
- is_version_equal_or_higher ,
23
24
python_deprecation_warning ,
24
25
validate_version_or_image_args ,
25
26
)
@@ -116,9 +117,7 @@ def __init__(
116
117
117
118
if "enable_sagemaker_metrics" not in kwargs :
118
119
# enable sagemaker metrics for PT v1.3 or greater:
119
- if self .framework_version and is_version_equal_or_higher (
120
- [1 , 3 ], self .framework_version
121
- ):
120
+ if self .framework_version and Version (self .framework_version ) >= Version ("1.3" ):
122
121
kwargs ["enable_sagemaker_metrics" ] = True
123
122
124
123
super (PyTorch , self ).__init__ (
Original file line number Diff line number Diff line change @@ -129,7 +129,7 @@ def __init__(
129
129
130
130
if "enable_sagemaker_metrics" not in kwargs :
131
131
# enable sagemaker metrics for TF v1.15 or greater:
132
- if framework_version and fw . is_version_equal_or_higher ([ 1 , 15 ], framework_version ):
132
+ if framework_version and version . Version ( framework_version ) >= version . Version ( "1.15" ):
133
133
kwargs ["enable_sagemaker_metrics" ] = True
134
134
135
135
super (TensorFlow , self ).__init__ (image_uri = image_uri , ** kwargs )
You can’t perform that action at this time.
0 commit comments