Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
886 changes: 475 additions & 411 deletions examples/deepspeed/text-summarization/T5-Fine-Tuning.ipynb

Large diffs are not rendered by default.

744 changes: 346 additions & 398 deletions examples/mlx/image-classification/mnist.ipynb

Large diffs are not rendered by default.

325 changes: 143 additions & 182 deletions examples/mlx/language-modeling/fine-tune-llama.ipynb

Large diffs are not rendered by default.

628 changes: 626 additions & 2 deletions pkg/runtime/framework/core/framework_test.go

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions pkg/runtime/framework/plugins/jobset/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,43 @@ func (b *Builder) Initializer(trainJob *trainer.TrainJob) *Builder {
return b
}

// isRunLauncherAsNode returns true if runLauncherAsNode is set to true in the MPI policy.
func (b *Builder) isRunLauncherAsNode(info *runtime.Info) bool {
return info.RuntimePolicy.MLPolicySource != nil &&
info.RuntimePolicy.MLPolicySource.MPI != nil &&
info.RuntimePolicy.MLPolicySource.MPI.RunLauncherAsNode != nil &&
*info.RuntimePolicy.MLPolicySource.MPI.RunLauncherAsNode
}

// Trainer updates JobSet values for the trainer Job.
func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Builder {
for i, rJob := range b.Spec.ReplicatedJobs {
// TODO (andreyvelich): For MPI we should apply container resources to the Node ReplicatedJob also.
// Eventually, we should find better way to propagate resources from TrainJob to JobSet.
if b.isRunLauncherAsNode(info) && *rJob.Name == constants.Node {
for j, container := range rJob.Template.Spec.Template.Spec.Containers {
if *container.Name == constants.Node {
if jobTrainer := trainJob.Spec.Trainer; jobTrainer != nil {
if resourcesPerNode := jobTrainer.ResourcesPerNode; resourcesPerNode != nil &&
(resourcesPerNode.Limits != nil || resourcesPerNode.Requests != nil) {
requirements := corev1ac.ResourceRequirements()
if limits := resourcesPerNode.Limits; limits != nil {
requirements.WithLimits(limits)
}
if requests := resourcesPerNode.Requests; requests != nil {
requirements.WithRequests(requests)
}
b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].
WithResources(requirements)
}
apply.UpsertEnvVars(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to propagate the environment variables as well?

Copy link
Copy Markdown
Member

@andreyvelich andreyvelich Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that env should be propagated as well, for now.
We need to investigate whether mpirun can read env from the Worker nodes.

&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Env,
apply.EnvVars(jobTrainer.Env...)...,
)
}
}
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like this to avoid duplicating the resources logic:

ancestor := ""

jobMetadata := rJob.Template.ObjectMetaApplyConfiguration
if jobMetadata != nil && jobMetadata.Labels != nil {
	ancestor, _ = jobMetadata.Labels[constants.LabelTrainJobAncestor]
}

if ancestor == constants.AncestorTrainer {
	// TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs.
	// REF: https://github.com/kubeflow/trainer/issues/2318
	b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1)
	// Update the Parallelism and Completions values for the Trainer Job.
	b.Spec.ReplicatedJobs[i].Template.Spec.Parallelism = info.FindPodSetByAncestor(constants.AncestorTrainer).Count
	b.Spec.ReplicatedJobs[i].Template.Spec.Completions = info.FindPodSetByAncestor(constants.AncestorTrainer).Count

	// Update values for the Trainer container.
	for j, container := range rJob.Template.Spec.Template.Spec.Containers {
		if *container.Name == constants.Node {
			// Update values from the TrainJob trainer.
			if jobTrainer := trainJob.Spec.Trainer; jobTrainer != nil {
				if image := jobTrainer.Image; image != nil {
					b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Image = image
				}
				if command := jobTrainer.Command; command != nil {
					b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Command = command
				}
				if args := jobTrainer.Args; args != nil {
					b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Args = args
				}
				apply.UpsertEnvVars(
					&b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].Env,
					apply.EnvVars(jobTrainer.Env...)...,
				)
			}
		}
	}
}

// Apply trainer configuration to node containers.
if ancestor == constants.AncestorTrainer ||
	if b.isRunLauncherAsNode(info) && *rJob.Name == constants.Node {
	for j, container := range rJob.Template.Spec.Template.Spec.Containers {
		if *container.Name == constants.Node {
			// Update values from the TrainJob trainer.
			if jobTrainer := trainJob.Spec.Trainer; jobTrainer != nil {
				if resourcesPerNode := jobTrainer.ResourcesPerNode; resourcesPerNode != nil &&
					(resourcesPerNode.Limits != nil || resourcesPerNode.Requests != nil) {
					requirements := corev1ac.ResourceRequirements()
					if limits := resourcesPerNode.Limits; limits != nil {
						requirements.WithLimits(limits)
					}
					if requests := resourcesPerNode.Requests; requests != nil {
						requirements.WithRequests(requests)
					}
					b.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.Containers[j].
						WithResources(requirements)
				}
			}
		}
	}
}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

jobMetadata := rJob.Template.ObjectMetaApplyConfiguration
if jobMetadata == nil || jobMetadata.Labels == nil {
continue
Expand Down
10 changes: 10 additions & 0 deletions pkg/runtime/framework/plugins/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er

if trainJob.Spec.Trainer != nil && trainJob.Spec.Trainer.NumProcPerNode != nil {
info.RuntimePolicy.MLPolicySource.MPI.NumProcPerNode = ptr.To(int32(trainJob.Spec.Trainer.NumProcPerNode.IntValue()))
// If numProcPerNode is set to 1 in runtime, we make it equal to number of GPUs.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tenzen-y @astefanutti @Electronic-Waste I auto set number of slots for MPI plugin equal to number of GPUs, if TrainJob doesn't set NumProcPerNode and NumProcPerNode = 1 (which is default value in our MPI runtimes).

This will help users to use DeepSpeed runtime more easily without modifying the numProcPerNode.
Let me know if that sounds good to you.
/assign @tenzen-y @astefanutti @Electronic-Waste

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic SGTM

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if TrainJob doesn't set NumProcPerNode, would that make sense to set it to the number of GPUs if NumProcPerNode < num GPUs?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@astefanutti I would suggest that we always set NumProcPerNode == num GPUs if the default value: 1 is set in NumProcPerNode.
If users manually override this value in the Runtime or in the TrainJob, we won't override it.
WDYT @astefanutti ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreyvelich right, better not override user-defined values.

} else if *info.RuntimePolicy.MLPolicySource.MPI.NumProcPerNode == 1 {
resourcesPerNode := ptr.Deref(runtime.ExtractResourcePerNodeFromRuntime(info), corev1.ResourceRequirements{})
if jobTrainer := trainJob.Spec.Trainer; jobTrainer != nil && jobTrainer.ResourcesPerNode != nil {
resourcesPerNode = ptr.Deref(jobTrainer.ResourcesPerNode, corev1.ResourceRequirements{})
}
gpuQ := runtime.GetNumGPUPerNode(&resourcesPerNode)
if gpuQ > 1 {
info.RuntimePolicy.MLPolicySource.MPI.NumProcPerNode = ptr.To(int32(gpuQ))
}
}

// Add Secret and ConfigMap volumes to the Info object
Expand Down
138 changes: 138 additions & 0 deletions pkg/runtime/framework/plugins/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
gocmp "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
apiruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/intstr"
Expand Down Expand Up @@ -350,6 +351,143 @@ trainJob-node-1-1.trainJob slots=1
utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault).
WithData(map[string]string{
constants.MPIHostfileName: `trainJob-node-1-0.trainJob slots=2
`,
}).
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob").
Obj(),
},
},
"numProcPerNode is set to number of GPUs in TrainJob": {
info: &runtime.Info{
Labels: make(map[string]string),
Annotations: make(map[string]string),
TemplateSpec: runtime.TemplateSpec{
PodSets: []runtime.PodSet{
{
Name: constants.Launcher,
Count: ptr.To[int32](1),
Endpoints: func(yield func(string) bool) {
yield("trainJob-launcher-0-0.trainJob")
},
},
{
Name: constants.Node,
Ancestor: ptr.To(constants.AncestorTrainer),
Count: ptr.To[int32](1),
Endpoints: func(yield func(string) bool) {
yield("trainJob-node-1-0.trainJob")
},
},
},
},
RuntimePolicy: runtime.RuntimePolicy{
MLPolicySource: utiltesting.MakeMLPolicySourceWrapper().
MPIPolicy(ptr.To[int32](1), trainer.MPIImplementationOpenMPI, ptr.To("/root/.ssh"), nil).
Obj(),
},
Scheduler: &runtime.Scheduler{
PodLabels: make(map[string]string),
},
},
trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "trainJob").
UID("trainJob").
Trainer(
utiltesting.MakeTrainJobTrainerWrapper().
NumNodes(1).
Container("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, corev1.ResourceList{
"custom.com/gpu": resource.MustParse("5"),
}).
Obj()).
Obj(),
wantInfo: &runtime.Info{
Labels: make(map[string]string),
Annotations: make(map[string]string),
TemplateSpec: runtime.TemplateSpec{
PodSets: []runtime.PodSet{
{
Name: constants.Launcher,
Count: ptr.To[int32](1),
Volumes: []corev1ac.VolumeApplyConfiguration{
*corev1ac.Volume().
WithName(constants.MPISSHAuthVolumeName).
WithSecret(corev1ac.SecretVolumeSource().
WithSecretName(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix)).
WithItems(
corev1ac.KeyToPath().
WithKey(corev1.SSHAuthPrivateKey).
WithPath(constants.MPISSHPrivateKeyFile),
corev1ac.KeyToPath().
WithKey(constants.MPISSHPublicKey).
WithPath(constants.MPISSHPublicKeyFile),
corev1ac.KeyToPath().
WithKey(constants.MPISSHPublicKey).
WithPath(constants.MPISSHAuthorizedKeys),
),
),
*corev1ac.Volume().
WithName(constants.MPIHostfileVolumeName).
WithConfigMap(corev1ac.ConfigMapVolumeSource().
WithName(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix)).
WithItems(
corev1ac.KeyToPath().
WithKey(constants.MPIHostfileName).
WithPath(constants.MPIHostfileName).
WithMode(0444),
),
),
},
Endpoints: func(yield func(string) bool) {
yield("trainJob-launcher-0-0.trainJob")
},
},
{
Name: constants.Node,
Ancestor: ptr.To(constants.AncestorTrainer),
Count: ptr.To[int32](1),
Volumes: []corev1ac.VolumeApplyConfiguration{
*corev1ac.Volume().
WithName(constants.MPISSHAuthVolumeName).
WithSecret(corev1ac.SecretVolumeSource().
WithSecretName(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix)).
WithItems(
corev1ac.KeyToPath().
WithKey(corev1.SSHAuthPrivateKey).
WithPath(constants.MPISSHPrivateKeyFile),
corev1ac.KeyToPath().
WithKey(constants.MPISSHPublicKey).
WithPath(constants.MPISSHPublicKeyFile),
corev1ac.KeyToPath().
WithKey(constants.MPISSHPublicKey).
WithPath(constants.MPISSHAuthorizedKeys),
),
),
},
Endpoints: func(yield func(string) bool) {
yield("trainJob-node-1-0.trainJob")
},
},
},
},
RuntimePolicy: runtime.RuntimePolicy{
MLPolicySource: utiltesting.MakeMLPolicySourceWrapper().
MPIPolicy(ptr.To[int32](5), trainer.MPIImplementationOpenMPI, ptr.To("/root/.ssh"), nil).
Obj(),
},
Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)},
},
wantObjs: []apiruntime.Object{
utiltesting.MakeSecretWrapper(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault).
WithImmutable(true).
WithType(corev1.SecretTypeSSHAuth).
WithData(map[string][]byte{
constants.MPISSHPublicKey: []byte("EXIST"),
corev1.SSHAuthPrivateKey: []byte("EXIST"),
}).
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob").
Obj(),
utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault).
WithData(map[string]string{
constants.MPIHostfileName: `trainJob-node-1-0.trainJob slots=5
`,
}).
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob").
Expand Down
53 changes: 2 additions & 51 deletions pkg/runtime/framework/plugins/torch/torch.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"
"fmt"
"slices"
"strings"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/intstr"
Expand All @@ -30,7 +29,6 @@ import (
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2"

trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/v2/pkg/apply"
Expand Down Expand Up @@ -113,11 +111,11 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob)
}

