Skip to content

Commit 0d46276

Browse files
schinmayeepintaoz-aws
authored andcommitted
Change default source directory to current, add option to specify source dir (#1593)
1 parent c3d1384 commit 0d46276

File tree

2 files changed

+188
-75
lines changed

2 files changed

+188
-75
lines changed

src/sagemaker/pytorch/estimator.py

Lines changed: 103 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,22 @@
4747
logger = logging.getLogger("sagemaker")
4848

4949

50+
def _setup_omegaconf_resolvers():
51+
"""Set up omegaconf resolvers for training recipes."""
52+
if not OmegaConf.has_resolver("multiply"):
53+
OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True)
54+
if not OmegaConf.has_resolver("divide_ceil"):
55+
OmegaConf.register_new_resolver(
56+
"divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True
57+
)
58+
if not OmegaConf.has_resolver("divide_floor"):
59+
OmegaConf.register_new_resolver(
60+
"divide_floor", lambda x, y: int(math.floor(x / y)), replace=True
61+
)
62+
if not OmegaConf.has_resolver("add"):
63+
OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers))
64+
65+
5066
def _try_resolve_recipe(recipe, key=None):
5167
"""Try to resolve recipe and return resolved recipe."""
5268
if key is not None:
@@ -60,6 +76,49 @@ def _try_resolve_recipe(recipe, key=None):
6076
return recipe[key]
6177

6278

79+
def _get_training_recipe_image_uri(image_cfg, region_name):
80+
"""Fetch image uri given image spec and region name to use for training."""
81+
if isinstance(image_cfg, str):
82+
return image_cfg
83+
return retrieve(
84+
image_cfg.get("framework"),
85+
region=region_name,
86+
version=image_cfg.get("version"),
87+
image_scope="training",
88+
**image_cfg.get("additional_args"),
89+
)
90+
91+
92+
def _get_training_recipe_gpu_script(code_dir, recipe, source_dir):
93+
"""Return path to training script (entry point) when running a gpu recipe."""
94+
model_type_to_script = {
95+
"llama_v3": ("llama", "llama_pretrain.py"),
96+
"mistral": ("mistral", "mistral_pretrain.py"),
97+
"mixtral": ("mixtral", "mixtral_pretrain.py"),
98+
}
99+
100+
if "model" not in recipe:
101+
raise ValueError("Supplied recipe does not contain required field model.")
102+
if "model_type" not in recipe["model"]:
103+
raise ValueError("Supplied recipe does not contain required field model_type.")
104+
model_type = recipe["model"]["model_type"]
105+
if model_type not in model_type_to_script:
106+
raise ValueError(f"Model type {model_type} not supported")
107+
108+
script_dir = os.path.join(code_dir, "examples", model_type_to_script[model_type][0])
109+
script = model_type_to_script[model_type][1]
110+
shutil.copyfile(os.path.join(script_dir, script), os.path.join(source_dir, script))
111+
return script
112+
113+
114+
def _get_training_recipe_trainium_script(code_dir, source_dir):
115+
"""Return path to training script (entry point) when running a trainium recipe."""
116+
script_dir = os.path.join(code_dir, "examples")
117+
script = "training_orchestrator.py"
118+
shutil.copytree(script_dir, source_dir, dirs_exist_ok=True)
119+
return script
120+
121+
63122
class PyTorch(Framework):
64123
"""Handle end-to-end training and deployment of custom PyTorch code."""
65124

