Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7babe4e
chore(plugin): Add torchtune-related constants & update current torch…
Electronic-Waste Apr 8, 2025
e6a9c54
chore(plugin): Add EnforceMLPolicy for torchtune.
Electronic-Waste Apr 8, 2025
afa49c4
chore(plugin): Add UTs in torch plugin.
Electronic-Waste Apr 8, 2025
ada86f8
fix(test): fix error in torch plugin UTs.
Electronic-Waste Apr 8, 2025
c9e340e
chore(plugin): Choose recipe according to numNodes & numProcPerNode &…
Electronic-Waste Apr 9, 2025
8ecaa5b
chore(sdk): Add PretrainedModel enum type.
Electronic-Waste Apr 9, 2025
06b555d
chore(plugin): Add torchtune config arg.
Electronic-Waste Apr 9, 2025
ba55d4c
chore(test): add UT for single-device full fine-tuning with torchtune.
Electronic-Waste Apr 9, 2025
206822e
chore(test): Add test for multi-nodes full fine-tuning with torchtune.
Electronic-Waste Apr 9, 2025
a29dddc
chore(test): Update torch validate UTs.
Electronic-Waste Apr 9, 2025
a9f993a
fix(lint): fix lint error.
Electronic-Waste Apr 9, 2025
7eee7a3
fix(sdk): remove pretrained model enum type in sdk.
Electronic-Waste Apr 10, 2025
7236732
fix(plugin): retrieve model name from runtimeRef.
Electronic-Waste Apr 10, 2025
cabd145
fix(lint): fix typo.
Electronic-Waste Apr 10, 2025
824cb25
fix(plugin): make some adjustments according to the review.
Electronic-Waste Apr 11, 2025
0d1d3a2
fix(sdk): remove runtime in get_trainer_crd_from_builtin_trainer.
Electronic-Waste Apr 12, 2025
d9c8b7c
fix(plugin): pass PET_ env variables in torch plugin for torchtune.
Electronic-Waste Apr 16, 2025
3b22cf7
fix(plugin): add env validation for torchtune.
Electronic-Waste Apr 16, 2025
ccbbefb
fix(plugin): update comments.
Electronic-Waste Apr 16, 2025
d5d347a
fix(plugins): fix the implementation according to the review.
Electronic-Waste Apr 22, 2025
79816eb
test(plugins): fix UT error in torch plugin.
Electronic-Waste Apr 22, 2025
71b4b5b
fix: fix UT and e2e tests error.
Electronic-Waste Apr 27, 2025
47321b7
fix: remove debug info.
Electronic-Waste Apr 27, 2025
40f4119
fix(test): add args in UTs related to torchtune.
Electronic-Waste Apr 27, 2025
52287a9
fix(test): update torchtune related args.
Electronic-Waste Apr 27, 2025
87f68ff
fix(test): Add a UT for multi-node mode check in torch plugin.
Electronic-Waste Apr 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,47 @@ const (

// TorchEnvMasterPort is the env name for the master node port.
TorchEnvMasterPort string = "PET_MASTER_PORT"

// TochTuneArgNumNodes is the arg anme for the number of training nodes.
TorchTuneArgNumNodes string = "--nnodes"

// TorchTuneArgNumProcPerNode is the arg name for the number of procs per node (e.g. number of GPUs per Pod).
TorchTuneArgNumProcPerNode string = "--nproc_per_node"

// TorchTuneArgRdzvId is the arg name for the rendezvous ID.
TorchTuneArgRdzvId string = "--rdzv_id"

// TorchTuneArgRdzvEndpoint is the arg name for the rendezvous endpoint.
TorchTuneArgRdzvEndpoint string = "--rdzv_endpoint"
Comment thread
Electronic-Waste marked this conversation as resolved.
Outdated

// TorchTuneFullFinetuneSingleDevice Recipe is the recipe for the single device full finetune.
TorchTuneFullFinetuneSingleDevice string = "full_finetune_single_device"
Comment thread
Electronic-Waste marked this conversation as resolved.

// TorchTuneFullFinetuneSingleDeviceConfigSuffix is the config suffix for the single device full finetune.
TorchTuneFullFinetuneSingleDeviceConfigSuffix string = "_full_single_device"

// TorchTuneFullFinetuneDistributed Recipe is the recipe for the distributed full finetune.
TorchTuneFullFinetuneDistributed string = "full_finetune_distributed"
Comment thread
Electronic-Waste marked this conversation as resolved.

// TorchTuneFullFinetuneMultiDevicesConfigSuffix is the config suffix for the single node distributed full finetune.
TorchTuneFullFinetuneMultiDevicesConfigSuffix string = "_full"

// TorchTuneFullFinetuneMultiNodesConfigSuffix is the config suffix for the multi node distributed full finetune.
TorchTuneFullFinetuneMultiNodesConfigSuffix string = "_full_multinode"
Comment thread
Electronic-Waste marked this conversation as resolved.

// TorchTuneDefaultRecipe is the default recipe for the torchtune.
TorchTuneDefaultRecipe string = TorchTuneFullFinetuneDistributed
)

