From 7babe4e25100212599ec23126f8576aec3485e61 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Tue, 8 Apr 2025 11:47:13 +0000 Subject: [PATCH 01/26] chore(plugin): Add torchtune-related constants & update current torch plugin. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/constants/constants.go | 15 ++++++ pkg/runtime/framework/plugins/torch/torch.go | 50 +++++++++++--------- sdk/kubeflow/trainer/types/types.py | 1 + 3 files changed, 44 insertions(+), 22 deletions(-) diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index e149e1fe70..623960f40b 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -134,6 +134,18 @@ const ( // TorchEnvMasterPort is the env name for the master node port. TorchEnvMasterPort string = "PET_MASTER_PORT" + + // TochTuneArgNumNodes is the arg anme for the number of training nodes. + TorchTuneArgNumNodes string = "--nnodes" + + // TorchTuneArgNumProcPerNode is the arg name for the number of procs per node (e.g. number of GPUs per Pod). + TorchTuneArgNumProcPerNode string = "--nproc_per_node" + + // TorchTuneArgRdzvId is the arg name for the rendezvous ID. + TorchTuneArgRdzvId string = "--rdzv_id" + + // TorchTuneArgRdzvEndpoint is the arg name for the rendezvous endpoint. + TorchTuneArgRdzvEndpoint string = "--rdzv_endpoint" ) var ( @@ -142,4 +154,7 @@ var ( // Torchrun reserved env names TorchRunReservedEnvNames = sets.New(TorchEnvNumNodes, TorchEnvNumProcPerNode, TorchEnvNodeRank, TorchEnvMasterAddr, TorchEnvMasterPort) + + // TorchTuneEntrypoint is the entrypoint for the torchtune. + TorchTuneEntrypoint = []string{"tune", "run"} ) diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index 6c4783bddc..bd25c9db3f 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -19,6 +19,7 @@ package torch import ( "context" "fmt" + "slices" "strings" corev1 "k8s.io/api/core/v1" @@ -137,9 +138,6 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) } // Update envs for Info object. - // Add PyTorch distributed "PET_" values for torchrun - // TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other. - // Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940 var trainerContainer *runtime.Container if trainJob.Spec.Trainer != nil { if trainerContainer = info.FindContainerByPodSetAncestorContainerName(constants.AncestorTrainer, constants.Node); trainerContainer != nil { @@ -147,25 +145,33 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) } } if trainerContainer != nil { - apply.UpsertEnvVar(&trainerContainer.Env, - *corev1ac.EnvVar(). - WithName(constants.TorchEnvNumNodes). - WithValue(fmt.Sprintf("%d", ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1))), - *corev1ac.EnvVar(). - WithName(constants.TorchEnvNumProcPerNode). - WithValue(numProcPerNode.String()), - *corev1ac.EnvVar(). - WithName(constants.TorchEnvNodeRank). - WithValueFrom(corev1ac.EnvVarSource(). - WithFieldRef(corev1ac.ObjectFieldSelector(). - WithFieldPath(constants.JobCompletionIndexFieldPath))), - *corev1ac.EnvVar(). - WithName(constants.TorchEnvMasterAddr). - WithValue(fmt.Sprintf("%s-%s-0-0.%s", trainJob.Name, constants.Node, trainJob.Name)), - *corev1ac.EnvVar(). - WithName(constants.TorchEnvMasterPort). - WithValue(fmt.Sprintf("%d", constants.ContainerTrainerPort)), - ) + // Add PyTorch distributed "PET_" values for torchrun. + // TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other. + // Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940 + if !slices.Equal(trainJob.Spec.Trainer.Command, constants.TorchTuneEntrypoint) { + apply.UpsertEnvVar(&trainerContainer.Env, + *corev1ac.EnvVar(). + WithName(constants.TorchEnvNumNodes). + WithValue(fmt.Sprintf("%d", ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1))), + *corev1ac.EnvVar(). + WithName(constants.TorchEnvNumProcPerNode). + WithValue(numProcPerNode.String()), + *corev1ac.EnvVar(). + WithName(constants.TorchEnvNodeRank). + WithValueFrom(corev1ac.EnvVarSource(). + WithFieldRef(corev1ac.ObjectFieldSelector(). + WithFieldPath(constants.JobCompletionIndexFieldPath))), + *corev1ac.EnvVar(). + WithName(constants.TorchEnvMasterAddr). + WithValue(fmt.Sprintf("%s-%s-0-0.%s", trainJob.Name, constants.Node, trainJob.Name)), + *corev1ac.EnvVar(). + WithName(constants.TorchEnvMasterPort). + WithValue(fmt.Sprintf("%d", constants.ContainerTrainerPort)), + ) + } else { + // Add PyTorch distributed command line args for torchtune. + + } // Add container port for the headless service. apply.UpsertPort(&trainerContainer.Ports, *corev1ac.ContainerPort().WithContainerPort(constants.ContainerTrainerPort)) } diff --git a/sdk/kubeflow/trainer/types/types.py b/sdk/kubeflow/trainer/types/types.py index 130dba4691..f36b0fd05b 100644 --- a/sdk/kubeflow/trainer/types/types.py +++ b/sdk/kubeflow/trainer/types/types.py @@ -220,6 +220,7 @@ class Initializer: "ghcr.io/kubeflow/trainer/torchtune-trainer": Trainer( trainer_type=TrainerType.BUILTIN_TRAINER, framework=Framework.TORCHTUNE, + entrypoint=constants.DEFAULT_TORCHTUNE_COMMAND, ), } From e6a9c540876a7299463d348f7cab7779fc2e157b Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Tue, 8 Apr 2025 12:15:27 +0000 Subject: [PATCH 02/26] chore(plugin): Add EnforceMLPolicy for torchtune. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/apply/apply.go | 10 +++++ pkg/runtime/framework/plugins/torch/torch.go | 47 +++++++++++++++----- pkg/runtime/runtime.go | 1 + 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/pkg/apply/apply.go b/pkg/apply/apply.go index 4f4655fcb1..3d16ff1729 100644 --- a/pkg/apply/apply.go +++ b/pkg/apply/apply.go @@ -34,6 +34,12 @@ var ( errorRequestedFieldPathNotFound = errors.New("requested field path not found") ) +func UpsertArgs(args *[]string, upArgs ...string) { + for _, a := range upArgs { + upsert(args, a, byArgName) + } +} + func UpsertEnvVar(envVars *[]corev1ac.EnvVarApplyConfiguration, envVar ...corev1ac.EnvVarApplyConfiguration) { for _, e := range envVar { upsert(envVars, e, byEnvVarName) @@ -64,6 +70,10 @@ func UpsertVolumeMounts(mounts *[]corev1ac.VolumeMountApplyConfiguration, upMoun } } +func byArgName(a, b string) bool { + return a == b +} + func byEnvVarName(a, b corev1ac.EnvVarApplyConfiguration) bool { return ptr.Equal(a.Name, b.Name) } diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index bd25c9db3f..8c6055ede0 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -22,6 +22,7 @@ import ( "slices" "strings" + "github.com/google/uuid" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/apimachinery/pkg/util/sets" @@ -71,16 +72,20 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob) } } - torchEnvs := sets.New[string]() - for _, env := range newObj.Spec.Trainer.Env { - if constants.TorchRunReservedEnvNames.Has(env.Name) { - torchEnvs.Insert(env.Name) + // Check reserved envs for torchrun. + // TODO(Electronic-Waste): Add validation for torchtune args. + if !slices.Equal(newObj.Spec.Trainer.Command, constants.TorchTuneEntrypoint) { + torchEnvs := sets.New[string]() + for _, env := range newObj.Spec.Trainer.Env { + if constants.TorchRunReservedEnvNames.Has(env.Name) { + torchEnvs.Insert(env.Name) + } } - } - if torchEnvs.Len() > 0 { - trainerEnvsPath := specPath.Child("trainer").Child("env") - allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("must not have reserved envs, invalid envs configured: %v", sets.List(torchEnvs)))) + if torchEnvs.Len() > 0 { + trainerEnvsPath := specPath.Child("trainer").Child("env") + allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("must not have reserved envs, invalid envs configured: %v", sets.List(torchEnvs)))) + } } } @@ -145,10 +150,10 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) } } if trainerContainer != nil { - // Add PyTorch distributed "PET_" values for torchrun. - // TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other. - // Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940 if !slices.Equal(trainJob.Spec.Trainer.Command, constants.TorchTuneEntrypoint) { + // Add PyTorch distributed "PET_" values for torchrun. + // TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other. + // Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940 apply.UpsertEnvVar(&trainerContainer.Env, *corev1ac.EnvVar(). WithName(constants.TorchEnvNumNodes). @@ -170,7 +175,25 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) ) } else { // Add PyTorch distributed command line args for torchtune. - + // TODO(Electronic-Waste): Add more args for torchtune if required. + apply.UpsertArgs(&trainerContainer.Args, + fmt.Sprintf("%s %s", + constants.TorchTuneArgNumNodes, + fmt.Sprintf("%d", ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1)), + ), + fmt.Sprintf("%s %s", + constants.TorchTuneArgNumProcPerNode, + numProcPerNode.String(), + ), + fmt.Sprintf("%s %s", + constants.TorchTuneArgRdzvId, + uuid.New().String(), + ), + fmt.Sprintf("%s %s", + constants.TorchTuneArgRdzvEndpoint, + fmt.Sprintf("%s-%s-0-0.%s:%d", trainJob.Name, constants.Node, trainJob.Name, constants.ContainerTrainerPort), + ), + ) } // Add container port for the headless service. apply.UpsertPort(&trainerContainer.Ports, *corev1ac.ContainerPort().WithContainerPort(constants.ContainerTrainerPort)) diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 9ec15730f2..c3496d0f56 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -80,6 +80,7 @@ type PodSet struct { type Container struct { Name string + Args []string Env []corev1ac.EnvVarApplyConfiguration Ports []corev1ac.ContainerPortApplyConfiguration VolumeMounts []corev1ac.VolumeMountApplyConfiguration From afa49c4925958d918501dea32d36c3e5e1f765c0 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Tue, 8 Apr 2025 13:21:46 +0000 Subject: [PATCH 03/26] chore(plugin): Add UTs in torch plugin. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/framework/plugins/torch/torch.go | 3 +- .../framework/plugins/torch/torch_test.go | 102 ++++++++++++++++++ 2 files changed, 103 insertions(+), 2 deletions(-) diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index 8c6055ede0..f48736ea63 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -22,7 +22,6 @@ import ( "slices" "strings" - "github.com/google/uuid" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/apimachinery/pkg/util/sets" @@ -187,7 +186,7 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) ), fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvId, - uuid.New().String(), + trainJob.Name, ), fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index af8e6934ee..3a7195bd41 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1126,6 +1126,74 @@ func TestTorch(t *testing.T) { Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)}, }, }, + "pass distributed params to Args when using torchtune": { + trainJob: utiltesting.MakeTrainJobWrapper("default", "torchtune-job"). + Trainer( + utiltesting.MakeTrainJobTrainerWrapper(). + NumNodes(4). + NumProcPerNode(intstr.FromString("auto")). + Container( + "ghcr.io/kubeflow/trainer/torchtune-trainer", + []string{"tune", "run"}, + []string{ + "dtype=fp16", + "batch_size=32", + "epochs=10", + "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", + }, + corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("8"), + corev1.ResourceMemory: resource.MustParse("16Gi"), + "nvidia.com/gpu": resource.MustParse("4"), // 4 GPUs per node + }, + ). + Obj(), + ). + Obj(), + info: runtime.NewInfo( + runtime.WithMLPolicySource( + utiltesting.MakeMLPolicyWrapper(). + WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper(). + TorchPolicy(ptr.To(intstr.FromString("auto")), nil). + Obj(), + ). + Obj(), + ), + runtime.WithPodSet(constants.Node, ptr.To(constants.AncestorTrainer), 1, corev1.PodSpec{}, corev1ac.PodSpec(). + WithContainers(corev1ac.Container().WithName(constants.Node)), + ), + ), + wantInfo: &runtime.Info{ + Labels: make(map[string]string), + Annotations: make(map[string]string), + RuntimePolicy: runtime.RuntimePolicy{ + MLPolicySource: utiltesting.MakeMLPolicySourceWrapper(). + TorchPolicy(ptr.To(intstr.FromString("auto")), nil). + Obj(), + }, + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{{ + Name: constants.Node, + Ancestor: ptr.To(constants.AncestorTrainer), + Count: ptr.To[int32](1), + SinglePodRequests: make(corev1.ResourceList), + Containers: []runtime.Container{{ + Name: constants.Node, + Args: []string{ + fmt.Sprintf("%s %s", constants.TorchTuneArgNumNodes, "4"), + fmt.Sprintf("%s %s", constants.TorchTuneArgNumProcPerNode, "auto"), + fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvId, "torchtune-job"), + fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), + }, + Ports: []corev1ac.ContainerPortApplyConfiguration{{ + ContainerPort: ptr.To[int32](constants.ContainerTrainerPort), + }}, + }}, + }}, + }, + Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)}, + }, + }, } for name, tc := range cases { @@ -1327,6 +1395,40 @@ func TestValidate(t *testing.T) { ), }, }, + "no reserved environment variable for torchtune": { + info: runtime.NewInfo( + runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper(). + WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper(). + TorchPolicy(ptr.To(intstr.FromString("auto")), nil). + Obj(), + ). + Obj(), + ), + ), + newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper(). + NumProcPerNode(intstr.FromString("auto")). + Container( + "ghcr.io/kubeflow/trainer/torchtune-trainer", + []string{"tune", "run"}, + nil, corev1.ResourceList{}, + ). + Env( + []corev1.EnvVar{ + { + Name: "test", + Value: "value", + }, + { + Name: constants.TorchEnvNumProcPerNode, + Value: "value", + }, + }..., + ). + Obj(), + ). + Obj(), + }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { From ada86f88794e9786145c383915c37d092ae4a8f8 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Tue, 8 Apr 2025 15:59:46 +0000 Subject: [PATCH 04/26] fix(test): fix error in torch plugin UTs. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/framework/plugins/torch/torch_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index 3a7195bd41..0b83845a02 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1175,7 +1175,7 @@ func TestTorch(t *testing.T) { PodSets: []runtime.PodSet{{ Name: constants.Node, Ancestor: ptr.To(constants.AncestorTrainer), - Count: ptr.To[int32](1), + Count: ptr.To[int32](4), SinglePodRequests: make(corev1.ResourceList), Containers: []runtime.Container{{ Name: constants.Node, From c9e340ebde4f4d45883eae95c608c67109445ab0 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Wed, 9 Apr 2025 02:34:59 +0000 Subject: [PATCH 05/26] chore(plugin): Choose recipe according to numNodes & numProcPerNode & Args. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/apply/apply.go | 10 ------ pkg/constants/constants.go | 9 +++++ pkg/runtime/framework/plugins/torch/torch.go | 37 ++++++++++++++++---- 3 files changed, 40 insertions(+), 16 deletions(-) diff --git a/pkg/apply/apply.go b/pkg/apply/apply.go index 3d16ff1729..4f4655fcb1 100644 --- a/pkg/apply/apply.go +++ b/pkg/apply/apply.go @@ -34,12 +34,6 @@ var ( errorRequestedFieldPathNotFound = errors.New("requested field path not found") ) -func UpsertArgs(args *[]string, upArgs ...string) { - for _, a := range upArgs { - upsert(args, a, byArgName) - } -} - func UpsertEnvVar(envVars *[]corev1ac.EnvVarApplyConfiguration, envVar ...corev1ac.EnvVarApplyConfiguration) { for _, e := range envVar { upsert(envVars, e, byEnvVarName) @@ -70,10 +64,6 @@ func UpsertVolumeMounts(mounts *[]corev1ac.VolumeMountApplyConfiguration, upMoun } } -func byArgName(a, b string) bool { - return a == b -} - func byEnvVarName(a, b corev1ac.EnvVarApplyConfiguration) bool { return ptr.Equal(a.Name, b.Name) } diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 623960f40b..0dc4bfaca5 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -146,6 +146,15 @@ const ( // TorchTuneArgRdzvEndpoint is the arg name for the rendezvous endpoint. TorchTuneArgRdzvEndpoint string = "--rdzv_endpoint" + + // TorchTuneFullFinetuneSingleDevice Recipe is the recipe for the single device full finetune. + TorchTuneFullFinetuneSingleDevice string = "full_finetune_single_device" + + // TorchTuneFullFinetuneDistributed Recipe is the recipe for the distributed full finetune. + TorchTuneFullFinetuneDistributed string = "full_finetune_distributed" + + // TorchTuneDefaultRecipe is the default recipe for the torchtune. + TorchTuneDefaultRecipe string = TorchTuneFullFinetuneDistributed ) var ( diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index f48736ea63..d03559ca23 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -173,12 +173,17 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) WithValue(fmt.Sprintf("%d", constants.ContainerTrainerPort)), ) } else { - // Add PyTorch distributed command line args for torchtune. + // Mutate command line args for torchtune. + // Ref: https://github.com/kubeflow/trainer/tree/master/docs/proposals/2401-llm-trainer-v2#complement-torch-plugin + oldArgs, newArgs := trainerContainer.Args, []string{} + + // 1. Add PyTorch distributed command line args for torchtune. // TODO(Electronic-Waste): Add more args for torchtune if required. - apply.UpsertArgs(&trainerContainer.Args, - fmt.Sprintf("%s %s", + numNodes := ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1) + newArgs = append(newArgs, + fmt.Sprintf("%s %d", constants.TorchTuneArgNumNodes, - fmt.Sprintf("%d", ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1)), + numNodes, ), fmt.Sprintf("%s %s", constants.TorchTuneArgNumProcPerNode, @@ -188,11 +193,21 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) constants.TorchTuneArgRdzvId, trainJob.Name, ), - fmt.Sprintf("%s %s", + fmt.Sprintf("%s %s-%s-0-0.%s:%d", constants.TorchTuneArgRdzvEndpoint, - fmt.Sprintf("%s-%s-0-0.%s:%d", trainJob.Name, constants.Node, trainJob.Name, constants.ContainerTrainerPort), + trainJob.Name, constants.Node, trainJob.Name, constants.ContainerTrainerPort, ), ) + + // 2. Get the recipe and config from old args and append them to new args. + newArgs = append(newArgs, + getRecipeFromArgs(numNodes, numProcPerNode, oldArgs), + ) + + // 3. Reserve old arguments to override corresponding items in the config file. + newArgs = append(newArgs, oldArgs...) + + trainerContainer.Args = newArgs } // Add container port for the headless service. apply.UpsertPort(&trainerContainer.Ports, *corev1ac.ContainerPort().WithContainerPort(constants.ContainerTrainerPort)) @@ -216,3 +231,13 @@ func calculateNumProcPerNode( } return intstr.FromInt32(defaultCPU), false } + +// getRecipeFromArgs extracts the recipe from the distributed parameters and command line arguments. +// TODO(Electronic-Waste): Add support for more recipes. +func getRecipeFromArgs(numNodes int32, numProcPerNode intstr.IntOrString, _ []string) string { + recipe := constants.TorchTuneDefaultRecipe + if numNodes == 1 && numProcPerNode.Type == intstr.Int && numProcPerNode.IntVal == 1 { + recipe = constants.TorchTuneFullFinetuneSingleDevice + } + return recipe +} From 8ecaa5bb179cf37018ad8fae73a41ab0b5bfeb10 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Wed, 9 Apr 2025 03:05:56 +0000 Subject: [PATCH 06/26] chore(sdk): Add PretrainedModel enum type. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- sdk/kubeflow/trainer/__init__.py | 1 + sdk/kubeflow/trainer/types/types.py | 8 +++++++- sdk/kubeflow/trainer/utils/utils.py | 12 +++++++++++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/sdk/kubeflow/trainer/__init__.py b/sdk/kubeflow/trainer/__init__.py index 9d07ea282e..a8950a9557 100644 --- a/sdk/kubeflow/trainer/__init__.py +++ b/sdk/kubeflow/trainer/__init__.py @@ -33,6 +33,7 @@ HuggingFaceModelInitializer, Initializer, Loss, + PretrainedModel, Runtime, Trainer, TrainerType, diff --git a/sdk/kubeflow/trainer/types/types.py b/sdk/kubeflow/trainer/types/types.py index f36b0fd05b..f43f3ee22e 100644 --- a/sdk/kubeflow/trainer/types/types.py +++ b/sdk/kubeflow/trainer/types/types.py @@ -115,6 +115,12 @@ class Framework(Enum): TORCHTUNE = "torchtune" +class PretrainedModel(Enum): + LLAMA3_2_1B = "llama3_2/1B" + LLAMA3_2_3B = "llama3_2/3B" + LLAMA3_3_70B = "llama3_3/70B" + + # Representation for the Trainer of the runtime. @dataclass class Trainer: @@ -130,7 +136,7 @@ class Trainer: class Runtime: name: str trainer: Trainer - pretrained_model: Optional[str] = None + pretrained_model: Optional[PretrainedModel] = None # Representation for the TrainJob steps. diff --git a/sdk/kubeflow/trainer/utils/utils.py b/sdk/kubeflow/trainer/utils/utils.py index e83e7f67a3..be142a8fc0 100644 --- a/sdk/kubeflow/trainer/utils/utils.py +++ b/sdk/kubeflow/trainer/utils/utils.py @@ -320,6 +320,7 @@ def get_entrypoint_using_train_func( def get_args_using_torchtune_config( + runtime: types.Runtime, fine_tuning_config: types.TorchTuneConfig, ) -> Tuple[List[str], List[str]]: """ @@ -346,6 +347,12 @@ def get_args_using_torchtune_config( if fine_tuning_config.loss: args.append(f"loss={fine_tuning_config.loss}") + # Provide pre-trained model information. + # TODO(Electronic-Waste): Move pre-trained model information to the runtime API fields. + # Ref: https://github.com/kubeflow/trainer/pull/2410#pullrequestreview-2672356400 + if runtime.pretrained_model: + args.append(f"model={runtime.pretrained_model}") + return constants.DEFAULT_TORCHTUNE_COMMAND, args @@ -384,6 +391,7 @@ def get_trainer_crd_from_custom_trainer( def get_trainer_crd_from_builtin_trainer( trainer: types.BuiltinTrainer, + runtime: types.Runtime, ) -> models.TrainerV1alpha1Trainer: """ Get the Trainer CRD from the builtin trainer. @@ -406,7 +414,9 @@ def get_trainer_crd_from_builtin_trainer( # Parse args in the TorchTuneConfig to the Trainer, preparing for the mutation of # the torchtune config in the runtime plugin. # Ref:https://github.com/kubeflow/trainer/tree/master/docs/proposals/2401-llm-trainer-v2 - trainer_crd.command, trainer_crd.args = get_args_using_torchtune_config(trainer) + trainer_crd.command, trainer_crd.args = get_args_using_torchtune_config( + runtime, trainer + ) return trainer_crd From 06b555d6ef884257016880d2f719121652b63e23 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Wed, 9 Apr 2025 06:20:21 +0000 Subject: [PATCH 07/26] chore(plugin): Add torchtune config arg. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/constants/constants.go | 26 ++++++++++++++ pkg/runtime/framework/plugins/torch/torch.go | 37 +++++++++++++++++--- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 0dc4bfaca5..01c0b8498f 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -150,13 +150,33 @@ const ( // TorchTuneFullFinetuneSingleDevice Recipe is the recipe for the single device full finetune. TorchTuneFullFinetuneSingleDevice string = "full_finetune_single_device" + // TorchTuneFullFinetuneSingleDeviceConfigSuffix is the config suffix for the single device full finetune. + TorchTuneFullFinetuneSingleDeviceConfigSuffix string = "_full_single_device" + // TorchTuneFullFinetuneDistributed Recipe is the recipe for the distributed full finetune. TorchTuneFullFinetuneDistributed string = "full_finetune_distributed" + // TorchTuneFullFinetuneMultiDevicesConfigSuffix is the config suffix for the single node distributed full finetune. + TorchTuneFullFinetuneMultiDevicesConfigSuffix string = "_full" + + // TorchTuneFullFinetuneMultiNodesConfigSuffix is the config suffix for the multi node distributed full finetune. + TorchTuneFullFinetuneMultiNodesConfigSuffix string = "_full_multinode" + // TorchTuneDefaultRecipe is the default recipe for the torchtune. TorchTuneDefaultRecipe string = TorchTuneFullFinetuneDistributed ) +const ( + // MODEL_LLAMA3_2_1B is the model name for the Llama3.2 1B Instruct model. + MODEL_LLAMA3_2_1B = "llama3_2/1B" + + // MODEL_LLAMA3_2_7B is the model name for the Llama3.2 7B Instruct model. + MODEL_LLAMA3_2_7B = "llama3_2/7B" + + // MODEL_LLAMA3_3_70B is the model name for the Llama3.3 70B Instruct model. + MODEL_LLAMA3_3_70B = "llama3_3/70B" +) + var ( // JobCompletionIndexFieldPath is the field path for the Job completion index annotation. JobCompletionIndexFieldPath string = fmt.Sprintf("metadata.annotations['%s']", batchv1.JobCompletionIndexAnnotation) @@ -164,6 +184,12 @@ var ( // Torchrun reserved env names TorchRunReservedEnvNames = sets.New(TorchEnvNumNodes, TorchEnvNumProcPerNode, TorchEnvNodeRank, TorchEnvMasterAddr, TorchEnvMasterPort) + // Currently supported TorchTune recipes. + TorchTuneSupportedRecipes = sets.New(TorchTuneFullFinetuneSingleDevice, TorchTuneFullFinetuneDistributed) + + // Currently supported pretrained models for TorchTuen Trainer. + TorchTuneSupportedPretrainedModels = sets.New(MODEL_LLAMA3_2_1B, MODEL_LLAMA3_2_7B, MODEL_LLAMA3_3_70B) + // TorchTuneEntrypoint is the entrypoint for the torchtune. TorchTuneEntrypoint = []string{"tune", "run"} ) diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index d03559ca23..6e947eb8cd 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -200,12 +200,14 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) ) // 2. Get the recipe and config from old args and append them to new args. - newArgs = append(newArgs, - getRecipeFromArgs(numNodes, numProcPerNode, oldArgs), - ) + recipe := getRecipeFromArgs(numNodes, numProcPerNode, oldArgs) + config := getConfigFileFromArgs(numNodes, recipe, oldArgs) + newArgs = append(newArgs, recipe, fmt.Sprintf("--config %s", config)) // 3. Reserve old arguments to override corresponding items in the config file. - newArgs = append(newArgs, oldArgs...) + newArgs = append(newArgs, slices.DeleteFunc(oldArgs, func(arg string) bool { + return strings.HasPrefix(arg, "model") + })...) trainerContainer.Args = newArgs } @@ -241,3 +243,30 @@ func getRecipeFromArgs(numNodes int32, numProcPerNode intstr.IntOrString, _ []st } return recipe } + +// getConfigFromArgs extracts the config from distributed parameters, recipe and command line arguments. +func getConfigFileFromArgs(numNodes int32, recipe string, args []string) string { + // Extract model from command line args. + model := constants.MODEL_LLAMA3_2_1B + for _, arg := range args { + if strings.HasPrefix(arg, "model") { + model = strings.Split(arg, "=")[1] + break + } + } + + // Determine the config file name based on the recipe and number of nodes. + var suffix string + switch recipe { + case constants.TorchTuneFullFinetuneDistributed: + if numNodes == 1 { + suffix = constants.TorchTuneFullFinetuneMultiDevicesConfigSuffix + } else { + suffix = constants.TorchTuneFullFinetuneMultiNodesConfigSuffix + } + case constants.TorchTuneFullFinetuneSingleDevice: + suffix = constants.TorchTuneFullFinetuneSingleDeviceConfigSuffix + } + + return fmt.Sprintf("%s%s.yaml", model, suffix) +} From ba55d4ce2279d71fbe7aa8118803572c523a101d Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Wed, 9 Apr 2025 08:23:29 +0000 Subject: [PATCH 08/26] chore(test): add UT for single-device full fine-tuning with torchtune. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/framework/plugins/torch/torch.go | 2 +- .../framework/plugins/torch/torch_test.go | 90 ++++++++++++++++++- 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index 6e947eb8cd..d03382cf36 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -175,7 +175,7 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) } else { // Mutate command line args for torchtune. // Ref: https://github.com/kubeflow/trainer/tree/master/docs/proposals/2401-llm-trainer-v2#complement-torch-plugin - oldArgs, newArgs := trainerContainer.Args, []string{} + oldArgs, newArgs := trainJob.Spec.Trainer.Args, []string{} // 1. Add PyTorch distributed command line args for torchtune. // TODO(Electronic-Waste): Add more args for torchtune if required. diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index 0b83845a02..e4ef555494 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1126,11 +1126,11 @@ func TestTorch(t *testing.T) { Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)}, }, }, - "pass distributed params to Args when using torchtune": { + "multi-devices full fine-tuning with torchtune": { trainJob: utiltesting.MakeTrainJobWrapper("default", "torchtune-job"). Trainer( utiltesting.MakeTrainJobTrainerWrapper(). - NumNodes(4). + NumNodes(1). NumProcPerNode(intstr.FromString("auto")). Container( "ghcr.io/kubeflow/trainer/torchtune-trainer", @@ -1140,6 +1140,7 @@ func TestTorch(t *testing.T) { "batch_size=32", "epochs=10", "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", + "model=llama3_2/1B", }, corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("8"), @@ -1175,15 +1176,96 @@ func TestTorch(t *testing.T) { PodSets: []runtime.PodSet{{ Name: constants.Node, Ancestor: ptr.To(constants.AncestorTrainer), - Count: ptr.To[int32](4), + Count: ptr.To[int32](1), SinglePodRequests: make(corev1.ResourceList), Containers: []runtime.Container{{ Name: constants.Node, Args: []string{ - fmt.Sprintf("%s %s", constants.TorchTuneArgNumNodes, "4"), + fmt.Sprintf("%s %s", constants.TorchTuneArgNumNodes, "1"), fmt.Sprintf("%s %s", constants.TorchTuneArgNumProcPerNode, "auto"), fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvId, "torchtune-job"), fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), + constants.TorchTuneFullFinetuneDistributed, + "--config llama3_2/1B_full.yaml", + "dtype=fp16", + "batch_size=32", + "epochs=10", + "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", + }, + Ports: []corev1ac.ContainerPortApplyConfiguration{{ + ContainerPort: ptr.To[int32](constants.ContainerTrainerPort), + }}, + }}, + }}, + }, + Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)}, + }, + }, + "single-device full fine-tuning with torchtune": { + trainJob: utiltesting.MakeTrainJobWrapper("default", "torchtune-job"). + Trainer( + utiltesting.MakeTrainJobTrainerWrapper(). + NumNodes(1). + NumProcPerNode(intstr.FromInt(1)). + Container( + "ghcr.io/kubeflow/trainer/torchtune-trainer", + []string{"tune", "run"}, + []string{ + "dtype=fp16", + "batch_size=32", + "epochs=10", + "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", + "model=llama3_2/1B", + }, + corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("8"), + corev1.ResourceMemory: resource.MustParse("16Gi"), + "nvidia.com/gpu": resource.MustParse("1"), // 1 GPU per node + }, + ). + Obj(), + ). + Obj(), + info: runtime.NewInfo( + runtime.WithMLPolicySource( + utiltesting.MakeMLPolicyWrapper(). + WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper(). + TorchPolicy(ptr.To(intstr.FromString("auto")), nil). + Obj(), + ). + Obj(), + ), + runtime.WithPodSet(constants.Node, ptr.To(constants.AncestorTrainer), 1, corev1.PodSpec{}, corev1ac.PodSpec(). + WithContainers(corev1ac.Container().WithName(constants.Node)), + ), + ), + wantInfo: &runtime.Info{ + Labels: make(map[string]string), + Annotations: make(map[string]string), + RuntimePolicy: runtime.RuntimePolicy{ + MLPolicySource: utiltesting.MakeMLPolicySourceWrapper(). + TorchPolicy(ptr.To(intstr.FromString("auto")), nil). + Obj(), + }, + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{{ + Name: constants.Node, + Ancestor: ptr.To(constants.AncestorTrainer), + Count: ptr.To[int32](1), + SinglePodRequests: make(corev1.ResourceList), + Containers: []runtime.Container{{ + Name: constants.Node, + Args: []string{ + fmt.Sprintf("%s %s", constants.TorchTuneArgNumNodes, "1"), + fmt.Sprintf("%s %s", constants.TorchTuneArgNumProcPerNode, "1"), + fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvId, "torchtune-job"), + fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), + constants.TorchTuneFullFinetuneSingleDevice, + "--config llama3_2/1B_full_single_device.yaml", + "dtype=fp16", + "batch_size=32", + "epochs=10", + "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", }, Ports: []corev1ac.ContainerPortApplyConfiguration{{ ContainerPort: ptr.To[int32](constants.ContainerTrainerPort), From 206822ed09106c986d89a2b267979e71e05fbcad Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Wed, 9 Apr 2025 08:29:41 +0000 Subject: [PATCH 09/26] chore(test): Add test for multi-nodes full fine-tuning with torchtune. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- .../framework/plugins/torch/torch_test.go | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index e4ef555494..7ebbff1ec9 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1276,6 +1276,81 @@ func TestTorch(t *testing.T) { Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)}, }, }, + "multi-nodes full fine-tuning with torchtune": { + trainJob: utiltesting.MakeTrainJobWrapper("default", "torchtune-job"). + Trainer( + utiltesting.MakeTrainJobTrainerWrapper(). + NumNodes(2). + NumProcPerNode(intstr.FromInt(8)). + Container( + "ghcr.io/kubeflow/trainer/torchtune-trainer", + []string{"tune", "run"}, + []string{ + "dtype=fp16", + "batch_size=32", + "epochs=10", + "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", + "model=llama3_3/70B", + }, + corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("8"), + corev1.ResourceMemory: resource.MustParse("16Gi"), + "nvidia.com/gpu": resource.MustParse("8"), // 8 GPUs per node + }, + ). + Obj(), + ). + Obj(), + info: runtime.NewInfo( + runtime.WithMLPolicySource( + utiltesting.MakeMLPolicyWrapper(). + WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper(). + TorchPolicy(ptr.To(intstr.FromString("auto")), nil). + Obj(), + ). + Obj(), + ), + runtime.WithPodSet(constants.Node, ptr.To(constants.AncestorTrainer), 1, corev1.PodSpec{}, corev1ac.PodSpec(). + WithContainers(corev1ac.Container().WithName(constants.Node)), + ), + ), + wantInfo: &runtime.Info{ + Labels: make(map[string]string), + Annotations: make(map[string]string), + RuntimePolicy: runtime.RuntimePolicy{ + MLPolicySource: utiltesting.MakeMLPolicySourceWrapper(). + TorchPolicy(ptr.To(intstr.FromString("auto")), nil). + Obj(), + }, + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{{ + Name: constants.Node, + Ancestor: ptr.To(constants.AncestorTrainer), + Count: ptr.To[int32](2), + SinglePodRequests: make(corev1.ResourceList), + Containers: []runtime.Container{{ + Name: constants.Node, + Args: []string{ + fmt.Sprintf("%s %s", constants.TorchTuneArgNumNodes, "2"), + fmt.Sprintf("%s %s", constants.TorchTuneArgNumProcPerNode, "8"), + fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvId, "torchtune-job"), + fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), + constants.TorchTuneFullFinetuneDistributed, + "--config llama3_3/70B_full_multinode.yaml", + "dtype=fp16", + "batch_size=32", + "epochs=10", + "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", + }, + Ports: []corev1ac.ContainerPortApplyConfiguration{{ + ContainerPort: ptr.To[int32](constants.ContainerTrainerPort), + }}, + }}, + }}, + }, + Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)}, + }, + }, } for name, tc := range cases { From a29dddc76fcbe7856823aca91b58ee0585e2d21f Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Wed, 9 Apr 2025 09:33:24 +0000 Subject: [PATCH 10/26] chore(test): Update torch validate UTs. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/framework/plugins/torch/torch.go | 36 +++++++---- .../framework/plugins/torch/torch_test.go | 62 ++++++++++++++++++- 2 files changed, 85 insertions(+), 13 deletions(-) diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index d03382cf36..f541a0d039 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -71,9 +71,8 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob) } } - // Check reserved envs for torchrun. - // TODO(Electronic-Waste): Add validation for torchtune args. if !slices.Equal(newObj.Spec.Trainer.Command, constants.TorchTuneEntrypoint) { + // Check reserved envs for torchrun. torchEnvs := sets.New[string]() for _, env := range newObj.Spec.Trainer.Env { if constants.TorchRunReservedEnvNames.Has(env.Name) { @@ -85,6 +84,17 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob) trainerEnvsPath := specPath.Child("trainer").Child("env") allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("must not have reserved envs, invalid envs configured: %v", sets.List(torchEnvs)))) } + } else { + // Check supported pretrained models for torchtune. + // TODO(Electronic-Waste): Add more validation for torchtune when we support more arguments. + argPath := specPath.Child("trainer").Child("args") + model := getModelFromArgs(newObj.Spec.Trainer.Args) + + if model == nil { + allErrs = append(allErrs, field.Invalid(argPath, newObj.Spec.Trainer.Args, "must specify a pretrained model")) + } else if !constants.TorchTuneSupportedPretrainedModels.Has(*model) { + allErrs = append(allErrs, field.Invalid(argPath, newObj.Spec.Trainer.Args, fmt.Sprintf("must have a supported pretrained model, invalid model configured: %v", *model))) + } } } @@ -246,15 +256,6 @@ func getRecipeFromArgs(numNodes int32, numProcPerNode intstr.IntOrString, _ []st // getConfigFromArgs extracts the config from distributed parameters, recipe and command line arguments. func getConfigFileFromArgs(numNodes int32, recipe string, args []string) string { - // Extract model from command line args. - model := constants.MODEL_LLAMA3_2_1B - for _, arg := range args { - if strings.HasPrefix(arg, "model") { - model = strings.Split(arg, "=")[1] - break - } - } - // Determine the config file name based on the recipe and number of nodes. var suffix string switch recipe { @@ -268,5 +269,16 @@ func getConfigFileFromArgs(numNodes int32, recipe string, args []string) string suffix = constants.TorchTuneFullFinetuneSingleDeviceConfigSuffix } - return fmt.Sprintf("%s%s.yaml", model, suffix) + return fmt.Sprintf("%s%s.yaml", *getModelFromArgs(args), suffix) +} + +func getModelFromArgs(args []string) *string { + var model *string + for _, arg := range args { + if strings.HasPrefix(arg, "model") { + model = &strings.Split(arg, "=")[1] + break + } + } + return model } diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index 7ebbff1ec9..aeda40ceb4 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1568,7 +1568,8 @@ func TestValidate(t *testing.T) { Container( "ghcr.io/kubeflow/trainer/torchtune-trainer", []string{"tune", "run"}, - nil, corev1.ResourceList{}, + []string{"model=llama3_2/1B"}, + corev1.ResourceList{}, ). Env( []corev1.EnvVar{ @@ -1586,6 +1587,65 @@ func TestValidate(t *testing.T) { ). Obj(), }, + "missing pretrained model": { + info: runtime.NewInfo( + runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper(). + WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper(). + TorchPolicy(ptr.To(intstr.FromString("auto")), nil). + Obj(), + ). + Obj(), + ), + ), + newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper(). + NumProcPerNode(intstr.FromString("auto")). + Container( + "ghcr.io/kubeflow/trainer/torchtune-trainer", + []string{"tune", "run"}, + nil, corev1.ResourceList{}, + ). + Obj(), + ). + Obj(), + wantError: field.ErrorList{ + field.Invalid( + field.NewPath("spec").Child("trainer").Child("args"), + []string(nil), + "must specify a pretrained model", + ), + }, + }, + "unsupported pretrained model": { + info: runtime.NewInfo( + runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper(). + WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper(). + TorchPolicy(ptr.To(intstr.FromString("auto")), nil). + Obj(), + ). + Obj(), + ), + ), + newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper(). + NumProcPerNode(intstr.FromString("auto")). + Container( + "ghcr.io/kubeflow/trainer/torchtune-trainer", + []string{"tune", "run"}, + []string{"model=llama3_1/70B"}, + corev1.ResourceList{}, + ). + Obj(), + ). + Obj(), + wantError: field.ErrorList{ + field.Invalid( + field.NewPath("spec").Child("trainer").Child("args"), + []string{"model=llama3_1/70B"}, + fmt.Sprintf("must have a supported pretrained model, invalid model configured: %s", "llama3_1/70B"), + ), + }, + }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { From a9f993a43b427ff6b4b8e7d91e7cdaf9cf41703f Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Wed, 9 Apr 2025 09:36:21 +0000 Subject: [PATCH 11/26] fix(lint): fix lint error. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/runtime.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index c3496d0f56..80af27a324 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -80,7 +80,7 @@ type PodSet struct { type Container struct { Name string - Args []string + Args []string Env []corev1ac.EnvVarApplyConfiguration Ports []corev1ac.ContainerPortApplyConfiguration VolumeMounts []corev1ac.VolumeMountApplyConfiguration From 7eee7a3610ef93679919cd2619d7bc85fe2b401d Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Thu, 10 Apr 2025 06:57:07 +0000 Subject: [PATCH 12/26] fix(sdk): remove pretrained model enum type in sdk. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- sdk/kubeflow/trainer/__init__.py | 1 - sdk/kubeflow/trainer/types/types.py | 8 +------- sdk/kubeflow/trainer/utils/utils.py | 6 ------ 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/sdk/kubeflow/trainer/__init__.py b/sdk/kubeflow/trainer/__init__.py index a8950a9557..9d07ea282e 100644 --- a/sdk/kubeflow/trainer/__init__.py +++ b/sdk/kubeflow/trainer/__init__.py @@ -33,7 +33,6 @@ HuggingFaceModelInitializer, Initializer, Loss, - PretrainedModel, Runtime, Trainer, TrainerType, diff --git a/sdk/kubeflow/trainer/types/types.py b/sdk/kubeflow/trainer/types/types.py index f43f3ee22e..f36b0fd05b 100644 --- a/sdk/kubeflow/trainer/types/types.py +++ b/sdk/kubeflow/trainer/types/types.py @@ -115,12 +115,6 @@ class Framework(Enum): TORCHTUNE = "torchtune" -class PretrainedModel(Enum): - LLAMA3_2_1B = "llama3_2/1B" - LLAMA3_2_3B = "llama3_2/3B" - LLAMA3_3_70B = "llama3_3/70B" - - # Representation for the Trainer of the runtime. @dataclass class Trainer: @@ -136,7 +130,7 @@ class Trainer: class Runtime: name: str trainer: Trainer - pretrained_model: Optional[PretrainedModel] = None + pretrained_model: Optional[str] = None # Representation for the TrainJob steps. diff --git a/sdk/kubeflow/trainer/utils/utils.py b/sdk/kubeflow/trainer/utils/utils.py index be142a8fc0..4e5ebdbb10 100644 --- a/sdk/kubeflow/trainer/utils/utils.py +++ b/sdk/kubeflow/trainer/utils/utils.py @@ -347,12 +347,6 @@ def get_args_using_torchtune_config( if fine_tuning_config.loss: args.append(f"loss={fine_tuning_config.loss}") - # Provide pre-trained model information. - # TODO(Electronic-Waste): Move pre-trained model information to the runtime API fields. - # Ref: https://github.com/kubeflow/trainer/pull/2410#pullrequestreview-2672356400 - if runtime.pretrained_model: - args.append(f"model={runtime.pretrained_model}") - return constants.DEFAULT_TORCHTUNE_COMMAND, args From 723673202a77055c02cae546ed691f7d0372ded4 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Thu, 10 Apr 2025 07:19:19 +0000 Subject: [PATCH 13/26] fix(plugin): retrieve model name from runtimeRef. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/framework/plugins/torch/torch.go | 32 +++++----- .../framework/plugins/torch/torch_test.go | 60 +++++++------------ 2 files changed, 37 insertions(+), 55 deletions(-) diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index f541a0d039..1a8e934a74 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -87,13 +87,11 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob) } else { // Check supported pretrained models for torchtune. // TODO(Electronic-Waste): Add more validation for torchtune when we support more arguments. - argPath := specPath.Child("trainer").Child("args") - model := getModelFromArgs(newObj.Spec.Trainer.Args) + runtimeRefNamePath := specPath.Child("runtimeRef").Child("name") + model := getModelFromRuntimeRef(newObj.Spec.RuntimeRef.Name) - if model == nil { - allErrs = append(allErrs, field.Invalid(argPath, newObj.Spec.Trainer.Args, "must specify a pretrained model")) - } else if !constants.TorchTuneSupportedPretrainedModels.Has(*model) { - allErrs = append(allErrs, field.Invalid(argPath, newObj.Spec.Trainer.Args, fmt.Sprintf("must have a supported pretrained model, invalid model configured: %v", *model))) + if !constants.TorchTuneSupportedPretrainedModels.Has(model) { + allErrs = append(allErrs, field.Invalid(runtimeRefNamePath, newObj.Spec.RuntimeRef.Name, fmt.Sprintf("must have a supported pretrained model, invalid model configured: %v", model))) } } } @@ -211,7 +209,7 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) // 2. Get the recipe and config from old args and append them to new args. recipe := getRecipeFromArgs(numNodes, numProcPerNode, oldArgs) - config := getConfigFileFromArgs(numNodes, recipe, oldArgs) + config := getConfigFileFromArgs(numNodes, recipe, trainJob.Spec.RuntimeRef.Name) newArgs = append(newArgs, recipe, fmt.Sprintf("--config %s", config)) // 3. Reserve old arguments to override corresponding items in the config file. @@ -254,8 +252,8 @@ func getRecipeFromArgs(numNodes int32, numProcPerNode intstr.IntOrString, _ []st return recipe } -// getConfigFromArgs extracts the config from distributed parameters, recipe and command line arguments. -func getConfigFileFromArgs(numNodes int32, recipe string, args []string) string { +// getConfigFromArgs extracts the config from distributed parameters, recipe and runtime reference name. +func getConfigFileFromArgs(numNodes int32, recipe, runtimeRefName string) string { // Determine the config file name based on the recipe and number of nodes. var suffix string switch recipe { @@ -269,16 +267,14 @@ func getConfigFileFromArgs(numNodes int32, recipe string, args []string) string suffix = constants.TorchTuneFullFinetuneSingleDeviceConfigSuffix } - return fmt.Sprintf("%s%s.yaml", *getModelFromArgs(args), suffix) + return fmt.Sprintf("%s%s.yaml", getModelFromRuntimeRef(runtimeRefName), suffix) } -func getModelFromArgs(args []string) *string { - var model *string - for _, arg := range args { - if strings.HasPrefix(arg, "model") { - model = &strings.Split(arg, "=")[1] - break - } +func getModelFromRuntimeRef(runtimeRefName string) string { + fields := strings.Split(runtimeRefName, "-") + if len(fields) != 3 { + return "" } - return model + + return fmt.Sprintf("%s/%s", strings.ReplaceAll(fields[1], ".", "_"), strings.ToUpper(fields[2])) } diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index aeda40ceb4..da26a3e296 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1140,7 +1140,6 @@ func TestTorch(t *testing.T) { "batch_size=32", "epochs=10", "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", - "model=llama3_2/1B", }, corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("8"), @@ -1150,6 +1149,10 @@ func TestTorch(t *testing.T) { ). Obj(), ). + RuntimeRef( + trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), + "torchtune-llama3.2-1b", + ). Obj(), info: runtime.NewInfo( runtime.WithMLPolicySource( @@ -1215,7 +1218,6 @@ func TestTorch(t *testing.T) { "batch_size=32", "epochs=10", "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", - "model=llama3_2/1B", }, corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("8"), @@ -1225,6 +1227,10 @@ func TestTorch(t *testing.T) { ). Obj(), ). + RuntimeRef( + trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), + "torchtune-llama3.2-1b", + ). Obj(), info: runtime.NewInfo( runtime.WithMLPolicySource( @@ -1290,7 +1296,6 @@ func TestTorch(t *testing.T) { "batch_size=32", "epochs=10", "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", - "model=llama3_3/70B", }, corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("8"), @@ -1300,6 +1305,10 @@ func TestTorch(t *testing.T) { ). Obj(), ). + RuntimeRef( + trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), + "torchtune-llama3.3-70b", + ). Obj(), info: runtime.NewInfo( runtime.WithMLPolicySource( @@ -1568,8 +1577,7 @@ func TestValidate(t *testing.T) { Container( "ghcr.io/kubeflow/trainer/torchtune-trainer", []string{"tune", "run"}, - []string{"model=llama3_2/1B"}, - corev1.ResourceList{}, + nil, corev1.ResourceList{}, ). Env( []corev1.EnvVar{ @@ -1585,36 +1593,11 @@ func TestValidate(t *testing.T) { ). Obj(), ). - Obj(), - }, - "missing pretrained model": { - info: runtime.NewInfo( - runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper(). - WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper(). - TorchPolicy(ptr.To(intstr.FromString("auto")), nil). - Obj(), - ). - Obj(), - ), - ), - newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test"). - Trainer(utiltesting.MakeTrainJobTrainerWrapper(). - NumProcPerNode(intstr.FromString("auto")). - Container( - "ghcr.io/kubeflow/trainer/torchtune-trainer", - []string{"tune", "run"}, - nil, corev1.ResourceList{}, - ). - Obj(), + RuntimeRef( + trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), + "torchtune-llama3.2-1b", ). Obj(), - wantError: field.ErrorList{ - field.Invalid( - field.NewPath("spec").Child("trainer").Child("args"), - []string(nil), - "must specify a pretrained model", - ), - }, }, "unsupported pretrained model": { info: runtime.NewInfo( @@ -1632,16 +1615,19 @@ func TestValidate(t *testing.T) { Container( "ghcr.io/kubeflow/trainer/torchtune-trainer", []string{"tune", "run"}, - []string{"model=llama3_1/70B"}, - corev1.ResourceList{}, + nil, corev1.ResourceList{}, ). Obj(), ). + RuntimeRef( + trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), + "torchtune-llama3.1-70b", + ). Obj(), wantError: field.ErrorList{ field.Invalid( - field.NewPath("spec").Child("trainer").Child("args"), - []string{"model=llama3_1/70B"}, + field.NewPath("spec").Child("runtimeRef").Child("name"), + "torchtune-llama3.1-70b", fmt.Sprintf("must have a supported pretrained model, invalid model configured: %s", "llama3_1/70B"), ), }, From cabd1457e53235d2f7a9a21b5aa69438a8eae4d2 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Thu, 10 Apr 2025 08:14:17 +0000 Subject: [PATCH 14/26] fix(lint): fix typo. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/constants/constants.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 01c0b8498f..14af374523 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -187,7 +187,7 @@ var ( // Currently supported TorchTune recipes. TorchTuneSupportedRecipes = sets.New(TorchTuneFullFinetuneSingleDevice, TorchTuneFullFinetuneDistributed) - // Currently supported pretrained models for TorchTuen Trainer. + // Currently supported pretrained models for TorchTune Trainer. TorchTuneSupportedPretrainedModels = sets.New(MODEL_LLAMA3_2_1B, MODEL_LLAMA3_2_7B, MODEL_LLAMA3_3_70B) // TorchTuneEntrypoint is the entrypoint for the torchtune. From 824cb259a1cf42751eca3a74a4b3e80b47dfedcd Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Fri, 11 Apr 2025 02:51:01 +0000 Subject: [PATCH 15/26] fix(plugin): make some adjustments according to the review. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/constants/constants.go | 17 +++++++---------- pkg/runtime/framework/plugins/torch/torch.go | 4 +--- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 14af374523..90f06b146b 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -167,14 +167,14 @@ const ( ) const ( - // MODEL_LLAMA3_2_1B is the model name for the Llama3.2 1B Instruct model. - MODEL_LLAMA3_2_1B = "llama3_2/1B" + // TORCHTUNE_MODEL_LLAMA3_2_1B is the model name for the Llama3.2 1B Instruct model. + TORCHTUNE_MODEL_LLAMA3_2_1B = "llama3_2/1B" - // MODEL_LLAMA3_2_7B is the model name for the Llama3.2 7B Instruct model. - MODEL_LLAMA3_2_7B = "llama3_2/7B" + // TORCHTUNE_MODEL_LLAMA3_2_7B is the model name for the Llama3.2 7B Instruct model. + TORCHTUNE_MODEL_LLAMA3_2_7B = "llama3_2/7B" - // MODEL_LLAMA3_3_70B is the model name for the Llama3.3 70B Instruct model. - MODEL_LLAMA3_3_70B = "llama3_3/70B" + // TORCHTUNE_MODEL_LLAMA3_3_70B is the model name for the Llama3.3 70B Instruct model. + TORCHTUNE_MODEL_LLAMA3_3_70B = "llama3_3/70B" ) var ( @@ -184,11 +184,8 @@ var ( // Torchrun reserved env names TorchRunReservedEnvNames = sets.New(TorchEnvNumNodes, TorchEnvNumProcPerNode, TorchEnvNodeRank, TorchEnvMasterAddr, TorchEnvMasterPort) - // Currently supported TorchTune recipes. - TorchTuneSupportedRecipes = sets.New(TorchTuneFullFinetuneSingleDevice, TorchTuneFullFinetuneDistributed) - // Currently supported pretrained models for TorchTune Trainer. - TorchTuneSupportedPretrainedModels = sets.New(MODEL_LLAMA3_2_1B, MODEL_LLAMA3_2_7B, MODEL_LLAMA3_3_70B) + TorchTuneSupportedPretrainedModels = sets.New(TORCHTUNE_MODEL_LLAMA3_2_1B, TORCHTUNE_MODEL_LLAMA3_2_7B, TORCHTUNE_MODEL_LLAMA3_3_70B) // TorchTuneEntrypoint is the entrypoint for the torchtune. TorchTuneEntrypoint = []string{"tune", "run"} diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index 1a8e934a74..dd888464b3 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -213,9 +213,7 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) newArgs = append(newArgs, recipe, fmt.Sprintf("--config %s", config)) // 3. Reserve old arguments to override corresponding items in the config file. - newArgs = append(newArgs, slices.DeleteFunc(oldArgs, func(arg string) bool { - return strings.HasPrefix(arg, "model") - })...) + newArgs = append(newArgs, oldArgs...) trainerContainer.Args = newArgs } From 0d1d3a23f0724e6b6faa2baf24b67555c34fb982 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Sat, 12 Apr 2025 03:38:17 +0000 Subject: [PATCH 16/26] fix(sdk): remove runtime in get_trainer_crd_from_builtin_trainer. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- sdk/kubeflow/trainer/utils/utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sdk/kubeflow/trainer/utils/utils.py b/sdk/kubeflow/trainer/utils/utils.py index 4e5ebdbb10..e83e7f67a3 100644 --- a/sdk/kubeflow/trainer/utils/utils.py +++ b/sdk/kubeflow/trainer/utils/utils.py @@ -320,7 +320,6 @@ def get_entrypoint_using_train_func( def get_args_using_torchtune_config( - runtime: types.Runtime, fine_tuning_config: types.TorchTuneConfig, ) -> Tuple[List[str], List[str]]: """ @@ -385,7 +384,6 @@ def get_trainer_crd_from_custom_trainer( def get_trainer_crd_from_builtin_trainer( trainer: types.BuiltinTrainer, - runtime: types.Runtime, ) -> models.TrainerV1alpha1Trainer: """ Get the Trainer CRD from the builtin trainer. @@ -408,9 +406,7 @@ def get_trainer_crd_from_builtin_trainer( # Parse args in the TorchTuneConfig to the Trainer, preparing for the mutation of # the torchtune config in the runtime plugin. # Ref:https://github.com/kubeflow/trainer/tree/master/docs/proposals/2401-llm-trainer-v2 - trainer_crd.command, trainer_crd.args = get_args_using_torchtune_config( - runtime, trainer - ) + trainer_crd.command, trainer_crd.args = get_args_using_torchtune_config(trainer) return trainer_crd From d9c8b7c119273545c09319e781bf966dfb9d9de6 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Wed, 16 Apr 2025 03:18:22 +0000 Subject: [PATCH 17/26] fix(plugin): pass PET_ env variables in torch plugin for torchtune. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/constants/constants.go | 9 --- pkg/runtime/framework/plugins/torch/torch.go | 49 ++++++--------- .../framework/plugins/torch/torch_test.go | 63 ++++++++++++++++--- 3 files changed, 74 insertions(+), 47 deletions(-) diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 90f06b146b..e00792a40d 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -135,15 +135,6 @@ const ( // TorchEnvMasterPort is the env name for the master node port. TorchEnvMasterPort string = "PET_MASTER_PORT" - // TochTuneArgNumNodes is the arg anme for the number of training nodes. - TorchTuneArgNumNodes string = "--nnodes" - - // TorchTuneArgNumProcPerNode is the arg name for the number of procs per node (e.g. number of GPUs per Pod). - TorchTuneArgNumProcPerNode string = "--nproc_per_node" - - // TorchTuneArgRdzvId is the arg name for the rendezvous ID. - TorchTuneArgRdzvId string = "--rdzv_id" - // TorchTuneArgRdzvEndpoint is the arg name for the rendezvous endpoint. TorchTuneArgRdzvEndpoint string = "--rdzv_endpoint" diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index dd888464b3..8e90b0cecf 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -157,22 +157,26 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) } } if trainerContainer != nil { + // Add PyTorch distributed "PET_" values for torchrun and torchtune. + // TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other. + // Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940 + apply.UpsertEnvVar(&trainerContainer.Env, + *corev1ac.EnvVar(). + WithName(constants.TorchEnvNumNodes). + WithValue(fmt.Sprintf("%d", ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1))), + *corev1ac.EnvVar(). + WithName(constants.TorchEnvNumProcPerNode). + WithValue(numProcPerNode.String()), + *corev1ac.EnvVar(). + WithName(constants.TorchEnvNodeRank). + WithValueFrom(corev1ac.EnvVarSource(). + WithFieldRef(corev1ac.ObjectFieldSelector(). + WithFieldPath(constants.JobCompletionIndexFieldPath))), + ) + if !slices.Equal(trainJob.Spec.Trainer.Command, constants.TorchTuneEntrypoint) { - // Add PyTorch distributed "PET_" values for torchrun. - // TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other. - // Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940 + // Add PET_MASTER_ADDR and PET_MASTER_PORT envs for torchrun. apply.UpsertEnvVar(&trainerContainer.Env, - *corev1ac.EnvVar(). - WithName(constants.TorchEnvNumNodes). - WithValue(fmt.Sprintf("%d", ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1))), - *corev1ac.EnvVar(). - WithName(constants.TorchEnvNumProcPerNode). - WithValue(numProcPerNode.String()), - *corev1ac.EnvVar(). - WithName(constants.TorchEnvNodeRank). - WithValueFrom(corev1ac.EnvVarSource(). - WithFieldRef(corev1ac.ObjectFieldSelector(). - WithFieldPath(constants.JobCompletionIndexFieldPath))), *corev1ac.EnvVar(). WithName(constants.TorchEnvMasterAddr). WithValue(fmt.Sprintf("%s-%s-0-0.%s", trainJob.Name, constants.Node, trainJob.Name)), @@ -185,22 +189,8 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) // Ref: https://github.com/kubeflow/trainer/tree/master/docs/proposals/2401-llm-trainer-v2#complement-torch-plugin oldArgs, newArgs := trainJob.Spec.Trainer.Args, []string{} - // 1. Add PyTorch distributed command line args for torchtune. - // TODO(Electronic-Waste): Add more args for torchtune if required. - numNodes := ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1) + // 1. Add rendezvous backend arg for torchtune. newArgs = append(newArgs, - fmt.Sprintf("%s %d", - constants.TorchTuneArgNumNodes, - numNodes, - ), - fmt.Sprintf("%s %s", - constants.TorchTuneArgNumProcPerNode, - numProcPerNode.String(), - ), - fmt.Sprintf("%s %s", - constants.TorchTuneArgRdzvId, - trainJob.Name, - ), fmt.Sprintf("%s %s-%s-0-0.%s:%d", constants.TorchTuneArgRdzvEndpoint, trainJob.Name, constants.Node, trainJob.Name, constants.ContainerTrainerPort, @@ -208,6 +198,7 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) ) // 2. Get the recipe and config from old args and append them to new args. + numNodes := ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1) recipe := getRecipeFromArgs(numNodes, numProcPerNode, oldArgs) config := getConfigFileFromArgs(numNodes, recipe, trainJob.Spec.RuntimeRef.Name) newArgs = append(newArgs, recipe, fmt.Sprintf("--config %s", config)) diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index da26a3e296..048e7d92cd 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1184,9 +1184,6 @@ func TestTorch(t *testing.T) { Containers: []runtime.Container{{ Name: constants.Node, Args: []string{ - fmt.Sprintf("%s %s", constants.TorchTuneArgNumNodes, "1"), - fmt.Sprintf("%s %s", constants.TorchTuneArgNumProcPerNode, "auto"), - fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvId, "torchtune-job"), fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), constants.TorchTuneFullFinetuneDistributed, "--config llama3_2/1B_full.yaml", @@ -1195,6 +1192,24 @@ func TestTorch(t *testing.T) { "epochs=10", "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", }, + Env: []corev1ac.EnvVarApplyConfiguration{ + { + Name: ptr.To(constants.TorchEnvNumNodes), + Value: ptr.To("1"), + }, + { + Name: ptr.To(constants.TorchEnvNumProcPerNode), + Value: ptr.To("auto"), + }, + { + Name: ptr.To(constants.TorchEnvNodeRank), + ValueFrom: &corev1ac.EnvVarSourceApplyConfiguration{ + FieldRef: &corev1ac.ObjectFieldSelectorApplyConfiguration{ + FieldPath: ptr.To(constants.JobCompletionIndexFieldPath), + }, + }, + }, + }, Ports: []corev1ac.ContainerPortApplyConfiguration{{ ContainerPort: ptr.To[int32](constants.ContainerTrainerPort), }}, @@ -1262,9 +1277,6 @@ func TestTorch(t *testing.T) { Containers: []runtime.Container{{ Name: constants.Node, Args: []string{ - fmt.Sprintf("%s %s", constants.TorchTuneArgNumNodes, "1"), - fmt.Sprintf("%s %s", constants.TorchTuneArgNumProcPerNode, "1"), - fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvId, "torchtune-job"), fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), constants.TorchTuneFullFinetuneSingleDevice, "--config llama3_2/1B_full_single_device.yaml", @@ -1273,6 +1285,24 @@ func TestTorch(t *testing.T) { "epochs=10", "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", }, + Env: []corev1ac.EnvVarApplyConfiguration{ + { + Name: ptr.To(constants.TorchEnvNumNodes), + Value: ptr.To("1"), + }, + { + Name: ptr.To(constants.TorchEnvNumProcPerNode), + Value: ptr.To("1"), + }, + { + Name: ptr.To(constants.TorchEnvNodeRank), + ValueFrom: &corev1ac.EnvVarSourceApplyConfiguration{ + FieldRef: &corev1ac.ObjectFieldSelectorApplyConfiguration{ + FieldPath: ptr.To(constants.JobCompletionIndexFieldPath), + }, + }, + }, + }, Ports: []corev1ac.ContainerPortApplyConfiguration{{ ContainerPort: ptr.To[int32](constants.ContainerTrainerPort), }}, @@ -1340,9 +1370,6 @@ func TestTorch(t *testing.T) { Containers: []runtime.Container{{ Name: constants.Node, Args: []string{ - fmt.Sprintf("%s %s", constants.TorchTuneArgNumNodes, "2"), - fmt.Sprintf("%s %s", constants.TorchTuneArgNumProcPerNode, "8"), - fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvId, "torchtune-job"), fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), constants.TorchTuneFullFinetuneDistributed, "--config llama3_3/70B_full_multinode.yaml", @@ -1351,6 +1378,24 @@ func TestTorch(t *testing.T) { "epochs=10", "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", }, + Env: []corev1ac.EnvVarApplyConfiguration{ + { + Name: ptr.To(constants.TorchEnvNumNodes), + Value: ptr.To("2"), + }, + { + Name: ptr.To(constants.TorchEnvNumProcPerNode), + Value: ptr.To("8"), + }, + { + Name: ptr.To(constants.TorchEnvNodeRank), + ValueFrom: &corev1ac.EnvVarSourceApplyConfiguration{ + FieldRef: &corev1ac.ObjectFieldSelectorApplyConfiguration{ + FieldPath: ptr.To(constants.JobCompletionIndexFieldPath), + }, + }, + }, + }, Ports: []corev1ac.ContainerPortApplyConfiguration{{ ContainerPort: ptr.To[int32](constants.ContainerTrainerPort), }}, From 3b22cf7c31b79b7484c9254df119e6e8a91c1bc9 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Wed, 16 Apr 2025 03:32:35 +0000 Subject: [PATCH 18/26] fix(plugin): add env validation for torchtune. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/framework/plugins/torch/torch.go | 28 +++++++++---------- .../framework/plugins/torch/torch_test.go | 22 ++++++++++++++- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index 8e90b0cecf..1126fd7d3d 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -71,22 +71,22 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob) } } - if !slices.Equal(newObj.Spec.Trainer.Command, constants.TorchTuneEntrypoint) { - // Check reserved envs for torchrun. - torchEnvs := sets.New[string]() - for _, env := range newObj.Spec.Trainer.Env { - if constants.TorchRunReservedEnvNames.Has(env.Name) { - torchEnvs.Insert(env.Name) - } + // Check reserved envs for torchrun. + torchEnvs := sets.New[string]() + for _, env := range newObj.Spec.Trainer.Env { + if constants.TorchRunReservedEnvNames.Has(env.Name) { + torchEnvs.Insert(env.Name) } + } - if torchEnvs.Len() > 0 { - trainerEnvsPath := specPath.Child("trainer").Child("env") - allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("must not have reserved envs, invalid envs configured: %v", sets.List(torchEnvs)))) - } - } else { - // Check supported pretrained models for torchtune. - // TODO(Electronic-Waste): Add more validation for torchtune when we support more arguments. + if torchEnvs.Len() > 0 { + trainerEnvsPath := specPath.Child("trainer").Child("env") + allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("must not have reserved envs, invalid envs configured: %v", sets.List(torchEnvs)))) + } + + // Check supported pretrained models for torchtune. + // TODO(Electronic-Waste): Add more validation for torchtune when we support more arguments. + if slices.Equal(newObj.Spec.Trainer.Command, constants.TorchTuneEntrypoint) { runtimeRefNamePath := specPath.Child("runtimeRef").Child("name") model := getModelFromRuntimeRef(newObj.Spec.RuntimeRef.Name) diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index 048e7d92cd..15c3b6c1a2 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1606,7 +1606,7 @@ func TestValidate(t *testing.T) { ), }, }, - "no reserved environment variable for torchtune": { + "reserved environment variable for torchtune": { info: runtime.NewInfo( runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper(). WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper(). @@ -1643,6 +1643,26 @@ func TestValidate(t *testing.T) { "torchtune-llama3.2-1b", ). Obj(), + wantError: field.ErrorList{ + field.Invalid( + field.NewPath("spec").Child("trainer").Child("env"), + []corev1.EnvVar{ + { + Name: "test", + Value: "value", + }, + { + Name: constants.TorchEnvNumProcPerNode, + Value: "value", + }, + }, + fmt.Sprintf("must not have reserved envs, invalid envs configured: %v", func() []string { + torchEnvs := sets.New[string]() + torchEnvs.Insert(constants.TorchEnvNumProcPerNode) + return sets.List(torchEnvs) + }()), + ), + }, }, "unsupported pretrained model": { info: runtime.NewInfo( From ccbbefbad992beb7b97cf2f9cb8766100fefd95c Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Wed, 16 Apr 2025 03:42:04 +0000 Subject: [PATCH 19/26] fix(plugin): update comments. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/framework/plugins/torch/torch.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index 1126fd7d3d..0bb1c054de 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -71,7 +71,7 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob) } } - // Check reserved envs for torchrun. + // Check reserved envs. torchEnvs := sets.New[string]() for _, env := range newObj.Spec.Trainer.Env { if constants.TorchRunReservedEnvNames.Has(env.Name) { From d5d347afb893753029e177f11bc7faa442473cbb Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Tue, 22 Apr 2025 02:46:12 +0000 Subject: [PATCH 20/26] fix(plugins): fix the implementation according to the review. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/apply/apply.go | 4 + pkg/constants/constants.go | 3 - pkg/runtime/core/trainingruntime.go | 4 + pkg/runtime/framework/plugins/torch/torch.go | 33 +++----- .../framework/plugins/torch/torch_test.go | 82 ++----------------- pkg/runtime/runtime.go | 3 +- 6 files changed, 31 insertions(+), 98 deletions(-) diff --git a/pkg/apply/apply.go b/pkg/apply/apply.go index 4f4655fcb1..ebeed46ba2 100644 --- a/pkg/apply/apply.go +++ b/pkg/apply/apply.go @@ -34,6 +34,10 @@ var ( errorRequestedFieldPathNotFound = errors.New("requested field path not found") ) +func UpsertCommand(command *[]string, cmd ...string) { + *command = append(*command, cmd...) +} + func UpsertEnvVar(envVars *[]corev1ac.EnvVarApplyConfiguration, envVar ...corev1ac.EnvVarApplyConfiguration) { for _, e := range envVar { upsert(envVars, e, byEnvVarName) diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index e00792a40d..e66106a05a 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -152,9 +152,6 @@ const ( // TorchTuneFullFinetuneMultiNodesConfigSuffix is the config suffix for the multi node distributed full finetune. TorchTuneFullFinetuneMultiNodesConfigSuffix string = "_full_multinode" - - // TorchTuneDefaultRecipe is the default recipe for the torchtune. - TorchTuneDefaultRecipe string = TorchTuneFullFinetuneDistributed ) const ( diff --git a/pkg/runtime/core/trainingruntime.go b/pkg/runtime/core/trainingruntime.go index e118d89729..fac68c452f 100644 --- a/pkg/runtime/core/trainingruntime.go +++ b/pkg/runtime/core/trainingruntime.go @@ -183,6 +183,10 @@ func syncPodSets(info *runtime.Info) { } apply.UpsertVolumes(&jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Volumes, ps.Volumes...) for containerIdx, container := range ps.Containers { + apply.UpsertCommand( + &jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Containers[containerIdx].Command, + container.Command..., + ) apply.UpsertEnvVar( &jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Containers[containerIdx].Env, container.Env..., diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index 0bb1c054de..ee511823a3 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -185,28 +185,23 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) WithValue(fmt.Sprintf("%d", constants.ContainerTrainerPort)), ) } else { - // Mutate command line args for torchtune. + // Mutate trainer command for torchtune. // Ref: https://github.com/kubeflow/trainer/tree/master/docs/proposals/2401-llm-trainer-v2#complement-torch-plugin - oldArgs, newArgs := trainJob.Spec.Trainer.Args, []string{} - // 1. Add rendezvous backend arg for torchtune. - newArgs = append(newArgs, + var newCommand []string + newCommand = append(newCommand, fmt.Sprintf("%s %s-%s-0-0.%s:%d", constants.TorchTuneArgRdzvEndpoint, trainJob.Name, constants.Node, trainJob.Name, constants.ContainerTrainerPort, ), ) - // 2. Get the recipe and config from old args and append them to new args. + // 2. Get the recipe and config from old args and append them to newCommand. numNodes := ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1) - recipe := getRecipeFromArgs(numNodes, numProcPerNode, oldArgs) - config := getConfigFileFromArgs(numNodes, recipe, trainJob.Spec.RuntimeRef.Name) - newArgs = append(newArgs, recipe, fmt.Sprintf("--config %s", config)) - - // 3. Reserve old arguments to override corresponding items in the config file. - newArgs = append(newArgs, oldArgs...) + recipe, config := getRecipeAndConfig(numNodes, numProcPerNode, trainJob.Spec.RuntimeRef.Name, trainJob.Spec.Trainer.Args) + newCommand = append(newCommand, recipe, fmt.Sprintf("--config %s", config)) - trainerContainer.Args = newArgs + trainerContainer.Command = newCommand } // Add container port for the headless service. apply.UpsertPort(&trainerContainer.Ports, *corev1ac.ContainerPort().WithContainerPort(constants.ContainerTrainerPort)) @@ -231,18 +226,14 @@ func calculateNumProcPerNode( return intstr.FromInt32(defaultCPU), false } -// getRecipeFromArgs extracts the recipe from the distributed parameters and command line arguments. -// TODO(Electronic-Waste): Add support for more recipes. -func getRecipeFromArgs(numNodes int32, numProcPerNode intstr.IntOrString, _ []string) string { - recipe := constants.TorchTuneDefaultRecipe +// getRecipeAndConfig returns the recipe and config file name based on the number of nodes, +// number of processes per node, runtime reference name, and command line arguments. +func getRecipeAndConfig(numNodes int32, numProcPerNode intstr.IntOrString, runtimeRefName string, _ []string) (string, string) { + recipe := constants.TorchTuneFullFinetuneDistributed if numNodes == 1 && numProcPerNode.Type == intstr.Int && numProcPerNode.IntVal == 1 { recipe = constants.TorchTuneFullFinetuneSingleDevice } - return recipe -} -// getConfigFromArgs extracts the config from distributed parameters, recipe and runtime reference name. -func getConfigFileFromArgs(numNodes int32, recipe, runtimeRefName string) string { // Determine the config file name based on the recipe and number of nodes. var suffix string switch recipe { @@ -256,7 +247,7 @@ func getConfigFileFromArgs(numNodes int32, recipe, runtimeRefName string) string suffix = constants.TorchTuneFullFinetuneSingleDeviceConfigSuffix } - return fmt.Sprintf("%s%s.yaml", getModelFromRuntimeRef(runtimeRefName), suffix) + return recipe, fmt.Sprintf("%s%s.yaml", getModelFromRuntimeRef(runtimeRefName), suffix) } func getModelFromRuntimeRef(runtimeRefName string) string { diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index 15c3b6c1a2..d127bb9f49 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1183,14 +1183,12 @@ func TestTorch(t *testing.T) { SinglePodRequests: make(corev1.ResourceList), Containers: []runtime.Container{{ Name: constants.Node, - Args: []string{ + Command: []string{ + "tune", + "run", fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), constants.TorchTuneFullFinetuneDistributed, "--config llama3_2/1B_full.yaml", - "dtype=fp16", - "batch_size=32", - "epochs=10", - "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", }, Env: []corev1ac.EnvVarApplyConfiguration{ { @@ -1276,14 +1274,12 @@ func TestTorch(t *testing.T) { SinglePodRequests: make(corev1.ResourceList), Containers: []runtime.Container{{ Name: constants.Node, - Args: []string{ + Command: []string{ + "tune", + "run", fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), constants.TorchTuneFullFinetuneSingleDevice, "--config llama3_2/1B_full_single_device.yaml", - "dtype=fp16", - "batch_size=32", - "epochs=10", - "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", }, Env: []corev1ac.EnvVarApplyConfiguration{ { @@ -1369,14 +1365,12 @@ func TestTorch(t *testing.T) { SinglePodRequests: make(corev1.ResourceList), Containers: []runtime.Container{{ Name: constants.Node, - Args: []string{ + Command: []string{ + "tune", + "run", fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), constants.TorchTuneFullFinetuneDistributed, "--config llama3_3/70B_full_multinode.yaml", - "dtype=fp16", - "batch_size=32", - "epochs=10", - "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", }, Env: []corev1ac.EnvVarApplyConfiguration{ { @@ -1606,64 +1600,6 @@ func TestValidate(t *testing.T) { ), }, }, - "reserved environment variable for torchtune": { - info: runtime.NewInfo( - runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper(). - WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper(). - TorchPolicy(ptr.To(intstr.FromString("auto")), nil). - Obj(), - ). - Obj(), - ), - ), - newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test"). - Trainer(utiltesting.MakeTrainJobTrainerWrapper(). - NumProcPerNode(intstr.FromString("auto")). - Container( - "ghcr.io/kubeflow/trainer/torchtune-trainer", - []string{"tune", "run"}, - nil, corev1.ResourceList{}, - ). - Env( - []corev1.EnvVar{ - { - Name: "test", - Value: "value", - }, - { - Name: constants.TorchEnvNumProcPerNode, - Value: "value", - }, - }..., - ). - Obj(), - ). - RuntimeRef( - trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), - "torchtune-llama3.2-1b", - ). - Obj(), - wantError: field.ErrorList{ - field.Invalid( - field.NewPath("spec").Child("trainer").Child("env"), - []corev1.EnvVar{ - { - Name: "test", - Value: "value", - }, - { - Name: constants.TorchEnvNumProcPerNode, - Value: "value", - }, - }, - fmt.Sprintf("must not have reserved envs, invalid envs configured: %v", func() []string { - torchEnvs := sets.New[string]() - torchEnvs.Insert(constants.TorchEnvNumProcPerNode) - return sets.List(torchEnvs) - }()), - ), - }, - }, "unsupported pretrained model": { info: runtime.NewInfo( runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper(). diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 80af27a324..867bbaa3e6 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -80,7 +80,7 @@ type PodSet struct { type Container struct { Name string - Args []string + Command []string Env []corev1ac.EnvVarApplyConfiguration Ports []corev1ac.ContainerPortApplyConfiguration VolumeMounts []corev1ac.VolumeMountApplyConfiguration @@ -158,6 +158,7 @@ func toPodSetContainer(containerApply ...corev1ac.ContainerApplyConfiguration) i for _, cApply := range containerApply { container := Container{ Name: ptr.Deref(cApply.Name, ""), + Command: cApply.Command, Env: cApply.Env, Ports: cApply.Ports, VolumeMounts: cApply.VolumeMounts, From 79816eb3f6b8479ce3d5a960ab5a122cb0eee269 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Tue, 22 Apr 2025 02:53:52 +0000 Subject: [PATCH 21/26] test(plugins): fix UT error in torch plugin. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/framework/plugins/torch/torch_test.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index d127bb9f49..1523e62b21 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1184,8 +1184,6 @@ func TestTorch(t *testing.T) { Containers: []runtime.Container{{ Name: constants.Node, Command: []string{ - "tune", - "run", fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), constants.TorchTuneFullFinetuneDistributed, "--config llama3_2/1B_full.yaml", @@ -1275,8 +1273,6 @@ func TestTorch(t *testing.T) { Containers: []runtime.Container{{ Name: constants.Node, Command: []string{ - "tune", - "run", fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), constants.TorchTuneFullFinetuneSingleDevice, "--config llama3_2/1B_full_single_device.yaml", @@ -1366,8 +1362,6 @@ func TestTorch(t *testing.T) { Containers: []runtime.Container{{ Name: constants.Node, Command: []string{ - "tune", - "run", fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), constants.TorchTuneFullFinetuneDistributed, "--config llama3_3/70B_full_multinode.yaml", From 71b4b5b2d546700197863319f37f6eec90349bd6 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Sun, 27 Apr 2025 12:18:31 +0000 Subject: [PATCH 22/26] fix: fix UT and e2e tests error. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/apply/apply.go | 4 - pkg/runtime/core/trainingruntime.go | 4 - pkg/runtime/core/trainingruntime_test.go | 101 ++++++++++++++++++ pkg/runtime/framework/plugins/torch/torch.go | 17 +-- .../framework/plugins/torch/torch_test.go | 15 --- pkg/runtime/runtime.go | 4 +- 6 files changed, 106 insertions(+), 39 deletions(-) diff --git a/pkg/apply/apply.go b/pkg/apply/apply.go index ebeed46ba2..4f4655fcb1 100644 --- a/pkg/apply/apply.go +++ b/pkg/apply/apply.go @@ -34,10 +34,6 @@ var ( errorRequestedFieldPathNotFound = errors.New("requested field path not found") ) -func UpsertCommand(command *[]string, cmd ...string) { - *command = append(*command, cmd...) -} - func UpsertEnvVar(envVars *[]corev1ac.EnvVarApplyConfiguration, envVar ...corev1ac.EnvVarApplyConfiguration) { for _, e := range envVar { upsert(envVars, e, byEnvVarName) diff --git a/pkg/runtime/core/trainingruntime.go b/pkg/runtime/core/trainingruntime.go index fac68c452f..e118d89729 100644 --- a/pkg/runtime/core/trainingruntime.go +++ b/pkg/runtime/core/trainingruntime.go @@ -183,10 +183,6 @@ func syncPodSets(info *runtime.Info) { } apply.UpsertVolumes(&jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Volumes, ps.Volumes...) for containerIdx, container := range ps.Containers { - apply.UpsertCommand( - &jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Containers[containerIdx].Command, - container.Command..., - ) apply.UpsertEnvVar( &jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Containers[containerIdx].Env, container.Env..., diff --git a/pkg/runtime/core/trainingruntime_test.go b/pkg/runtime/core/trainingruntime_test.go index b486ef8148..b01cebb202 100644 --- a/pkg/runtime/core/trainingruntime_test.go +++ b/pkg/runtime/core/trainingruntime_test.go @@ -486,6 +486,106 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Obj(), }, }, + "succeeded to build JobSet with TorchTune values from the TrainJob": { + trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "torchtune-llama3.3-70b").RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "torchtune-llama3.3-70b").Spec). + WithMLPolicy( + testingutil.MakeMLPolicyWrapper(). + WithNumNodes(100). + WithMLPolicySource(*testingutil.MakeMLPolicySourceWrapper(). + TorchPolicy(ptr.To(intstr.FromString("auto")), nil). + Obj(), + ). + Obj(), + ). + JobSetSpec( + testingutil.MakeJobSetWrapper("", ""). + DependsOn(constants.Node, + []jobsetv1alpha2.DependsOn{ + { + Name: constants.DatasetInitializer, + Status: jobsetv1alpha2.DependencyComplete, + }, + { + Name: constants.ModelInitializer, + Status: jobsetv1alpha2.DependencyComplete, + }, + }..., + ). + Obj(). + Spec, + ). + Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + Obj(), + ).Obj(), + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). + UID("uid"). + RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "torchtune-llama3.3-70b"). + Trainer( + testingutil.MakeTrainJobTrainerWrapper(). + Container("test:trainjob", []string{"tune", "run"}, []string{"runtime"}, resRequests). + NumNodes(30). + NumProcPerNode(intstr.FromInt32(3)). + Obj(), + ). + Obj(), + wantObjs: []runtime.Object{ + testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). + ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid"). + Replicas(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Node, constants.Launcher). + Parallelism(1, constants.DatasetInitializer, constants.ModelInitializer). + Completions(1, constants.DatasetInitializer, constants.ModelInitializer). + NumNodes(30). + Container( + constants.Node, + constants.Node, + "test:trainjob", + []string{ + "tune", + "run", + fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "test-job-node-0-0.test-job:29500"), + constants.TorchTuneFullFinetuneDistributed, + "--config llama3_3/70B_full_multinode.yaml", + }, + []string{"runtime"}, + resRequests, + ). + ContainerTrainerPorts([]corev1.ContainerPort{{ContainerPort: constants.ContainerTrainerPort}}). + Env(constants.Node, constants.Node, + []corev1.EnvVar{ + { + Name: constants.TorchEnvNumNodes, + Value: "30", + }, + { + Name: constants.TorchEnvNumProcPerNode, + Value: "3", + }, + { + Name: constants.TorchEnvNodeRank, + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: constants.JobCompletionIndexFieldPath, + }, + }, + }, + }..., + ). + DependsOn(constants.Node, + []jobsetv1alpha2.DependsOn{ + { + Name: constants.DatasetInitializer, + Status: jobsetv1alpha2.DependencyComplete, + }, + { + Name: constants.ModelInitializer, + Status: jobsetv1alpha2.DependencyComplete, + }, + }..., + ). + Obj(), + }, + }, "succeeded to build JobSet with OpenMPI values from the TrainJob": { ObjCmpOpts: cmp.Options{ cmp.Comparer(testingutil.MPISecretDataComparer), @@ -658,6 +758,7 @@ test-job-node-0-1.test-job slots=8 } for name, tc := range cases { t.Run(name, func(t *testing.T) { + fmt.Println(name) _, ctx := ktesting.NewTestContext(t) var cancel func() ctx, cancel = context.WithCancel(ctx) diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index ee511823a3..a2e7e9aef7 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -201,7 +201,7 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) recipe, config := getRecipeAndConfig(numNodes, numProcPerNode, trainJob.Spec.RuntimeRef.Name, trainJob.Spec.Trainer.Args) newCommand = append(newCommand, recipe, fmt.Sprintf("--config %s", config)) - trainerContainer.Command = newCommand + trainJob.Spec.Trainer.Command = append(trainJob.Spec.Trainer.Command, newCommand...) } // Add container port for the headless service. apply.UpsertPort(&trainerContainer.Ports, *corev1ac.ContainerPort().WithContainerPort(constants.ContainerTrainerPort)) @@ -230,21 +230,12 @@ func calculateNumProcPerNode( // number of processes per node, runtime reference name, and command line arguments. func getRecipeAndConfig(numNodes int32, numProcPerNode intstr.IntOrString, runtimeRefName string, _ []string) (string, string) { recipe := constants.TorchTuneFullFinetuneDistributed + suffix := constants.TorchTuneFullFinetuneMultiDevicesConfigSuffix if numNodes == 1 && numProcPerNode.Type == intstr.Int && numProcPerNode.IntVal == 1 { recipe = constants.TorchTuneFullFinetuneSingleDevice - } - - // Determine the config file name based on the recipe and number of nodes. - var suffix string - switch recipe { - case constants.TorchTuneFullFinetuneDistributed: - if numNodes == 1 { - suffix = constants.TorchTuneFullFinetuneMultiDevicesConfigSuffix - } else { - suffix = constants.TorchTuneFullFinetuneMultiNodesConfigSuffix - } - case constants.TorchTuneFullFinetuneSingleDevice: suffix = constants.TorchTuneFullFinetuneSingleDeviceConfigSuffix + } else if numNodes > 1 { + suffix = constants.TorchTuneFullFinetuneMultiNodesConfigSuffix } return recipe, fmt.Sprintf("%s%s.yaml", getModelFromRuntimeRef(runtimeRefName), suffix) diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index 1523e62b21..ade0e2590e 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1183,11 +1183,6 @@ func TestTorch(t *testing.T) { SinglePodRequests: make(corev1.ResourceList), Containers: []runtime.Container{{ Name: constants.Node, - Command: []string{ - fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), - constants.TorchTuneFullFinetuneDistributed, - "--config llama3_2/1B_full.yaml", - }, Env: []corev1ac.EnvVarApplyConfiguration{ { Name: ptr.To(constants.TorchEnvNumNodes), @@ -1272,11 +1267,6 @@ func TestTorch(t *testing.T) { SinglePodRequests: make(corev1.ResourceList), Containers: []runtime.Container{{ Name: constants.Node, - Command: []string{ - fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), - constants.TorchTuneFullFinetuneSingleDevice, - "--config llama3_2/1B_full_single_device.yaml", - }, Env: []corev1ac.EnvVarApplyConfiguration{ { Name: ptr.To(constants.TorchEnvNumNodes), @@ -1361,11 +1351,6 @@ func TestTorch(t *testing.T) { SinglePodRequests: make(corev1.ResourceList), Containers: []runtime.Container{{ Name: constants.Node, - Command: []string{ - fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "torchtune-job-node-0-0.torchtune-job:29500"), - constants.TorchTuneFullFinetuneDistributed, - "--config llama3_3/70B_full_multinode.yaml", - }, Env: []corev1ac.EnvVarApplyConfiguration{ { Name: ptr.To(constants.TorchEnvNumNodes), diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 867bbaa3e6..93fe301387 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -80,7 +80,6 @@ type PodSet struct { type Container struct { Name string - Command []string Env []corev1ac.EnvVarApplyConfiguration Ports []corev1ac.ContainerPortApplyConfiguration VolumeMounts []corev1ac.VolumeMountApplyConfiguration @@ -135,7 +134,7 @@ func WithTemplateSpecObjApply(objApply any) InfoOption { } // WithPodSet construct Info.TemplateSpec.PodSet from PodSpec. -// The third argument, 'typedPodSpec' is used only to calculate requested resources. +// The forth argument, 'typedPodSpec' is used only to calculate requested resources. func WithPodSet( psName string, ancestor *string, count int32, typedPodSpec corev1.PodSpec, podSpecApply *corev1ac.PodSpecApplyConfiguration, ) InfoOption { @@ -158,7 +157,6 @@ func toPodSetContainer(containerApply ...corev1ac.ContainerApplyConfiguration) i for _, cApply := range containerApply { container := Container{ Name: ptr.Deref(cApply.Name, ""), - Command: cApply.Command, Env: cApply.Env, Ports: cApply.Ports, VolumeMounts: cApply.VolumeMounts, From 47321b79d154e1be6a7cf403a7a89b38e424b313 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Sun, 27 Apr 2025 12:21:57 +0000 Subject: [PATCH 23/26] fix: remove debug info. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/core/trainingruntime_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/runtime/core/trainingruntime_test.go b/pkg/runtime/core/trainingruntime_test.go index b01cebb202..485cf5ac58 100644 --- a/pkg/runtime/core/trainingruntime_test.go +++ b/pkg/runtime/core/trainingruntime_test.go @@ -758,7 +758,6 @@ test-job-node-0-1.test-job slots=8 } for name, tc := range cases { t.Run(name, func(t *testing.T) { - fmt.Println(name) _, ctx := ktesting.NewTestContext(t) var cancel func() ctx, cancel = context.WithCancel(ctx) From 40f411995f64d8f4b02a25710b0b11962033f8ad Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Sun, 27 Apr 2025 12:33:21 +0000 Subject: [PATCH 24/26] fix(test): add args in UTs related to torchtune. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/core/trainingruntime_test.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/pkg/runtime/core/trainingruntime_test.go b/pkg/runtime/core/trainingruntime_test.go index 485cf5ac58..07aae1ea64 100644 --- a/pkg/runtime/core/trainingruntime_test.go +++ b/pkg/runtime/core/trainingruntime_test.go @@ -523,7 +523,19 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "torchtune-llama3.3-70b"). Trainer( testingutil.MakeTrainJobTrainerWrapper(). - Container("test:trainjob", []string{"tune", "run"}, []string{"runtime"}, resRequests). + Container( + "test:trainjob", + []string{ + "tune", + "run", + }, + []string{ + "dtype=fp16", + "batch_size=10", + "epochs=1", + "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", + }, + resRequests). NumNodes(30). NumProcPerNode(intstr.FromInt32(3)). Obj(), @@ -547,7 +559,12 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { constants.TorchTuneFullFinetuneDistributed, "--config llama3_3/70B_full_multinode.yaml", }, - []string{"runtime"}, + []string{ + "dtype=fp16", + "batch_size=10", + "epochs=1", + "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", + }, resRequests, ). ContainerTrainerPorts([]corev1.ContainerPort{{ContainerPort: constants.ContainerTrainerPort}}). From 52287a9365626e34fd51303cf1abcd176614b8e0 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Sun, 27 Apr 2025 12:37:01 +0000 Subject: [PATCH 25/26] fix(test): update torchtune related args. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/core/trainingruntime_test.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/pkg/runtime/core/trainingruntime_test.go b/pkg/runtime/core/trainingruntime_test.go index 07aae1ea64..5cad4a29d5 100644 --- a/pkg/runtime/core/trainingruntime_test.go +++ b/pkg/runtime/core/trainingruntime_test.go @@ -525,14 +525,11 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { testingutil.MakeTrainJobTrainerWrapper(). Container( "test:trainjob", - []string{ - "tune", - "run", - }, + []string{"tune", "run"}, []string{ "dtype=fp16", - "batch_size=10", - "epochs=1", + "batch_size=32", + "epochs=10", "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", }, resRequests). @@ -561,8 +558,8 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { }, []string{ "dtype=fp16", - "batch_size=10", - "epochs=1", + "batch_size=32", + "epochs=10", "loss=torchtune.modules.loss.CEWithChunkedOutputLoss", }, resRequests, From 87f68ff453f815baceb387b8da4d5417ce6da329 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Sun, 27 Apr 2025 13:28:03 +0000 Subject: [PATCH 26/26] fix(test): Add a UT for multi-node mode check in torch plugin. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- pkg/runtime/framework/plugins/torch/torch.go | 6 ++++ .../framework/plugins/torch/torch_test.go | 35 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/pkg/runtime/framework/plugins/torch/torch.go b/pkg/runtime/framework/plugins/torch/torch.go index a2e7e9aef7..b1c13e38e3 100644 --- a/pkg/runtime/framework/plugins/torch/torch.go +++ b/pkg/runtime/framework/plugins/torch/torch.go @@ -93,6 +93,12 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob) if !constants.TorchTuneSupportedPretrainedModels.Has(model) { allErrs = append(allErrs, field.Invalid(runtimeRefNamePath, newObj.Spec.RuntimeRef.Name, fmt.Sprintf("must have a supported pretrained model, invalid model configured: %v", model))) } + + numNodesRefPath := specPath.Child("trainer").Child("numNodes") + numNodes := *newObj.Spec.Trainer.NumNodes + if numNodes > 1 && model != constants.TORCHTUNE_MODEL_LLAMA3_3_70B { + allErrs = append(allErrs, field.Invalid(numNodesRefPath, numNodes, fmt.Sprintf("must be 1 for %v model", model))) + } } } diff --git a/pkg/runtime/framework/plugins/torch/torch_test.go b/pkg/runtime/framework/plugins/torch/torch_test.go index ade0e2590e..f1811e5066 100644 --- a/pkg/runtime/framework/plugins/torch/torch_test.go +++ b/pkg/runtime/framework/plugins/torch/torch_test.go @@ -1592,6 +1592,7 @@ func TestValidate(t *testing.T) { newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test"). Trainer(utiltesting.MakeTrainJobTrainerWrapper(). NumProcPerNode(intstr.FromString("auto")). + NumNodes(int32(1)). Container( "ghcr.io/kubeflow/trainer/torchtune-trainer", []string{"tune", "run"}, @@ -1612,6 +1613,40 @@ func TestValidate(t *testing.T) { ), }, }, + "multi-node mode is only enabled for Llama-3.3-70b": { + info: runtime.NewInfo( + runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper(). + WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper(). + TorchPolicy(ptr.To(intstr.FromString("auto")), nil). + Obj(), + ). + Obj(), + ), + ), + newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper(). + NumProcPerNode(intstr.FromString("auto")). + NumNodes(int32(2)). + Container( + "ghcr.io/kubeflow/trainer/torchtune-trainer", + []string{"tune", "run"}, + nil, corev1.ResourceList{}, + ). + Obj(), + ). + RuntimeRef( + trainer.SchemeGroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), + "torchtune-llama3.2-7b", + ). + Obj(), + wantError: field.ErrorList{ + field.Invalid( + field.NewPath("spec").Child("trainer").Child("numNodes"), + int32(2), + fmt.Sprintf("must be 1 for %v model", "llama3_2/7B"), + ), + }, + }, } for name, tc := range cases { t.Run(name, func(t *testing.T) {