@@ -294,13 +353,13 @@ def __init__(
294353
if training_recipe is not None:
295354
if entry_point is not None:
296355
logger.warning("Argument entry_point will be ignored with training_recipe.")
297-
if source_dir is not None:
298-
logger.warning("Argument source_dir will be ignored with training_recipe.")
299356
if hyperparameters is not None:
300357
logger.warning("Argument hyperparameters will be ignored with training recipe.")
301358
if distribution is not None:
302359
logger.warning("Argument distribution will be ignored with training_recipe.")
303-
args = self._setup_for_training_recipe(training_recipe, recipe_overrides, kwargs)
360+
args = self._setup_for_training_recipe(
361+
training_recipe, recipe_overrides, source_dir, kwargs
362+
)
304363
entry_point = args["entry_point"]
305364
source_dir = args["source_dir"]
306365
hyperparameters = args["hyperparameters"]
@@ -538,7 +597,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
538597
return init_params
539598

540599
@classmethod
541-
def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
600+
def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, source_dir, kwargs):
542601
"""Performs training recipe specific setup and returns recipe specific args.
543602
544603
Updates kwargs and returns a dictionary of args to use for estimator
@@ -549,7 +608,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
549608
training_recipe (str): A recipe which is a local file path, a url or a
550609
sagemaker training recipe.
551610
recipe_overrides (Dict): Dictionary specifying key values to override in the
552-
training_recipe.
611+
source_dir (str): Path (absolute, or relative) to a directory where to copy
612+
the scripts for training recipe. requirements.txt can also
613+
go here.
553614
kwargs (dict): Dictionary of args used for estimator initializaiton.
554615
Returns:
555616
dict containing arg values for estimator initialization and setup.
@@ -559,6 +620,7 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
559620
region_name = kwargs.get("sagemaker_session").boto_region_name
560621
else:
561622
region_name = Session().boto_region_name
623+
562624
training_recipes_cfg_filename = os.path.join(
563625
os.path.dirname(__file__), "training_recipes.json"
564626
)
@@ -567,12 +629,16 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
567629

568630
if recipe_overrides is None:
569631
recipe_overrides = dict()
570-
cls.recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_")
571-
cls.recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_")
632+
recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_")
633+
recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_")
634+
args = dict()
635+
if source_dir is None:
636+
args["source_dir"] = "."
637+
else:
638+
args["source_dir"] = source_dir
572639

573-
temp_local_recipe = tempfile.NamedTemporaryFile(
574-
prefix="recipe_original", suffix=".yaml"
575-
).name
640+
recipe_name = os.path.splitext(os.path.basename(training_recipe))[0]
641+
temp_local_recipe = tempfile.NamedTemporaryFile(prefix=recipe_name, suffix=".yaml").name
576642
if training_recipe.endswith(".yaml"):
577643
if os.path.isfile(training_recipe):
578644
shutil.copy(training_recipe, temp_local_recipe)
@@ -587,9 +653,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
587653
launcher_repo = os.environ.get(
588654
"training_launcher_git", None
589655
) or training_recipes_cfg.get("launcher_repo")
590-
_run_clone_command(launcher_repo, cls.recipe_launcher_dir.name)
656+
_run_clone_command(launcher_repo, recipe_launcher_dir.name)
591657
recipe = os.path.join(
592-
cls.recipe_launcher_dir.name,
658+
recipe_launcher_dir.name,
593659
"recipes_collection",
594660
"recipes",
595661
training_recipe + ".yaml",
@@ -628,44 +694,19 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
628694
)
629695
kwargs["instance_count"] = recipe["trainer"]["num_nodes"]
630696

631-
args = dict()
632697
# [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
633698
# to retrieve the image uri below before we go GA.
634699
if device_type == "gpu":
635700
adapter_repo = os.environ.get("training_adapter_git", None) or training_recipes_cfg.get(
636701
"adapter_repo"
637702
)
638-
_run_clone_command(adapter_repo, cls.recipe_train_dir.name)
639-
640-
model_type_to_entry = {
641-
"llama_v3": ("llama", "llama_pretrain.py"),
642-
"mistral": ("mistral", "mistral_pretrain.py"),
643-
"mixtral": ("mixtral", "mixtral_pretrain.py"),
644-
}
645-
646-
if "model" not in recipe:
647-
raise ValueError("Supplied recipe does not contain required field model.")
648-
if "model_type" not in recipe["model"]:
649-
raise ValueError("Supplied recipe does not contain required field model_type.")
650-
model_type = recipe["model"]["model_type"]
651-
if model_type not in model_type_to_entry:
652-
raise ValueError(f"Model type {model_type} not supported")
653-
654-
args["source_dir"] = os.path.join(
655-
cls.recipe_train_dir.name, "examples", model_type_to_entry[model_type][0]
703+
_run_clone_command(adapter_repo, recipe_train_dir.name)
704+
script = _get_training_recipe_gpu_script(
705+
recipe_train_dir.name, recipe, args["source_dir"]
706+
)
707+
args["default_image_uri"] = _get_training_recipe_image_uri(
708+
training_recipes_cfg.get("gpu_image"), region_name
656709
)
657-
args["entry_point"] = model_type_to_entry[model_type][1]
658-
gpu_image_cfg = training_recipes_cfg.get("gpu_image")
659-
if isinstance(gpu_image_cfg, str):
660-
args["default_image_uri"] = gpu_image_cfg
661-
else:
662-
args["default_image_uri"] = retrieve(
663-
gpu_image_cfg.get("framework"),
664-
region=region_name,
665-
version=gpu_image_cfg.get("version"),
666-
image_scope="training",
667-
**gpu_image_cfg.get("additional_args"),
668-
)
669710
smp_options = {
670711
"enabled": True,
671712
"parameters": {
@@ -677,55 +718,45 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
677718
"torch_distributed": {"enabled": True},
678719
}
679720
elif device_type == "trainium":
680-
_run_clone_command(
681-
training_recipes_cfg.get("neuron_dist_repo"), cls.recipe_train_dir.name
721+
_run_clone_command(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name)
722+
script = _get_training_recipe_trainium_script(recipe_train_dir.name, args["source_dir"])
723+
args["default_image_uri"] = _get_training_recipe_image_uri(
724+
training_recipes_cfg.get("neuron_image"), region_name
682725
)
683-
args["source_dir"] = os.path.join(cls.recipe_train_dir.name, "examples")
684-
args["entry_point"] = "training_orchestrator.py"
685-
neuron_image_cfg = training_recipes_cfg.get("neuron_image")
686-
if isinstance(neuron_image_cfg, str):
687-
args["default_image_uri"] = neuron_image_cfg
688-
else:
689-
args["default_image_uri"] = retrieve(
690-
neuron_image_cfg.get("framework"),
691-
region=region_name,
692-
version=neuron_image_cfg.get("version"),
693-
image_scope="training",
694-
**neuron_image_cfg.get("additional_args"),
695-
)
696726
args["distribution"] = {
697727
"torch_distributed": {"enabled": True},
698728
}
699729
else:
700730
raise ValueError(
701731
f"Devices of type {device_type} are not supported with training recipes."
702732
)
733+
args["entry_point"] = os.path.basename(script)
734+
735+
recipe_train_dir.cleanup()
736+
recipe_launcher_dir.cleanup()
703737

704738
if "container" in recipe and not recipe["container"]:
705739
logger.warning(
706740
"Ignoring container from training_recipe. Use image_uri arg for estimator."
707741
)
708742

709-
if not OmegaConf.has_resolver("multiply"):
710-
OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True)
711-
if not OmegaConf.has_resolver("divide_ceil"):
712-
OmegaConf.register_new_resolver(
713-
"divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True
714-
)
715-
if not OmegaConf.has_resolver("divide_floor"):
716-
OmegaConf.register_new_resolver(
717-
"divide_floor", lambda x, y: int(math.floor(x / y)), replace=True
718-
)
719-
if not OmegaConf.has_resolver("add"):
720-
OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers))
743+
_setup_omegaconf_resolvers()
721744
final_recipe = _try_resolve_recipe(recipe)
722745
if final_recipe is None:
723746
final_recipe = _try_resolve_recipe(recipe, "recipes")
724747
if final_recipe is None:
725748
final_recipe = _try_resolve_recipe(recipe, "training")
726749
if final_recipe is None:
727750
raise RuntimeError("Could not resolve provided recipe.")
728-
OmegaConf.save(config=final_recipe, f=os.path.join(args["source_dir"], "recipe.yaml"))
729-
args["hyperparameters"] = {"config-path": ".", "config-name": "recipe.yaml"}
751+
cls.training_recipe_file = tempfile.NamedTemporaryFile(
752+
dir=args["source_dir"],
753+
prefix=recipe_name + "_",
754+
suffix=".yaml",
755+
)
756+
OmegaConf.save(config=final_recipe, f=cls.training_recipe_file.name)
757+
args["hyperparameters"] = {
758+
"config-path": ".",
759+
"config-name": os.path.basename(cls.training_recipe_file.name),
760+
}
730761

731762
return args

tests/unit/test_pytorch.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919
from mock import ANY, MagicMock, Mock, patch
2020
from packaging.version import Version
21+
import tempfile
2122

2223
from sagemaker import image_uris
2324
from sagemaker.pytorch import defaults
@@ -906,7 +907,7 @@ def test_training_recipe_for_gpu(sagemaker_session, recipe, model):
906907
recipe_overrides=recipe_overrides,
907908
)
908909

909-
assert pytorch.source_dir == os.path.join(pytorch.recipe_train_dir.name, "examples", model)
910+
assert pytorch.source_dir == "."
910911
assert pytorch.entry_point == f"{model}_pretrain.py"
911912
expected_distribution = {
912913
"torch_distributed": {
@@ -956,7 +957,46 @@ def test_training_recipe_with_override(sagemaker_session):
956957
recipe_overrides=recipe_overrides,
957958
)
958959

959-
assert pytorch.source_dir == os.path.join(pytorch.recipe_train_dir.name, "examples", "mistral")
960+
assert pytorch.source_dir == "."
961+
assert pytorch.entry_point == "mistral_pretrain.py"
962+
assert pytorch.image_uri == IMAGE_URI
963+
964+
965+
def test_training_recipe_gpu_custom_source_dir(sagemaker_session):
966+
container_log_level = '"logging.INFO"'
967+
968+
recipe_overrides = {
969+
"run": {
970+
"results_dir": "/opt/ml/model",
971+
},
972+
"exp_manager": {
973+
"explicit_log_dir": "/opt/ml/output",
974+
"checkpoint_dir": "/opt/ml/checkpoints",
975+
},
976+
"model": {
977+
"data": {
978+
"train_dir": "/opt/ml/input/data/train",
979+
"val_dir": "/opt/ml/input/data/val",
980+
},
981+
"model_type": "mistral",
982+
},
983+
}
984+
source_dir = tempfile.TemporaryDirectory(prefix="source_")
985+
pytorch = PyTorch(
986+
output_path="s3://output_path",
987+
role=ROLE,
988+
image_uri=IMAGE_URI,
989+
source_dir=source_dir.name,
990+
sagemaker_session=sagemaker_session,
991+
instance_count=INSTANCE_COUNT,
992+
instance_type=INSTANCE_TYPE_GPU,
993+
base_job_name="job",
994+
container_log_level=container_log_level,
995+
training_recipe="training/llama/hf_llama3_8b_seq8192_gpu",
996+
recipe_overrides=recipe_overrides,
997+
)
998+
999+
assert pytorch.source_dir == source_dir.name
9601000
assert pytorch.entry_point == "mistral_pretrain.py"
9611001
assert pytorch.image_uri == IMAGE_URI
9621002

@@ -991,7 +1031,49 @@ def test_training_recipe_for_trainium(sagemaker_session):
9911031
recipe_overrides=recipe_overrides,
9921032
)
9931033

994-
assert pytorch.source_dir == os.path.join(pytorch.recipe_train_dir.name, "examples")
1034+
assert pytorch.source_dir == "."
1035+
assert pytorch.entry_point == "training_orchestrator.py"
1036+
expected_distribution = {
1037+
"torch_distributed": {
1038+
"enabled": True,
1039+
},
1040+
}
1041+
assert pytorch.distribution == expected_distribution
1042+
1043+
1044+
def test_training_recipe_for_trainium_custom_source_dir(sagemaker_session):
1045+
container_log_level = '"logging.INFO"'
1046+
1047+
recipe_overrides = {
1048+
"run": {
1049+
"results_dir": "/opt/ml/model",
1050+
},
1051+
"exp_manager": {
1052+
"explicit_log_dir": "/opt/ml/output",
1053+
},
1054+
"data": {
1055+
"train_dir": "/opt/ml/input/data/train",
1056+
},
1057+
"model": {
1058+
"model_config": "/opt/ml/input/data/train/config.json",
1059+
},
1060+
"compiler_cache_url": "s3://s3://output_path/neuron-cache",
1061+
}
1062+
source_dir = tempfile.TemporaryDirectory(prefix="source_")
1063+
pytorch = PyTorch(
1064+
output_path="s3://output_path",
1065+
role=ROLE,
1066+
source_dir=source_dir.name,
1067+
sagemaker_session=sagemaker_session,
1068+
instance_count=INSTANCE_COUNT,
1069+
instance_type=INSTANCE_TYPE_TRAINIUM,
1070+
base_job_name="job",
1071+
container_log_level=container_log_level,
1072+
training_recipe=NEURON_RECIPE,
1073+
recipe_overrides=recipe_overrides,
1074+
)
1075+
1076+
assert pytorch.source_dir == source_dir.name
9951077
assert pytorch.entry_point == "training_orchestrator.py"
9961078
expected_distribution = {
9971079
"torch_distributed": {

0 commit comments

Comments
 (0)