const (
// MODEL_LLAMA3_2_1B is the model name for the Llama3.2 1B Instruct model.
MODEL_LLAMA3_2_1B = "llama3_2/1B"
Comment thread
Electronic-Waste marked this conversation as resolved.
Outdated

// MODEL_LLAMA3_2_7B is the model name for the Llama3.2 7B Instruct model.
MODEL_LLAMA3_2_7B = "llama3_2/7B"

// MODEL_LLAMA3_3_70B is the model name for the Llama3.3 70B Instruct model.
MODEL_LLAMA3_3_70B = "llama3_3/70B"
)

var (
Expand All @@ -142,4 +183,13 @@ var (

// Torchrun reserved env names
TorchRunReservedEnvNames = sets.New(TorchEnvNumNodes, TorchEnvNumProcPerNode, TorchEnvNodeRank, TorchEnvMasterAddr, TorchEnvMasterPort)

// Currently supported TorchTune recipes.
TorchTuneSupportedRecipes = sets.New(TorchTuneFullFinetuneSingleDevice, TorchTuneFullFinetuneDistributed)
Comment thread
Electronic-Waste marked this conversation as resolved.
Outdated

// Currently supported pretrained models for TorchTuen Trainer.
Comment thread
Electronic-Waste marked this conversation as resolved.
Outdated
TorchTuneSupportedPretrainedModels = sets.New(MODEL_LLAMA3_2_1B, MODEL_LLAMA3_2_7B, MODEL_LLAMA3_3_70B)
Comment thread
Electronic-Waste marked this conversation as resolved.
Outdated

// TorchTuneEntrypoint is the entrypoint for the torchtune.
TorchTuneEntrypoint = []string{"tune", "run"}
)
154 changes: 124 additions & 30 deletions pkg/runtime/framework/plugins/torch/torch.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package torch
import (
"context"
"fmt"
"slices"
"strings"

corev1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -70,16 +71,30 @@ func (t *Torch) Validate(runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob)
}
}