// Determine numProcPerNode based on the resourcesPerNode.
resourcesPerNode := ptr.Deref(extractResourcePerNodeFromRuntime(info), corev1.ResourceRequirements{})
resourcesPerNode := ptr.Deref(runtime.ExtractResourcePerNodeFromRuntime(info), corev1.ResourceRequirements{})
if jobTrainer := trainJob.Spec.Trainer; jobTrainer != nil && jobTrainer.ResourcesPerNode != nil {
resourcesPerNode = ptr.Deref(jobTrainer.ResourcesPerNode, corev1.ResourceRequirements{})
}
gpuQ := getNumGPUPerNode(&resourcesPerNode)
gpuQ := runtime.GetNumGPUPerNode(&resourcesPerNode)
// If numProcPerNode is "cpu" or no GPU is set in resource, we calculate numProcPerNode based on CPU.
if numProcPerNode.String() == "cpu" || numProcPerNode.String() == "auto" && gpuQ == 0 {
numProcPerNode = intstr.FromInt(max(1, getNumCPUPerNode(&resourcesPerNode)))
Expand Down Expand Up @@ -204,50 +202,3 @@ func getNumCPUPerNode(res *corev1.ResourceRequirements) int {
}
return int(requestCpuQ.Value())
}

// getNumGPUPerNode returns the GPU count if found.
func getNumGPUPerNode(res *corev1.ResourceRequirements) int {
if res == nil {
return 0
}
gpuQ := numGPU(res.Requests)
if limitGpuQ := numGPU(res.Limits); gpuQ == 0 && limitGpuQ > 0 {
gpuQ = limitGpuQ
}
return gpuQ
}

