Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7babe4e
chore(plugin): Add torchtune-related constants & update current torch…
Electronic-Waste Apr 8, 2025
e6a9c54
chore(plugin): Add EnforceMLPolicy for torchtune.
Electronic-Waste Apr 8, 2025
afa49c4
chore(plugin): Add UTs in torch plugin.
Electronic-Waste Apr 8, 2025
ada86f8
fix(test): fix error in torch plugin UTs.
Electronic-Waste Apr 8, 2025
c9e340e
chore(plugin): Choose recipe according to numNodes & numProcPerNode &…
Electronic-Waste Apr 9, 2025
8ecaa5b
chore(sdk): Add PretrainedModel enum type.
Electronic-Waste Apr 9, 2025
06b555d
chore(plugin): Add torchtune config arg.
Electronic-Waste Apr 9, 2025
ba55d4c
chore(test): add UT for single-device full fine-tuning with torchtune.
Electronic-Waste Apr 9, 2025
206822e
chore(test): Add test for multi-nodes full fine-tuning with torchtune.
Electronic-Waste Apr 9, 2025
a29dddc
chore(test): Update torch validate UTs.
Electronic-Waste Apr 9, 2025
a9f993a
fix(lint): fix lint error.
Electronic-Waste Apr 9, 2025
7eee7a3
fix(sdk): remove pretrained model enum type in sdk.
Electronic-Waste Apr 10, 2025
7236732
fix(plugin): retrieve model name from runtimeRef.
Electronic-Waste Apr 10, 2025
cabd145
fix(lint): fix typo.
Electronic-Waste Apr 10, 2025
824cb25
fix(plugin): make some adjustments according to the review.
Electronic-Waste Apr 11, 2025
0d1d3a2
fix(sdk): remove runtime in get_trainer_crd_from_builtin_trainer.
Electronic-Waste Apr 12, 2025
d9c8b7c
fix(plugin): pass PET_ env variables in torch plugin for torchtune.
Electronic-Waste Apr 16, 2025
3b22cf7
fix(plugin): add env validation for torchtune.
Electronic-Waste Apr 16, 2025
ccbbefb
fix(plugin): update comments.
Electronic-Waste Apr 16, 2025
d5d347a
fix(plugins): fix the implementation according to the review.
Electronic-Waste Apr 22, 2025
79816eb
test(plugins): fix UT error in torch plugin.
Electronic-Waste Apr 22, 2025
71b4b5b
fix: fix UT and e2e tests error.
Electronic-Waste Apr 27, 2025
47321b7
fix: remove debug info.
Electronic-Waste Apr 27, 2025
40f4119
fix(test): add args in UTs related to torchtune.
Electronic-Waste Apr 27, 2025
52287a9
fix(test): update torchtune related args.
Electronic-Waste Apr 27, 2025
87f68ff
fix(test): Add a UT for multi-node mode check in torch plugin.
Electronic-Waste Apr 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,35 @@ const (

// TorchEnvMasterPort is the env name for the master node port.
TorchEnvMasterPort string = "PET_MASTER_PORT"

// TorchTuneArgRdzvEndpoint is the arg name for the rendezvous endpoint.
TorchTuneArgRdzvEndpoint string = "--rdzv_endpoint"

// TorchTuneFullFinetuneSingleDevice Recipe is the recipe for the single device full finetune.
TorchTuneFullFinetuneSingleDevice string = "full_finetune_single_device"
Comment thread
Electronic-Waste marked this conversation as resolved.

// TorchTuneFullFinetuneSingleDeviceConfigSuffix is the config suffix for the single device full finetune.
TorchTuneFullFinetuneSingleDeviceConfigSuffix string = "_full_single_device"

// TorchTuneFullFinetuneDistributed Recipe is the recipe for the distributed full finetune.
TorchTuneFullFinetuneDistributed string = "full_finetune_distributed"
Comment thread
Electronic-Waste marked this conversation as resolved.

// TorchTuneFullFinetuneMultiDevicesConfigSuffix is the config suffix for the single node distributed full finetune.
TorchTuneFullFinetuneMultiDevicesConfigSuffix string = "_full"

// TorchTuneFullFinetuneMultiNodesConfigSuffix is the config suffix for the multi node distributed full finetune.
TorchTuneFullFinetuneMultiNodesConfigSuffix string = "_full_multinode"
Comment thread
Electronic-Waste marked this conversation as resolved.
)

