47
47
logger = logging .getLogger ("sagemaker" )
48
48
49
49
50
+ def _setup_omegaconf_resolvers ():
51
+ """Set up omegaconf resolvers for training recipes."""
52
+ if not OmegaConf .has_resolver ("multiply" ):
53
+ OmegaConf .register_new_resolver ("multiply" , lambda x , y : x * y , replace = True )
54
+ if not OmegaConf .has_resolver ("divide_ceil" ):
55
+ OmegaConf .register_new_resolver (
56
+ "divide_ceil" , lambda x , y : int (math .ceil (x / y )), replace = True
57
+ )
58
+ if not OmegaConf .has_resolver ("divide_floor" ):
59
+ OmegaConf .register_new_resolver (
60
+ "divide_floor" , lambda x , y : int (math .floor (x / y )), replace = True
61
+ )
62
+ if not OmegaConf .has_resolver ("add" ):
63
+ OmegaConf .register_new_resolver ("add" , lambda * numbers : sum (numbers ))
64
+
65
+
50
66
def _try_resolve_recipe (recipe , key = None ):
51
67
"""Try to resolve recipe and return resolved recipe."""
52
68
if key is not None :
@@ -60,6 +76,49 @@ def _try_resolve_recipe(recipe, key=None):
60
76
return recipe [key ]
61
77
62
78
79
+ def _get_training_recipe_image_uri (image_cfg , region_name ):
80
+ """Fetch image uri given image spec and region name to use for training."""
81
+ if isinstance (image_cfg , str ):
82
+ return image_cfg
83
+ return retrieve (
84
+ image_cfg .get ("framework" ),
85
+ region = region_name ,
86
+ version = image_cfg .get ("version" ),
87
+ image_scope = "training" ,
88
+ ** image_cfg .get ("additional_args" ),
89
+ )
90
+
91
+
92
+ def _get_training_recipe_gpu_script (code_dir , recipe , source_dir ):
93
+ """Return path to training script (entry point) when running a gpu recipe."""
94
+ model_type_to_script = {
95
+ "llama_v3" : ("llama" , "llama_pretrain.py" ),
96
+ "mistral" : ("mistral" , "mistral_pretrain.py" ),
97
+ "mixtral" : ("mixtral" , "mixtral_pretrain.py" ),
98
+ }
99
+
100
+ if "model" not in recipe :
101
+ raise ValueError ("Supplied recipe does not contain required field model." )
102
+ if "model_type" not in recipe ["model" ]:
103
+ raise ValueError ("Supplied recipe does not contain required field model_type." )
104
+ model_type = recipe ["model" ]["model_type" ]
105
+ if model_type not in model_type_to_script :
106
+ raise ValueError (f"Model type { model_type } not supported" )
107
+
108
+ script_dir = os .path .join (code_dir , "examples" , model_type_to_script [model_type ][0 ])
109
+ script = model_type_to_script [model_type ][1 ]
110
+ shutil .copyfile (os .path .join (script_dir , script ), os .path .join (source_dir , script ))
111
+ return script
112
+
113
+
114
+ def _get_training_recipe_trainium_script (code_dir , source_dir ):
115
+ """Return path to training script (entry point) when running a trainium recipe."""
116
+ script_dir = os .path .join (code_dir , "examples" )
117
+ script = "training_orchestrator.py"
118
+ shutil .copytree (script_dir , source_dir , dirs_exist_ok = True )
119
+ return script
120
+
121
+
63
122
class PyTorch (Framework ):
64
123
"""Handle end-to-end training and deployment of custom PyTorch code."""
65
124
@@ -294,13 +353,13 @@ def __init__(
294
353
if training_recipe is not None :
295
354
if entry_point is not None :
296
355
logger .warning ("Argument entry_point will be ignored with training_recipe." )
297
- if source_dir is not None :
298
- logger .warning ("Argument source_dir will be ignored with training_recipe." )
299
356
if hyperparameters is not None :
300
357
logger .warning ("Argument hyperparameters will be ignored with training recipe." )
301
358
if distribution is not None :
302
359
logger .warning ("Argument distribution will be ignored with training_recipe." )
303
- args = self ._setup_for_training_recipe (training_recipe , recipe_overrides , kwargs )
360
+ args = self ._setup_for_training_recipe (
361
+ training_recipe , recipe_overrides , source_dir , kwargs
362
+ )
304
363
entry_point = args ["entry_point" ]
305
364
source_dir = args ["source_dir" ]
306
365
hyperparameters = args ["hyperparameters" ]
@@ -538,7 +597,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
538
597
return init_params
539
598
540
599
@classmethod
541
- def _setup_for_training_recipe (cls , training_recipe , recipe_overrides , kwargs ):
600
+ def _setup_for_training_recipe (cls , training_recipe , recipe_overrides , source_dir , kwargs ):
542
601
"""Performs training recipe specific setup and returns recipe specific args.
543
602
544
603
Updates kwargs and returns a dictionary of args to use for estimator
@@ -549,7 +608,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
549
608
training_recipe (str): A recipe which is a local file path, a url or a
550
609
sagemaker training recipe.
551
610
recipe_overrides (Dict): Dictionary specifying key values to override in the
552
- training_recipe.
611
+ source_dir (str): Path (absolute, or relative) to a directory where to copy
612
+ the scripts for training recipe. requirements.txt can also
613
+ go here.
553
614
kwargs (dict): Dictionary of args used for estimator initializaiton.
554
615
Returns:
555
616
dict containing arg values for estimator initialization and setup.
@@ -559,6 +620,7 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
559
620
region_name = kwargs .get ("sagemaker_session" ).boto_region_name
560
621
else :
561
622
region_name = Session ().boto_region_name
623
+
562
624
training_recipes_cfg_filename = os .path .join (
563
625
os .path .dirname (__file__ ), "training_recipes.json"
564
626
)
@@ -567,12 +629,16 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
567
629
568
630
if recipe_overrides is None :
569
631
recipe_overrides = dict ()
570
- cls .recipe_train_dir = tempfile .TemporaryDirectory (prefix = "training_" )
571
- cls .recipe_launcher_dir = tempfile .TemporaryDirectory (prefix = "launcher_" )
632
+ recipe_train_dir = tempfile .TemporaryDirectory (prefix = "training_" )
633
+ recipe_launcher_dir = tempfile .TemporaryDirectory (prefix = "launcher_" )
634
+ args = dict ()
635
+ if source_dir is None :
636
+ args ["source_dir" ] = "."
637
+ else :
638
+ args ["source_dir" ] = source_dir
572
639
573
- temp_local_recipe = tempfile .NamedTemporaryFile (
574
- prefix = "recipe_original" , suffix = ".yaml"
575
- ).name
640
+ recipe_name = os .path .splitext (os .path .basename (training_recipe ))[0 ]
641
+ temp_local_recipe = tempfile .NamedTemporaryFile (prefix = recipe_name , suffix = ".yaml" ).name
576
642
if training_recipe .endswith (".yaml" ):
577
643
if os .path .isfile (training_recipe ):
578
644
shutil .copy (training_recipe , temp_local_recipe )
@@ -587,9 +653,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
587
653
launcher_repo = os .environ .get (
588
654
"training_launcher_git" , None
589
655
) or training_recipes_cfg .get ("launcher_repo" )
590
- _run_clone_command (launcher_repo , cls . recipe_launcher_dir .name )
656
+ _run_clone_command (launcher_repo , recipe_launcher_dir .name )
591
657
recipe = os .path .join (
592
- cls . recipe_launcher_dir .name ,
658
+ recipe_launcher_dir .name ,
593
659
"recipes_collection" ,
594
660
"recipes" ,
595
661
training_recipe + ".yaml" ,
@@ -628,44 +694,19 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
628
694
)
629
695
kwargs ["instance_count" ] = recipe ["trainer" ]["num_nodes" ]
630
696
631
- args = dict ()
632
697
# [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
633
698
# to retrieve the image uri below before we go GA.
634
699
if device_type == "gpu" :
635
700
adapter_repo = os .environ .get ("training_adapter_git" , None ) or training_recipes_cfg .get (
636
701
"adapter_repo"
637
702
)
638
- _run_clone_command (adapter_repo , cls .recipe_train_dir .name )
639
-
640
- model_type_to_entry = {
641
- "llama_v3" : ("llama" , "llama_pretrain.py" ),
642
- "mistral" : ("mistral" , "mistral_pretrain.py" ),
643
- "mixtral" : ("mixtral" , "mixtral_pretrain.py" ),
644
- }
645
-
646
- if "model" not in recipe :
647
- raise ValueError ("Supplied recipe does not contain required field model." )
648
- if "model_type" not in recipe ["model" ]:
649
- raise ValueError ("Supplied recipe does not contain required field model_type." )
650
- model_type = recipe ["model" ]["model_type" ]
651
- if model_type not in model_type_to_entry :
652
- raise ValueError (f"Model type { model_type } not supported" )
653
-
654
- args ["source_dir" ] = os .path .join (
655
- cls .recipe_train_dir .name , "examples" , model_type_to_entry [model_type ][0 ]
703
+ _run_clone_command (adapter_repo , recipe_train_dir .name )
704
+ script = _get_training_recipe_gpu_script (
705
+ recipe_train_dir .name , recipe , args ["source_dir" ]
706
+ )
707
+ args ["default_image_uri" ] = _get_training_recipe_image_uri (
708
+ training_recipes_cfg .get ("gpu_image" ), region_name
656
709
)
657
- args ["entry_point" ] = model_type_to_entry [model_type ][1 ]
658
- gpu_image_cfg = training_recipes_cfg .get ("gpu_image" )
659
- if isinstance (gpu_image_cfg , str ):
660
- args ["default_image_uri" ] = gpu_image_cfg
661
- else :
662
- args ["default_image_uri" ] = retrieve (
663
- gpu_image_cfg .get ("framework" ),
664
- region = region_name ,
665
- version = gpu_image_cfg .get ("version" ),
666
- image_scope = "training" ,
667
- ** gpu_image_cfg .get ("additional_args" ),
668
- )
669
710
smp_options = {
670
711
"enabled" : True ,
671
712
"parameters" : {
@@ -677,55 +718,45 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
677
718
"torch_distributed" : {"enabled" : True },
678
719
}
679
720
elif device_type == "trainium" :
680
- _run_clone_command (
681
- training_recipes_cfg .get ("neuron_dist_repo" ), cls .recipe_train_dir .name
721
+ _run_clone_command (training_recipes_cfg .get ("neuron_dist_repo" ), recipe_train_dir .name )
722
+ script = _get_training_recipe_trainium_script (recipe_train_dir .name , args ["source_dir" ])
723
+ args ["default_image_uri" ] = _get_training_recipe_image_uri (
724
+ training_recipes_cfg .get ("neuron_image" ), region_name
682
725
)
683
- args ["source_dir" ] = os .path .join (cls .recipe_train_dir .name , "examples" )
684
- args ["entry_point" ] = "training_orchestrator.py"
685
- neuron_image_cfg = training_recipes_cfg .get ("neuron_image" )
686
- if isinstance (neuron_image_cfg , str ):
687
- args ["default_image_uri" ] = neuron_image_cfg
688
- else :
689
- args ["default_image_uri" ] = retrieve (
690
- neuron_image_cfg .get ("framework" ),
691
- region = region_name ,
692
- version = neuron_image_cfg .get ("version" ),
693
- image_scope = "training" ,
694
- ** neuron_image_cfg .get ("additional_args" ),
695
- )
696
726
args ["distribution" ] = {
697
727
"torch_distributed" : {"enabled" : True },
698
728
}
699
729
else :
700
730
raise ValueError (
701
731
f"Devices of type { device_type } are not supported with training recipes."
702
732
)
733
+ args ["entry_point" ] = os .path .basename (script )
734
+
735
+ recipe_train_dir .cleanup ()
736
+ recipe_launcher_dir .cleanup ()
703
737
704
738
if "container" in recipe and not recipe ["container" ]:
705
739
logger .warning (
706
740
"Ignoring container from training_recipe. Use image_uri arg for estimator."
707
741
)
708
742
709
- if not OmegaConf .has_resolver ("multiply" ):
710
- OmegaConf .register_new_resolver ("multiply" , lambda x , y : x * y , replace = True )
711
- if not OmegaConf .has_resolver ("divide_ceil" ):
712
- OmegaConf .register_new_resolver (
713
- "divide_ceil" , lambda x , y : int (math .ceil (x / y )), replace = True
714
- )
715
- if not OmegaConf .has_resolver ("divide_floor" ):
716
- OmegaConf .register_new_resolver (
717
- "divide_floor" , lambda x , y : int (math .floor (x / y )), replace = True
718
- )
719
- if not OmegaConf .has_resolver ("add" ):
720
- OmegaConf .register_new_resolver ("add" , lambda * numbers : sum (numbers ))
743
+ _setup_omegaconf_resolvers ()
721
744
final_recipe = _try_resolve_recipe (recipe )
722
745
if final_recipe is None :
723
746
final_recipe = _try_resolve_recipe (recipe , "recipes" )
724
747
if final_recipe is None :
725
748
final_recipe = _try_resolve_recipe (recipe , "training" )
726
749
if final_recipe is None :
727
750
raise RuntimeError ("Could not resolve provided recipe." )
728
- OmegaConf .save (config = final_recipe , f = os .path .join (args ["source_dir" ], "recipe.yaml" ))
729
- args ["hyperparameters" ] = {"config-path" : "." , "config-name" : "recipe.yaml" }
751
+ cls .training_recipe_file = tempfile .NamedTemporaryFile (
752
+ dir = args ["source_dir" ],
753
+ prefix = recipe_name + "_" ,
754
+ suffix = ".yaml" ,
755
+ )
756
+ OmegaConf .save (config = final_recipe , f = cls .training_recipe_file .name )
757
+ args ["hyperparameters" ] = {
758
+ "config-path" : "." ,
759
+ "config-name" : os .path .basename (cls .training_recipe_file .name ),
760
+ }
730
761
731
762
return args
0 commit comments