Skip to content

Commit 8cc19a3

Browse files
beniericpintaoz-aws
authored andcommitted
Mask Sensitive Env Logs in Container (#1568)
1 parent a406f64 commit 8cc19a3

File tree

4 files changed

+99
-7
lines changed

4 files changed

+99
-7
lines changed

src/sagemaker/modules/templates.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,6 @@
7373
cat /opt/ml/input/config/inputdataconfig.json
7474
echo
7575
76-
echo "/opt/ml/input/config/hyperparameters.json:"
77-
cat /opt/ml/input/config/hyperparameters.json
78-
echo
79-
8076
echo "/opt/ml/input/data/sm_drivers/sourcecodeconfig.json"
8177
cat /opt/ml/input/data/sm_drivers/sourcecodeconfig.json
8278
echo

src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,24 @@
7474
"\n",
7575
"pytorch_image = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310\"\n",
7676
"\n",
77-
"pytorch_image = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310\"\n",
78-
"\n",
7977
"source_code_config = SourceCodeConfig(\n",
8078
" source_dir=\"basic-script-mode\",\n",
8179
" command=\"python custom_script.py\",\n",
8280
")\n",
8381
"\n",
82+
"hyperparameters = {\n",
83+
" \"secret_token\": \"123456\",\n",
84+
"}\n",
85+
"\n",
86+
"env_vars = {\n",
87+
" \"PASSWORD\": \"123456\"\n",
88+
"}\n",
89+
"\n",
8490
"model_trainer = ModelTrainer(\n",
8591
" training_image=pytorch_image,\n",
8692
" source_code_config=source_code_config,\n",
93+
" hyperparameters=hyperparameters,\n",
94+
" environment=env_vars,\n",
8795
")\n",
8896
"\n",
8997
"model_trainer.train()"

src/sagemaker/modules/train/container_drivers/scripts/environment.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747

4848
ENV_OUTPUT_FILE = "/opt/ml/input/data/sm_drivers/scripts/sm_training.env"
4949

50+
SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"]
51+
HIDDEN_VALUE = "******"
52+
5053

5154
def num_cpus() -> int:
5255
"""Return the number of CPUs available in the current container.
@@ -199,6 +202,50 @@ def set_env(
199202
else:
200203
f.write(f"export {key}='{value}'\n")
201204

205+
logger.info("Environment Variables:")
206+
log_env_variables(env_vars_dict=env_vars)
207+
208+
209+
def mask_sensitive_info(data):
210+
"""Recursively mask sensitive information in a dictionary."""
211+
if isinstance(data, dict):
212+
for k, v in data.items():
213+
if isinstance(v, dict):
214+
data[k] = mask_sensitive_info(v)
215+
elif isinstance(v, str) and any(
216+
keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS
217+
):
218+
data[k] = HIDDEN_VALUE
219+
return data
220+
221+
222+
def log_key_value(key: str, value: str):
223+
"""Log a key-value pair, masking sensitive values if necessary."""
224+
if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS):
225+
logger.info("%s=%s", key, HIDDEN_VALUE)
226+
elif isinstance(value, dict):
227+
masked_value = mask_sensitive_info(value)
228+
logger.info("%s=%s", key, json.dumps(masked_value))
229+
else:
230+
try:
231+
decoded_value = json.loads(value)
232+
if isinstance(decoded_value, dict):
233+
masked_value = mask_sensitive_info(decoded_value)
234+
logger.info("%s=%s", key, json.dumps(masked_value))
235+
else:
236+
logger.info("%s=%s", key, decoded_value)
237+
except (json.JSONDecodeError, TypeError):
238+
logger.info("%s=%s", key, value)
239+
240+
241+
def log_env_variables(env_vars_dict: Dict[str, Any]):
242+
"""Log Environment Variables from the environment and an env_vars_dict."""
243+
for key, value in os.environ.items():
244+
log_key_value(key, value)
245+
246+
for key, value in env_vars_dict.items():
247+
log_key_value(key, value)
248+
202249

203250
def main():
204251
"""Main function to set the environment variables for the training job container."""

tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,18 @@
1414
from __future__ import absolute_import
1515

1616
import os
17+
import io
18+
import logging
1719

1820
from unittest.mock import patch
1921

20-
from sagemaker.modules.train.container_drivers.scripts.environment import set_env
22+
from sagemaker.modules.train.container_drivers.scripts.environment import (
23+
set_env,
24+
log_key_value,
25+
log_env_variables,
26+
mask_sensitive_info,
27+
HIDDEN_VALUE,
28+
)
2129

2230
RESOURCE_CONFIG = dict(
2331
current_host="algo-1",
@@ -129,6 +137,39 @@ def test_set_env(mock_num_cpus, mock_num_gpus, mock_num_neurons):
129137
assert not os.path.exists(OUTPUT_FILE)
130138

131139

140+
@patch.dict(os.environ, {"SECRET_TOKEN": "122345678", "CLEAR_DATA": "123456789"}, clear=True)
141+
def test_log_env_variables():
142+
log_stream = io.StringIO()
143+
handler = logging.StreamHandler(log_stream)
144+
145+
logger = logging.getLogger("sagemaker.modules.train.container_drivers.scripts.environment")
146+
logger.addHandler(handler)
147+
logger.setLevel(logging.INFO)
148+
149+
env_vars = {
150+
"SM_MODEL_DIR": "/opt/ml/model",
151+
"SM_INPUT_DIR": "/opt/ml/input",
152+
"SM_HPS": {"batch_size": 32, "learning_rate": 0.001, "access_token": "123456789"},
153+
"SM_HP_BATCH_SIZE": 32,
154+
"SM_HP_LEARNING_RATE": 0.001,
155+
"SM_HP_ACCESS_TOKEN": "123456789",
156+
}
157+
log_env_variables(env_vars_dict=env_vars)
158+
159+
log_output = log_stream.getvalue()
160+
161+
assert f"SECRET_TOKEN={HIDDEN_VALUE}" in log_output
162+
assert "CLEAR_DATA=123456789" in log_output
163+
assert "SM_MODEL_DIR=/opt/ml/model" in log_output
164+
assert (
165+
f'SM_HPS={{"batch_size": 32, "learning_rate": 0.001, "access_token": "{HIDDEN_VALUE}"}}'
166+
in log_output
167+
)
168+
assert "SM_HP_BATCH_SIZE=32" in log_output
169+
assert "SM_HP_LEARNING_RATE=0.001" in log_output
170+
assert f"SM_HP_ACCESS_TOKEN={HIDDEN_VALUE}" in log_output
171+
172+
132173
def _remove_extra_lines(string):
133174
"""Removes extra blank lines from a string."""
134175
return "\n".join([line for line in string.splitlines() if line.strip()])

0 commit comments

Comments
 (0)