const (
// TORCHTUNE_MODEL_LLAMA3_2_1B is the model name for the Llama3.2 1B Instruct model.
TORCHTUNE_MODEL_LLAMA3_2_1B = "llama3_2/1B"
Comment thread
Electronic-Waste marked this conversation as resolved.

// TORCHTUNE_MODEL_LLAMA3_2_7B is the model name for the Llama3.2 7B Instruct model.
TORCHTUNE_MODEL_LLAMA3_2_7B = "llama3_2/7B"

// TORCHTUNE_MODEL_LLAMA3_3_70B is the model name for the Llama3.3 70B Instruct model.
TORCHTUNE_MODEL_LLAMA3_3_70B = "llama3_3/70B"
)

var (
Expand All @@ -142,4 +171,10 @@ var (

// Torchrun reserved env names
TorchRunReservedEnvNames = sets.New(TorchEnvNumNodes, TorchEnvNumProcPerNode, TorchEnvNodeRank, TorchEnvMasterAddr, TorchEnvMasterPort)

// Currently supported pretrained models for TorchTune Trainer.
TorchTuneSupportedPretrainedModels = sets.New(TORCHTUNE_MODEL_LLAMA3_2_1B, TORCHTUNE_MODEL_LLAMA3_2_7B, TORCHTUNE_MODEL_LLAMA3_3_70B)

// TorchTuneEntrypoint is the entrypoint for the torchtune.
TorchTuneEntrypoint = []string{"tune", "run"}
)
114 changes: 114 additions & 0 deletions pkg/runtime/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,120 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Obj(),
},
},
"succeeded to build JobSet with TorchTune values from the TrainJob": {
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "torchtune-llama3.3-70b").RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "torchtune-llama3.3-70b").Spec).
WithMLPolicy(
testingutil.MakeMLPolicyWrapper().
WithNumNodes(100).
WithMLPolicySource(*testingutil.MakeMLPolicySourceWrapper().
TorchPolicy(ptr.To(intstr.FromString("auto")), nil).
Obj(),
).
Obj(),
).
JobSetSpec(
testingutil.MakeJobSetWrapper("", "").
DependsOn(constants.Node,
[]jobsetv1alpha2.DependsOn{
{
Name: constants.DatasetInitializer,
Status: jobsetv1alpha2.DependencyComplete,
},
{
Name: constants.ModelInitializer,
Status: jobsetv1alpha2.DependencyComplete,
},
}...,
).
Obj().
Spec,
).
Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Obj(),
).Obj(),
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "torchtune-llama3.3-70b").
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
Container(
"test:trainjob",
[]string{"tune", "run"},
[]string{
"dtype=fp16",
"batch_size=32",
"epochs=10",
"loss=torchtune.modules.loss.CEWithChunkedOutputLoss",
},
resRequests).
NumNodes(30).
NumProcPerNode(intstr.FromInt32(3)).
Obj(),
).
Obj(),
wantObjs: []runtime.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
Replicas(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Node, constants.Launcher).
Parallelism(1, constants.DatasetInitializer, constants.ModelInitializer).
Completions(1, constants.DatasetInitializer, constants.ModelInitializer).
NumNodes(30).
Container(
constants.Node,
constants.Node,
"test:trainjob",
[]string{
"tune",
"run",
fmt.Sprintf("%s %s", constants.TorchTuneArgRdzvEndpoint, "test-job-node-0-0.test-job:29500"),
constants.TorchTuneFullFinetuneDistributed,
"--config llama3_3/70B_full_multinode.yaml",
},
[]string{
"dtype=fp16",
"batch_size=32",
"epochs=10",
"loss=torchtune.modules.loss.CEWithChunkedOutputLoss",
},
resRequests,
).
ContainerTrainerPorts([]corev1.ContainerPort{{ContainerPort: constants.ContainerTrainerPort}}).
Env(constants.Node, constants.Node,
[]corev1.EnvVar{
{
Name: constants.TorchEnvNumNodes,
Value: "30",
},
{
Name: constants.TorchEnvNumProcPerNode,
Value: "3",
},
{
Name: constants.TorchEnvNodeRank,
ValueFrom: &corev1.EnvVarSource{
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: constants.JobCompletionIndexFieldPath,
},
},
},
}...,
).
DependsOn(constants.Node,
[]jobsetv1alpha2.DependsOn{
{
Name: constants.DatasetInitializer,
Status: jobsetv1alpha2.DependencyComplete,
},
{
Name: constants.ModelInitializer,
Status: jobsetv1alpha2.DependencyComplete,
},
}...,
).
Obj(),
},
},
"succeeded to build JobSet with OpenMPI values from the TrainJob": {
ObjCmpOpts: cmp.Options{
cmp.Comparer(testingutil.MPISecretDataComparer),
Expand Down
85 changes: 76 additions & 9 deletions pkg/runtime/framework/plugins/torch/torch.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package torch
import (
"context"
"fmt"
"slices"
"strings"

corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -70,6 +71,7 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob)
}
}

