Skip to content

Commit 694f8e9

Browse files
beniericpintaoz-aws
authored andcommitted
Pass hyperparameters as CLI args (#1577)
1 parent 0847a16 commit 694f8e9

File tree

21 files changed

+672
-95
lines changed

21 files changed

+672
-95
lines changed

src/sagemaker/modules/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,3 @@
1616
from sagemaker_core.main.utils import logger as sagemaker_core_logger
1717

1818
logger = sagemaker_core_logger
19-
20-
from sagemaker.modules.train.model_trainer import ( # noqa: F401 E402 # pylint: disable=C0413
21-
ModelTrainer,
22-
)

src/sagemaker/modules/configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ class Compute(shapes.ResourceConfig):
123123
Attributes:
124124
instance_type (Optional[str]):
125125
The ML compute instance type. For information about available instance types,
126-
see https://aws.amazon.com/sagemaker/pricing/. Default: ml.m5.xlarge
126+
see https://aws.amazon.com/sagemaker/pricing/.
127127
instance_count (Optional[int]): The number of ML compute instances to use. For distributed
128-
training, provide a value greater than 1. Default: 1
128+
training, provide a value greater than 1.
129129
volume_size_in_gb (Optional[int]):
130130
The size of the ML storage volume that you want to provision. ML storage volumes store
131131
model artifacts and incremental states. Training algorithms might also use the ML

src/sagemaker/modules/distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ def model_dump(self, *args, **kwargs):
3030

3131

3232
class Torchrun(DistributedRunner):
33-
"""TorchDistribution.
33+
"""TorchDistributed.
3434
35-
The TorchDistribution runner uses `torchrun` or `torch.distributed.launch` in the backend to
35+
The Torchrun distributed runner uses `torchrun` or `torch.distributed.launch` in the backend to
3636
launch distributed training.
3737
3838
Attributes:

src/sagemaker/modules/templates.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@
1515

1616
EXECUTE_BASE_COMMANDS = """
1717
CMD="{base_command}"
18-
echo "Running command: $CMD"
18+
echo "Executing command: $CMD"
1919
eval $CMD
2020
"""
2121

22+
EXECUTE_BASIC_SCRIPT_DRIVER = """
23+
echo "Running Basic Script driver"
24+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/basic_script_driver.py
25+
"""
26+
2227
EXEUCTE_TORCHRUN_DRIVER = """
2328
echo "Running Torchrun driver"
2429
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/torchrun_driver.py

src/sagemaker/modules/testing_notebooks/base_model_trainer.ipynb

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,11 @@
117117
"metadata": {},
118118
"outputs": [],
119119
"source": [
120+
"from sagemaker.modules.train import ModelTrainer\n",
120121
"from sagemaker.modules.configs import SourceCode\n",
121122
"\n",
123+
"pytorch_image = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310\"\n",
124+
"\n",
122125
"source_code = SourceCode(\n",
123126
" source_dir=\"basic-script-mode\",\n",
124127
" requirements=\"requirements.txt\",\n",
@@ -460,13 +463,6 @@
460463
")\n",
461464
"model_trainer.train(input_data_config=[test_data], wait=False)"
462465
]
463-
},
464-
{
465-
"cell_type": "code",
466-
"execution_count": null,
467-
"metadata": {},
468-
"outputs": [],
469-
"source": []
470466
}
471467
],
472468
"metadata": {
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module is the entry point for the Basic Script Driver."""
14+
from __future__ import absolute_import
15+
16+
import sys
17+
import shlex
18+
19+
from typing import List
20+
21+
from utils import (
22+
logger,
23+
get_python_executable,
24+
read_source_code_json,
25+
read_hyperparameters_json,
26+
execute_commands,
27+
write_failure_file,
28+
hyperparameters_to_cli_args,
29+
)
30+
31+
32+
def create_commands() -> List[str]:
33+
"""Create the commands to execute."""
34+
source_code = read_source_code_json()
35+
hyperparameters = read_hyperparameters_json()
36+
python_executable = get_python_executable()
37+
38+
entry_script = source_code["entry_script"]
39+
args = hyperparameters_to_cli_args(hyperparameters)
40+
if entry_script.endswith(".py"):
41+
commands = [python_executable, entry_script]
42+
commands += args
43+
elif entry_script.endswith(".sh"):
44+
args_str = " ".join(shlex.quote(arg) for arg in args)
45+
commands = [
46+
"/bin/sh",
47+
"-c",
48+
f"chmod +x {entry_script} && ./{entry_script} {args_str}",
49+
]
50+
else:
51+
raise ValueError(
52+
f"Unsupported entry script type: {entry_script}. Only .py and .sh are supported."
53+
)
54+
return commands
55+
56+
57+
def main():
58+
"""Main function for the Basic Script Driver.
59+
60+
This function is the entry point for the Basic Script Driver.
61+
62+
Execution Lifecycle:
63+
1. Read the source code and hyperparameters JSON files.
64+
2. Set hyperparameters as command line arguments.
65+
3. Create the commands to execute.
66+
4. Execute the commands.
67+
"""
68+
69+
cmd = create_commands()
70+
71+
logger.info(f"Executing command: {' '.join(cmd)}")
72+
exit_code, traceback = execute_commands(cmd)
73+
if exit_code != 0:
74+
write_failure_file(traceback)
75+
sys.exit(exit_code)
76+
77+
78+
if __name__ == "__main__":
79+
main()

src/sagemaker/modules/train/container_drivers/mpi_driver.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
logger,
2222
read_source_code_json,
2323
read_distributed_runner_json,
24+
read_hyperparameters_json,
25+
hyperparameters_to_cli_args,
2426
get_process_count,
2527
execute_commands,
2628
write_failure_file,
@@ -58,6 +60,7 @@ def main():
5860
"""
5961
source_code = read_source_code_json()
6062
distribution = read_distributed_runner_json()
63+
hyperparameters = read_hyperparameters_json()
6164

6265
sm_current_host = os.environ["SM_CURRENT_HOST"]
6366
sm_hosts = json.loads(os.environ["SM_HOSTS"])
@@ -87,7 +90,10 @@ def main():
8790
entry_script_path=os.path.join(USER_CODE_PATH, source_code["entry_script"]),
8891
)
8992

90-
logger.info(f"Executing command: {mpi_command}")
93+
args = hyperparameters_to_cli_args(hyperparameters)
94+
mpi_command += args
95+
96+
logger.info(f"Executing command: {' '.join(mpi_command)}")
9197
exit_code, error_traceback = execute_commands(mpi_command)
9298
write_status_file_to_workers(worker_hosts)
9399

src/sagemaker/modules/train/container_drivers/mpi_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
from utils import logger, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable
2323

2424
FINISHED_STATUS_FILE = "/tmp/done.algo-1"
25+
READY_FILE = "/tmp/ready.%s"
2526
DEFAULT_SSH_PORT = 22
2627

2728

28-
def _write_status_file(host: str, status_file: str) -> bool:
29-
"""Write the status file to the provided host."""
29+
def _write_file_to_host(host: str, status_file: str) -> bool:
30+
"""Write the a file to the provided host."""
3031
try:
31-
logger.info("Writing finished status file (%s) to %s", status_file, host)
32+
logger.info(f"Writing {status_file} to {host}")
3233
subprocess.run(
3334
["ssh", host, "touch", f"{status_file}"],
3435
capture_output=True,
@@ -46,7 +47,7 @@ def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FIN
4647
"""Write the status file to all worker nodes."""
4748
for worker in worker_hosts:
4849
retry = 0
49-
while not _write_status_file(worker, status_file):
50+
while not _write_file_to_host(worker, status_file):
5051
time.sleep(5)
5152
retry += 1
5253
if retry > 5:
@@ -102,7 +103,10 @@ def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, tim
102103

103104
while True:
104105
logger.info("Master is attempting to connect to all workers...")
105-
all_workers_connected = all(_can_connect(worker, port) for worker in worker_hosts)
106+
all_workers_connected = all(
107+
_can_connect(worker, port) and os.path.exists(READY_FILE % worker)
108+
for worker in worker_hosts
109+
)
106110

107111
if all_workers_connected:
108112
logger.info("Master can connect to all worker nodes.")
@@ -131,6 +135,7 @@ def bootstrap_worker_node(master_host: str, status_file: str = FINISHED_STATUS_F
131135
"""Bootstrap the worker nodes."""
132136
logger.info("Bootstrapping worker node...")
133137
_wait_for_master(master_host)
138+
_write_file_to_host(master_host, READY_FILE % os.environ["SM_CURRENT_HOST"])
134139
_wait_for_status_file(status_file)
135140

136141

src/sagemaker/modules/train/container_drivers/scripts/environment.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
import sys
2222
import logging
2323

24+
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
25+
sys.path.insert(0, parent_dir)
26+
27+
from utils import safe_serialize # noqa: E402 # pylint: disable=C0413
28+
2429
# Initialize logger
2530
SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20)
2631
logger = logging.getLogger(__name__)
@@ -147,7 +152,7 @@ def set_env(
147152
# Hyperparameters
148153
env_vars["SM_HPS"] = hyperparameters_config
149154
for key, value in hyperparameters_config.items():
150-
env_vars[f"SM_HP_{key.upper()}"] = value
155+
env_vars[f"SM_HP_{key.upper()}"] = safe_serialize(value)
151156

152157
# Host Variables
153158
current_host = resource_config["current_host"]
@@ -197,10 +202,7 @@ def set_env(
197202
}
198203
with open(output_file, "w") as f:
199204
for key, value in env_vars.items():
200-
if isinstance(value, (list, dict)):
201-
f.write(f"export {key}='{json.dumps(value)}'\n")
202-
else:
203-
f.write(f"export {key}='{value}'\n")
205+
f.write(f"export {key}='{safe_serialize(value)}'\n")
204206

205207
logger.info("Environment Variables:")
206208
log_env_variables(env_vars_dict=env_vars)

src/sagemaker/modules/train/container_drivers/torchrun_driver.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
logger,
2323
read_source_code_json,
2424
read_distributed_runner_json,
25+
read_hyperparameters_json,
26+
hyperparameters_to_cli_args,
2527
get_process_count,
2628
get_python_executable,
27-
SM_EFA_NCCL_INSTANCES,
28-
SM_EFA_RDMA_INSTANCES,
2929
execute_commands,
3030
write_failure_file,
3131
USER_CODE_PATH,
32+
SM_EFA_NCCL_INSTANCES,
33+
SM_EFA_RDMA_INSTANCES,
3234
)
3335

3436

@@ -65,6 +67,7 @@ def create_commands():
6567
"""Create the Torch Distributed command to execute"""
6668
source_code = read_source_code_json()
6769
distribution = read_distributed_runner_json()
70+
hyperparameters = read_hyperparameters_json()
6871

6972
process_count = get_process_count(distribution)
7073
host_count = int(os.environ["SM_HOST_COUNT"])
@@ -92,6 +95,10 @@ def create_commands():
9295
)
9396

9497
torch_cmd.extend([os.path.join(USER_CODE_PATH, source_code["entry_script"])])
98+
99+
args = hyperparameters_to_cli_args(hyperparameters)
100+
torch_cmd += args
101+
95102
return torch_cmd
96103

97104

@@ -110,7 +117,7 @@ def main():
110117
"""
111118
setup_env()
112119
torch_cmd = create_commands()
113-
logger.info(f"Executing command: {torch_cmd}")
120+
logger.info(f"Executing command: {' '.join(torch_cmd)}")
114121
exit_code, traceback = execute_commands(torch_cmd)
115122
if exit_code != 0:
116123
write_failure_file(traceback)

0 commit comments

Comments
 (0)