func numGPU(resourcePerNode corev1.ResourceList) int {
for resName, resQ := range resourcePerNode {
if strings.Contains(strings.ToLower(resName.String()), "gpu") {
return int(resQ.Value())
}
}
return 0
}

// extractResourcePerNodeFromRuntime extracts the resource per node from the Trainer Node.
func extractResourcePerNodeFromRuntime(info *runtime.Info) *corev1.ResourceRequirements {
if jobSetSpec, ok := runtime.TemplateSpecApply[jobsetv1alpha2ac.JobSetSpecApplyConfiguration](info); ok {
for _, rJob := range jobSetSpec.ReplicatedJobs {
if rJob.Name != nil && *rJob.Name == constants.Node || rJob.Template.Labels[constants.LabelTrainJobAncestor] == constants.AncestorTrainer {
for _, container := range rJob.Template.Spec.Template.Spec.Containers {
if container.Name != nil && *container.Name == constants.Node && container.Resources != nil {
res := &corev1.ResourceRequirements{
Limits: corev1.ResourceList{},
Requests: corev1.ResourceList{},
}
if container.Resources.Limits != nil {
res.Limits = *container.Resources.Limits
}
if container.Resources.Requests != nil {
res.Requests = *container.Resources.Requests
}
return res
}
}
}
}
}
return nil
}
4 changes: 2 additions & 2 deletions pkg/runtime/framework/plugins/torch/torchtune.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ func validateTorchTune(runtimeInfo *runtime.Info, newObj *trainer.TrainJob) (adm

numProcPerNodeRefPath := specPath.Child("trainer").Child("numProcPerNode")
numProcPerNode := *newObj.Spec.Trainer.NumProcPerNode
resourcesPerNode := ptr.Deref(extractResourcePerNodeFromRuntime(runtimeInfo), corev1.ResourceRequirements{})
resourcesPerNode := ptr.Deref(runtime.ExtractResourcePerNodeFromRuntime(runtimeInfo), corev1.ResourceRequirements{})
if jobTrainer := newObj.Spec.Trainer; jobTrainer != nil && jobTrainer.ResourcesPerNode != nil {
resourcesPerNode = ptr.Deref(jobTrainer.ResourcesPerNode, corev1.ResourceRequirements{})
}
_, config := getRecipeAndConfig(numNodes, numProcPerNode, getNumGPUPerNode(&resourcesPerNode), newObj)
_, config := getRecipeAndConfig(numNodes, numProcPerNode, runtime.GetNumGPUPerNode(&resourcesPerNode), newObj)
if strings.Contains(config, constants.TorchTuneQLoRAFinetuneDistributedConfigSuffix) {
if model == constants.TORCHTUNE_MODEL_QWEN2_5_1_5B {
allErrs = append(allErrs, field.Invalid(runtimeRefNamePath, newObj.Spec.RuntimeRef.Name, fmt.Sprintf("QLoRA is not supported for %v model", model)))
Expand Down
50 changes: 50 additions & 0 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ import (
"iter"
"maps"
"slices"
"strings"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
corev1ac "k8s.io/client-go/applyconfigurations/core/v1"
resourcehelpers "k8s.io/component-helpers/resource"
"k8s.io/utils/ptr"
jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2"

trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/v2/pkg/constants"
)

var (
Expand Down Expand Up @@ -240,3 +243,50 @@ func RuntimeRefToRuntimeRegistryKey(runtimeRef trainer.RuntimeRef) string {
Kind: ptr.Deref(runtimeRef.Kind, ""),
}.String()
}

// ExtractResourcePerNodeFromRuntime extracts the Trainer resource per node from the Info object.
func ExtractResourcePerNodeFromRuntime(info *Info) *corev1.ResourceRequirements {
if jobSetSpec, ok := TemplateSpecApply[jobsetv1alpha2ac.JobSetSpecApplyConfiguration](info); ok {
for _, rJob := range jobSetSpec.ReplicatedJobs {
if rJob.Name != nil && *rJob.Name == constants.Node || rJob.Template.Labels[constants.LabelTrainJobAncestor] == constants.AncestorTrainer {
for _, container := range rJob.Template.Spec.Template.Spec.Containers {
if container.Name != nil && *container.Name == constants.Node && container.Resources != nil {
res := &corev1.ResourceRequirements{
Limits: corev1.ResourceList{},
Requests: corev1.ResourceList{},
}
if container.Resources.Limits != nil {
res.Limits = *container.Resources.Limits
}
if container.Resources.Requests != nil {
res.Requests = *container.Resources.Requests
}
return res
}
}
}
}
}
return nil
}

// GetNumGPUPerNode returns the GPU count if found in container resources.
func GetNumGPUPerNode(res *corev1.ResourceRequirements) int {
if res == nil {
return 0
}
gpuQ := numGPU(res.Requests)
if limitGpuQ := numGPU(res.Limits); gpuQ == 0 && limitGpuQ > 0 {
gpuQ = limitGpuQ
}
return gpuQ
}

func numGPU(resourcePerNode corev1.ResourceList) int {
for resName, resQ := range resourcePerNode {
if strings.Contains(strings.ToLower(resName.String()), "gpu") {
return int(resQ.Value())
}
}
return 0
}
7 changes: 6 additions & 1 deletion pkg/util/testing/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ import (
)

var (
PodSetEndpointsCmpOpts = cmp.Transformer("Seq", func(a iter.Seq[string]) []string { return slices.Collect(a) })
PodSetEndpointsCmpOpts = cmp.Transformer("Seq", func(a iter.Seq[string]) []string {
if a == nil {
return nil
}
return slices.Collect(a)
})
TrainJobUpdateReconcileRequestCmpOpts = cmp.Transformer("SeqTrainJobUpdateReconcileRequest",
func(req iter.Seq[types.NamespacedName]) []types.NamespacedName {
if req == nil {
Expand Down
Loading