// Check reserved envs.
Comment thread
Electronic-Waste marked this conversation as resolved.
torchEnvs := sets.New[string]()
for _, env := range newObj.Spec.Trainer.Env {
if constants.TorchRunReservedEnvNames.Has(env.Name) {
Expand All @@ -81,6 +83,23 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob)
trainerEnvsPath := specPath.Child("trainer").Child("env")
allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("must not have reserved envs, invalid envs configured: %v", sets.List(torchEnvs))))
}

// Check supported pretrained models for torchtune.
// TODO(Electronic-Waste): Add more validation for torchtune when we support more arguments.
if slices.Equal(newObj.Spec.Trainer.Command, constants.TorchTuneEntrypoint) {
runtimeRefNamePath := specPath.Child("runtimeRef").Child("name")
model := getModelFromRuntimeRef(newObj.Spec.RuntimeRef.Name)

if !constants.TorchTuneSupportedPretrainedModels.Has(model) {
allErrs = append(allErrs, field.Invalid(runtimeRefNamePath, newObj.Spec.RuntimeRef.Name, fmt.Sprintf("must have a supported pretrained model, invalid model configured: %v", model)))
}

numNodesRefPath := specPath.Child("trainer").Child("numNodes")
numNodes := *newObj.Spec.Trainer.NumNodes
if numNodes > 1 && model != constants.TORCHTUNE_MODEL_LLAMA3_3_70B {
allErrs = append(allErrs, field.Invalid(numNodesRefPath, numNodes, fmt.Sprintf("must be 1 for %v model", model)))
}
}
}

return nil, allErrs
Expand Down Expand Up @@ -137,16 +156,16 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob)
}

// Update envs for Info object.
// Add PyTorch distributed "PET_" values for torchrun
// TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other.
// Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940
var trainerContainer *runtime.Container
if trainJob.Spec.Trainer != nil {
if trainerContainer = info.FindContainerByPodSetAncestorContainerName(constants.AncestorTrainer, constants.Node); trainerContainer != nil {
apply.UpsertEnvVars(&trainerContainer.Env, apply.EnvVars(trainJob.Spec.Trainer.Env...)...)
}
}
if trainerContainer != nil {
// Add PyTorch distributed "PET_" values for torchrun and torchtune.
// TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other.
// Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940
apply.UpsertEnvVar(&trainerContainer.Env,
*corev1ac.EnvVar().
WithName(constants.TorchEnvNumNodes).
Expand All @@ -159,13 +178,37 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob)
WithValueFrom(corev1ac.EnvVarSource().
WithFieldRef(corev1ac.ObjectFieldSelector().
WithFieldPath(constants.JobCompletionIndexFieldPath))),
*corev1ac.EnvVar().
WithName(constants.TorchEnvMasterAddr).
WithValue(fmt.Sprintf("%s-%s-0-0.%s", trainJob.Name, constants.Node, trainJob.Name)),
*corev1ac.EnvVar().
WithName(constants.TorchEnvMasterPort).
WithValue(fmt.Sprintf("%d", constants.ContainerTrainerPort)),
)

