Skip to content

Commit 30dfdca

Browse files
schinmayeepintaoz-aws
authored andcommitted
Remove default values for fields in recipe_overrides and fix recipe path. (#1566)
1 parent 9480ee0 commit 30dfdca

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

src/sagemaker/pytorch/estimator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,6 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
592592
cls.recipe_launcher_dir.name,
593593
"recipes_collection",
594594
"recipes",
595-
"training",
596595
training_recipe + ".yaml",
597596
)
598597
if os.path.isfile(recipe):
@@ -602,8 +601,6 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
602601

603602
recipe = OmegaConf.load(temp_local_recipe)
604603
os.unlink(temp_local_recipe)
605-
recipe_overrides.setdefault("run", dict())["results_dir"] = "/opt/ml/model"
606-
recipe_overrides.setdefault("exp_manager", dict())["exp_dir"] = "/opt/ml/model/"
607604
recipe = OmegaConf.merge(recipe, recipe_overrides)
608605

609606
if "instance_type" not in kwargs:

tests/unit/test_pytorch.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,9 @@ def test_training_recipe_for_cpu(sagemaker_session):
839839
container_log_level = '"logging.INFO"'
840840

841841
recipe_overrides = {
842+
"run": {
843+
"results_dir": "/opt/ml/model",
844+
},
842845
"exp_manager": {
843846
"explicit_log_dir": "/opt/ml/output/tensorboard",
844847
"checkpoint_dir": "/opt/ml/checkpoints",
@@ -860,7 +863,7 @@ def test_training_recipe_for_cpu(sagemaker_session):
860863
instance_type=INSTANCE_TYPE,
861864
base_job_name="job",
862865
container_log_level=container_log_level,
863-
training_recipe="llama/hf_llama3_8b_seq8192_gpu",
866+
training_recipe="training/llama/hf_llama3_8b_seq8192_gpu",
864867
recipe_overrides=recipe_overrides,
865868
)
866869

@@ -877,6 +880,9 @@ def test_training_recipe_for_gpu(sagemaker_session, recipe, model):
877880
container_log_level = '"logging.INFO"'
878881

879882
recipe_overrides = {
883+
"run": {
884+
"results_dir": "/opt/ml/model",
885+
},
880886
"exp_manager": {
881887
"explicit_log_dir": "/opt/ml/output",
882888
"checkpoint_dir": "/opt/ml/checkpoints",
@@ -896,7 +902,7 @@ def test_training_recipe_for_gpu(sagemaker_session, recipe, model):
896902
instance_type=INSTANCE_TYPE_GPU,
897903
base_job_name="job",
898904
container_log_level=container_log_level,
899-
training_recipe=f"{model}/{recipe}",
905+
training_recipe=f"training/{model}/{recipe}",
900906
recipe_overrides=recipe_overrides,
901907
)
902908

@@ -922,6 +928,9 @@ def test_training_recipe_with_override(sagemaker_session):
922928
container_log_level = '"logging.INFO"'
923929

924930
recipe_overrides = {
931+
"run": {
932+
"results_dir": "/opt/ml/model",
933+
},
925934
"exp_manager": {
926935
"explicit_log_dir": "/opt/ml/output",
927936
"checkpoint_dir": "/opt/ml/checkpoints",
@@ -943,7 +952,7 @@ def test_training_recipe_with_override(sagemaker_session):
943952
instance_type=INSTANCE_TYPE_GPU,
944953
base_job_name="job",
945954
container_log_level=container_log_level,
946-
training_recipe="llama/hf_llama3_8b_seq8192_gpu",
955+
training_recipe="training/llama/hf_llama3_8b_seq8192_gpu",
947956
recipe_overrides=recipe_overrides,
948957
)
949958

@@ -956,6 +965,9 @@ def test_training_recipe_for_trainium(sagemaker_session):
956965
container_log_level = '"logging.INFO"'
957966

958967
recipe_overrides = {
968+
"run": {
969+
"results_dir": "/opt/ml/model",
970+
},
959971
"exp_manager": {
960972
"explicit_log_dir": "/opt/ml/output",
961973
},

0 commit comments

Comments
 (0)