torchEnvs := sets.New[string]()
for _, env := range newObj.Spec.Trainer.Env {
if constants.TorchRunReservedEnvNames.Has(env.Name) {
torchEnvs.Insert(env.Name)
if !slices.Equal(newObj.Spec.Trainer.Command, constants.TorchTuneEntrypoint) {
Comment thread
Electronic-Waste marked this conversation as resolved.
Outdated
// Check reserved envs for torchrun.
torchEnvs := sets.New[string]()
for _, env := range newObj.Spec.Trainer.Env {
if constants.TorchRunReservedEnvNames.Has(env.Name) {
torchEnvs.Insert(env.Name)
}
}
}

if torchEnvs.Len() > 0 {
trainerEnvsPath := specPath.Child("trainer").Child("env")
allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("must not have reserved envs, invalid envs configured: %v", sets.List(torchEnvs))))
if torchEnvs.Len() > 0 {
trainerEnvsPath := specPath.Child("trainer").Child("env")
allErrs = append(allErrs, field.Invalid(trainerEnvsPath, newObj.Spec.Trainer.Env, fmt.Sprintf("must not have reserved envs, invalid envs configured: %v", sets.List(torchEnvs))))
}
} else {
// Check supported pretrained models for torchtune.
// TODO(Electronic-Waste): Add more validation for torchtune when we support more arguments.
argPath := specPath.Child("trainer").Child("args")
model := getModelFromArgs(newObj.Spec.Trainer.Args)

if model == nil {
allErrs = append(allErrs, field.Invalid(argPath, newObj.Spec.Trainer.Args, "must specify a pretrained model"))
} else if !constants.TorchTuneSupportedPretrainedModels.Has(*model) {
allErrs = append(allErrs, field.Invalid(argPath, newObj.Spec.Trainer.Args, fmt.Sprintf("must have a supported pretrained model, invalid model configured: %v", *model)))
}
}
}

Expand Down Expand Up @@ -137,35 +152,75 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob)
}

// Update envs for Info object.
// Add PyTorch distributed "PET_" values for torchrun
// TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other.
// Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940
var trainerContainer *runtime.Container
if trainJob.Spec.Trainer != nil {
if trainerContainer = info.FindContainerByPodSetAncestorContainerName(constants.AncestorTrainer, constants.Node); trainerContainer != nil {
apply.UpsertEnvVars(&trainerContainer.Env, apply.EnvVars(trainJob.Spec.Trainer.Env...)...)
}
}
if trainerContainer != nil {
apply.UpsertEnvVar(&trainerContainer.Env,
*corev1ac.EnvVar().
WithName(constants.TorchEnvNumNodes).
WithValue(fmt.Sprintf("%d", ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1))),
*corev1ac.EnvVar().
WithName(constants.TorchEnvNumProcPerNode).
WithValue(numProcPerNode.String()),
*corev1ac.EnvVar().
WithName(constants.TorchEnvNodeRank).
WithValueFrom(corev1ac.EnvVarSource().
WithFieldRef(corev1ac.ObjectFieldSelector().
WithFieldPath(constants.JobCompletionIndexFieldPath))),
*corev1ac.EnvVar().
WithName(constants.TorchEnvMasterAddr).
WithValue(fmt.Sprintf("%s-%s-0-0.%s", trainJob.Name, constants.Node, trainJob.Name)),
*corev1ac.EnvVar().
WithName(constants.TorchEnvMasterPort).
WithValue(fmt.Sprintf("%d", constants.ContainerTrainerPort)),
)
if !slices.Equal(trainJob.Spec.Trainer.Command, constants.TorchTuneEntrypoint) {
// Add PyTorch distributed "PET_" values for torchrun.
// TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other.
// Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940
apply.UpsertEnvVar(&trainerContainer.Env,
*corev1ac.EnvVar().
WithName(constants.TorchEnvNumNodes).
WithValue(fmt.Sprintf("%d", ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1))),
*corev1ac.EnvVar().
WithName(constants.TorchEnvNumProcPerNode).
WithValue(numProcPerNode.String()),
*corev1ac.EnvVar().
WithName(constants.TorchEnvNodeRank).
WithValueFrom(corev1ac.EnvVarSource().
WithFieldRef(corev1ac.ObjectFieldSelector().
WithFieldPath(constants.JobCompletionIndexFieldPath))),
*corev1ac.EnvVar().
WithName(constants.TorchEnvMasterAddr).
WithValue(fmt.Sprintf("%s-%s-0-0.%s", trainJob.Name, constants.Node, trainJob.Name)),
*corev1ac.EnvVar().
WithName(constants.TorchEnvMasterPort).
WithValue(fmt.Sprintf("%d", constants.ContainerTrainerPort)),
)
} else {
// Mutate command line args for torchtune.
// Ref: https://github.com/kubeflow/trainer/tree/master/docs/proposals/2401-llm-trainer-v2#complement-torch-plugin
oldArgs, newArgs := trainJob.Spec.Trainer.Args, []string{}
Comment thread
Electronic-Waste marked this conversation as resolved.
Outdated

// 1. Add PyTorch distributed command line args for torchtune.
// TODO(Electronic-Waste): Add more args for torchtune if required.
numNodes := ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1)
newArgs = append(newArgs,
fmt.Sprintf("%s %d",
constants.TorchTuneArgNumNodes,
numNodes,
),
fmt.Sprintf("%s %s",
constants.TorchTuneArgNumProcPerNode,
numProcPerNode.String(),
),
fmt.Sprintf("%s %s",
constants.TorchTuneArgRdzvId,
trainJob.Name,
),
fmt.Sprintf("%s %s-%s-0-0.%s:%d",
constants.TorchTuneArgRdzvEndpoint,
trainJob.Name, constants.Node, trainJob.Name, constants.ContainerTrainerPort,
),
)
Comment thread
Electronic-Waste marked this conversation as resolved.