if !slices.Equal(trainJob.Spec.Trainer.Command, constants.TorchTuneEntrypoint) {
// Add PET_MASTER_ADDR and PET_MASTER_PORT envs for torchrun.
apply.UpsertEnvVar(&trainerContainer.Env,
*corev1ac.EnvVar().
WithName(constants.TorchEnvMasterAddr).
WithValue(fmt.Sprintf("%s-%s-0-0.%s", trainJob.Name, constants.Node, trainJob.Name)),
*corev1ac.EnvVar().
WithName(constants.TorchEnvMasterPort).
WithValue(fmt.Sprintf("%d", constants.ContainerTrainerPort)),
)
} else {
// Mutate trainer command for torchtune.
// Ref: https://github.com/kubeflow/trainer/tree/master/docs/proposals/2401-llm-trainer-v2#complement-torch-plugin
// 1. Add rendezvous backend arg for torchtune.
var newCommand []string
newCommand = append(newCommand,
fmt.Sprintf("%s %s-%s-0-0.%s:%d",
constants.TorchTuneArgRdzvEndpoint,
trainJob.Name, constants.Node, trainJob.Name, constants.ContainerTrainerPort,
),
)

// 2. Get the recipe and config from old args and append them to newCommand.
numNodes := ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1)
recipe, config := getRecipeAndConfig(numNodes, numProcPerNode, trainJob.Spec.RuntimeRef.Name, trainJob.Spec.Trainer.Args)
newCommand = append(newCommand, recipe, fmt.Sprintf("--config %s", config))

trainJob.Spec.Trainer.Command = append(trainJob.Spec.Trainer.Command, newCommand...)
}
// Add container port for the headless service.
apply.UpsertPort(&trainerContainer.Ports, *corev1ac.ContainerPort().WithContainerPort(constants.ContainerTrainerPort))
}
Expand All @@ -188,3 +231,27 @@ func calculateNumProcPerNode(
}
return intstr.FromInt32(defaultCPU), false
}

// getRecipeAndConfig returns the recipe and config file name based on the number of nodes,
// number of processes per node, runtime reference name, and command line arguments.
func getRecipeAndConfig(numNodes int32, numProcPerNode intstr.IntOrString, runtimeRefName string, _ []string) (string, string) {
recipe := constants.TorchTuneFullFinetuneDistributed
suffix := constants.TorchTuneFullFinetuneMultiDevicesConfigSuffix
if numNodes == 1 && numProcPerNode.Type == intstr.Int && numProcPerNode.IntVal == 1 {
recipe = constants.TorchTuneFullFinetuneSingleDevice
suffix = constants.TorchTuneFullFinetuneSingleDeviceConfigSuffix
} else if numNodes > 1 {
suffix = constants.TorchTuneFullFinetuneMultiNodesConfigSuffix
}

return recipe, fmt.Sprintf("%s%s.yaml", getModelFromRuntimeRef(runtimeRefName), suffix)
}

func getModelFromRuntimeRef(runtimeRefName string) string {
fields := strings.Split(runtimeRefName, "-")
if len(fields) != 3 {
Comment thread
Electronic-Waste marked this conversation as resolved.
return ""
}

return fmt.Sprintf("%s/%s", strings.ReplaceAll(fields[1], ".", "_"), strings.ToUpper(fields[2]))
}
Loading