Skip to content

Commit 5bbaf52

Browse files
beniericpintaoz-aws
authored andcommitted
Fix: Correctly serialize SM_HPS env var (#1611)
1 parent a3fce08 commit 5bbaf52

File tree

4 files changed

+67
-8
lines changed

4 files changed

+67
-8
lines changed

src/sagemaker/modules/train/container_drivers/scripts/environment.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
2525
sys.path.insert(0, parent_dir)
2626

27-
from utils import safe_serialize # noqa: E402 # pylint: disable=C0413
27+
from utils import safe_serialize, safe_deserialize # noqa: E402 # pylint: disable=C0413
2828

2929
# Initialize logger
3030
SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20)
@@ -115,6 +115,21 @@ def num_neurons() -> int:
115115
return 0
116116

117117

118+
def deserialize_hyperparameters(hyperparameters: Dict[str, str]) -> Dict[str, Any]:
119+
"""Deserialize hyperparameters from string to their original types.
120+
121+
Args:
122+
hyperparameters (Dict[str, str]): Hyperparameters as strings.
123+
124+
Returns:
125+
Dict[str, Any]: Hyperparameters as their original types.
126+
"""
127+
deserialized_hyperparameters = {}
128+
for key, value in hyperparameters.items():
129+
deserialized_hyperparameters[key] = safe_deserialize(value)
130+
return deserialized_hyperparameters
131+
132+
118133
def set_env(
119134
resource_config: Dict[str, Any],
120135
input_data_config: Dict[str, Any],
@@ -150,10 +165,11 @@ def set_env(
150165
env_vars["SM_CHANNELS"] = channels
151166

152167
# Hyperparameters
153-
env_vars["SM_HPS"] = hyperparameters_config
154-
for key, value in hyperparameters_config.items():
168+
hps = deserialize_hyperparameters(hyperparameters_config)
169+
for key, value in hps.items():
155170
key_upper = key.replace("-", "_").upper()
156-
env_vars[f"SM_HP_{key_upper}"] = safe_serialize(value)
171+
env_vars[f"SM_HP_{key_upper}"] = value
172+
env_vars["SM_HPS"] = hps
157173

158174
# Host Variables
159175
current_host = resource_config["current_host"]

src/sagemaker/modules/train/model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class ModelTrainer(BaseModel):
119119
```
120120
121121
Attributes:
122-
session (Optiona(Session)):
122+
sagemaker_session (Optiona(Session)):
123123
The SageMakerCore session. For convinience, can be imported like:
124124
`from sagemaker.modules import Session`.
125125
If not specified, a new session will be created.

tests/data/modules/params_script/train.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,43 @@ def main():
9999
assert json.loads(os.environ["SM_HP_LIST"]) == EXPECTED_HYPERPARAMETERS["list"]
100100
assert json.loads(os.environ["SM_HP_DICT"]) == EXPECTED_HYPERPARAMETERS["dict"]
101101

102+
params = json.loads(os.environ["SM_HPS"])
103+
print(f"SM_HPS: {params}")
104+
assert params["string"] == EXPECTED_HYPERPARAMETERS["string"]
105+
assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"]
106+
assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"]
107+
assert params["float"] == EXPECTED_HYPERPARAMETERS["float"]
108+
assert params["list"] == EXPECTED_HYPERPARAMETERS["list"]
109+
assert params["dict"] == EXPECTED_HYPERPARAMETERS["dict"]
110+
111+
assert isinstance(params, dict)
112+
assert isinstance(params["string"], str)
113+
assert isinstance(params["integer"], int)
114+
assert isinstance(params["boolean"], bool)
115+
assert isinstance(params["float"], float)
116+
assert isinstance(params["list"], list)
117+
assert isinstance(params["dict"], dict)
118+
119+
params = json.loads(os.environ["SM_TRAINING_ENV"])["hyperparameters"]
120+
print(params)
121+
assert params["string"] == EXPECTED_HYPERPARAMETERS["string"]
122+
assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"]
123+
assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"]
124+
assert params["float"] == EXPECTED_HYPERPARAMETERS["float"]
125+
assert params["list"] == EXPECTED_HYPERPARAMETERS["list"]
126+
assert params["dict"] == EXPECTED_HYPERPARAMETERS["dict"]
127+
128+
assert isinstance(params, dict)
129+
assert isinstance(params["string"], str)
130+
assert isinstance(params["integer"], int)
131+
assert isinstance(params["boolean"], bool)
132+
assert isinstance(params["float"], float)
133+
assert isinstance(params["list"], list)
134+
assert isinstance(params["dict"], dict)
135+
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")
136+
137+
print("Test passed.")
138+
102139

103140
if __name__ == "__main__":
104141
main()

tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
mask_sensitive_info,
2727
HIDDEN_VALUE,
2828
)
29-
from sagemaker.modules.train.container_drivers.utils import safe_serialize
29+
from sagemaker.modules.train.container_drivers.utils import safe_serialize, safe_deserialize
3030

3131
RESOURCE_CONFIG = dict(
3232
current_host="algo-1",
@@ -92,11 +92,11 @@
9292
export SM_CHANNEL_TRAIN='/opt/ml/input/data/train'
9393
export SM_CHANNEL_VALIDATION='/opt/ml/input/data/validation'
9494
export SM_CHANNELS='["train", "validation"]'
95-
export SM_HPS='{"batch_size": 32, "learning_rate": 0.001, "hosts": ["algo-1", "algo-2"], "mp_parameters": {"microbatches": 2, "partitions": 2, "pipeline": "interleaved", "optimize": "memory", "horovod": true}}'
9695
export SM_HP_BATCH_SIZE='32'
9796
export SM_HP_LEARNING_RATE='0.001'
9897
export SM_HP_HOSTS='["algo-1", "algo-2"]'
9998
export SM_HP_MP_PARAMETERS='{"microbatches": 2, "partitions": 2, "pipeline": "interleaved", "optimize": "memory", "horovod": true}'
99+
export SM_HPS='{"batch_size": 32, "learning_rate": 0.001, "hosts": ["algo-1", "algo-2"], "mp_parameters": {"microbatches": 2, "partitions": 2, "pipeline": "interleaved", "optimize": "memory", "horovod": true}}'
100100
export SM_CURRENT_HOST='algo-1'
101101
export SM_CURRENT_INSTANCE_TYPE='ml.p3.16xlarge'
102102
export SM_HOSTS='["algo-1", "algo-2", "algo-3"]'
@@ -119,7 +119,13 @@
119119
"sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize",
120120
side_effect=safe_serialize,
121121
)
122-
def test_set_env(mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons):
122+
@patch(
123+
"sagemaker.modules.train.container_drivers.scripts.environment.safe_deserialize",
124+
side_effect=safe_deserialize,
125+
)
126+
def test_set_env(
127+
mock_safe_deserialize, mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons
128+
):
123129
with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}):
124130
set_env(
125131
resource_config=RESOURCE_CONFIG,

0 commit comments

Comments
 (0)