// 2. Get the recipe and config from old args and append them to new args.
recipe := getRecipeFromArgs(numNodes, numProcPerNode, oldArgs)
Comment thread
Electronic-Waste marked this conversation as resolved.
Outdated
config := getConfigFileFromArgs(numNodes, recipe, oldArgs)
newArgs = append(newArgs, recipe, fmt.Sprintf("--config %s", config))
Comment thread
Electronic-Waste marked this conversation as resolved.
Outdated

// 3. Reserve old arguments to override corresponding items in the config file.
newArgs = append(newArgs, slices.DeleteFunc(oldArgs, func(arg string) bool {
return strings.HasPrefix(arg, "model")
})...)
Comment thread
Electronic-Waste marked this conversation as resolved.
Outdated

trainerContainer.Args = newArgs
}
// Add container port for the headless service.
apply.UpsertPort(&trainerContainer.Ports, *corev1ac.ContainerPort().WithContainerPort(constants.ContainerTrainerPort))
}
Expand All @@ -188,3 +243,42 @@ func calculateNumProcPerNode(
}
return intstr.FromInt32(defaultCPU), false
}

// getRecipeFromArgs extracts the recipe from the distributed parameters and command line arguments.
// TODO(Electronic-Waste): Add support for more recipes.
func getRecipeFromArgs(numNodes int32, numProcPerNode intstr.IntOrString, _ []string) string {
recipe := constants.TorchTuneDefaultRecipe
Comment thread
Electronic-Waste marked this conversation as resolved.
Outdated
if numNodes == 1 && numProcPerNode.Type == intstr.Int && numProcPerNode.IntVal == 1 {
recipe = constants.TorchTuneFullFinetuneSingleDevice
}
return recipe
}

// getConfigFromArgs extracts the config from distributed parameters, recipe and command line arguments.
func getConfigFileFromArgs(numNodes int32, recipe string, args []string) string {
// Determine the config file name based on the recipe and number of nodes.
var suffix string
switch recipe {
Comment thread
Electronic-Waste marked this conversation as resolved.
Outdated
case constants.TorchTuneFullFinetuneDistributed:
if numNodes == 1 {
suffix = constants.TorchTuneFullFinetuneMultiDevicesConfigSuffix
} else {
suffix = constants.TorchTuneFullFinetuneMultiNodesConfigSuffix
}
case constants.TorchTuneFullFinetuneSingleDevice:
suffix = constants.TorchTuneFullFinetuneSingleDeviceConfigSuffix
}

return fmt.Sprintf("%s%s.yaml", *getModelFromArgs(args), suffix)
}

func getModelFromArgs(args []string) *string {
var model *string
for _, arg := range args {
if strings.HasPrefix(arg, "model") {
model = &strings.Split(arg, "=")[1]
break
}
}
return model
}
Loading