@@ -25,6 +25,7 @@ import (
2525 gocmp "github.com/google/go-cmp/cmp"
2626 "github.com/google/go-cmp/cmp/cmpopts"
2727 corev1 "k8s.io/api/core/v1"
28+ "k8s.io/apimachinery/pkg/api/resource"
2829 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2930 apiruntime "k8s.io/apimachinery/pkg/runtime"
3031 "k8s.io/klog/v2/ktesting"
@@ -33,10 +34,10 @@ import (
3334 "sigs.k8s.io/controller-runtime/pkg/client/interceptor"
3435 schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"
3536
36- trainerv1alpha1 "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
37- "github.com/kubeflow/trainer/pkg/runtime"
38- "github.com/kubeflow/trainer/pkg/runtime/framework"
39- utiltesting "github.com/kubeflow/trainer/pkg/util/testing"
37+ trainerv1alpha1 "github.com/kubeflow/trainer/v2/ pkg/apis/trainer/v1alpha1"
38+ "github.com/kubeflow/trainer/v2/ pkg/runtime"
39+ "github.com/kubeflow/trainer/v2/ pkg/runtime/framework"
40+ utiltesting "github.com/kubeflow/trainer/v2/ pkg/util/testing"
4041)
4142
4243func TestCoScheduling (t * testing.T ) {
@@ -184,6 +185,168 @@ func TestCoScheduling(t *testing.T) {
184185 Obj (),
185186 },
186187 },
188+ "succeeded to build PodGroup with multiple PodSets" : {
189+ info : & runtime.Info {
190+ Scheduler : & runtime.Scheduler {},
191+ RuntimePolicy : runtime.RuntimePolicy {
192+ PodGroupPolicy : & trainerv1alpha1.PodGroupPolicy {
193+ PodGroupPolicySource : trainerv1alpha1.PodGroupPolicySource {
194+ Coscheduling : & trainerv1alpha1.CoschedulingPodGroupPolicySource {
195+ ScheduleTimeoutSeconds : ptr.To [int32 ](30 ),
196+ },
197+ },
198+ },
199+ },
200+ TemplateSpec : runtime.TemplateSpec {
201+ PodSets : []runtime.PodSet {
202+ {
203+ Name : "node" ,
204+ Count : ptr.To [int32 ](2 ),
205+ SinglePodRequests : corev1.ResourceList {
206+ corev1 .ResourceCPU : resource .MustParse ("500m" ),
207+ corev1 .ResourceMemory : resource .MustParse ("1Gi" ),
208+ },
209+ },
210+ {
211+ Name : "dataset-initializer" ,
212+ Count : ptr.To [int32 ](1 ),
213+ SinglePodRequests : corev1.ResourceList {
214+ corev1 .ResourceCPU : resource .MustParse ("250m" ),
215+ corev1 .ResourceMemory : resource .MustParse ("512Mi" ),
216+ },
217+ },
218+ },
219+ },
220+ },
221+ trainJob : utiltesting .MakeTrainJobWrapper (metav1 .NamespaceDefault , "trainJob" ).
222+ UID ("trainJob" ).
223+ Trainer (
224+ utiltesting .MakeTrainJobTrainerWrapper ().
225+ NumNodes (2 ).
226+ Obj ()).
227+ Obj (),
228+ wantInfo : & runtime.Info {
229+ Scheduler : & runtime.Scheduler {
230+ PodLabels : map [string ]string {
231+ "scheduling.x-k8s.io/pod-group" : "trainJob" ,
232+ },
233+ },
234+ RuntimePolicy : runtime.RuntimePolicy {
235+ PodGroupPolicy : & trainerv1alpha1.PodGroupPolicy {
236+ PodGroupPolicySource : trainerv1alpha1.PodGroupPolicySource {
237+ Coscheduling : & trainerv1alpha1.CoschedulingPodGroupPolicySource {
238+ ScheduleTimeoutSeconds : ptr.To [int32 ](30 ),
239+ },
240+ },
241+ },
242+ },
243+ TemplateSpec : runtime.TemplateSpec {
244+ PodSets : []runtime.PodSet {
245+ {
246+ Name : "node" ,
247+ Count : ptr.To [int32 ](2 ),
248+ SinglePodRequests : corev1.ResourceList {
249+ corev1 .ResourceCPU : resource .MustParse ("500m" ),
250+ corev1 .ResourceMemory : resource .MustParse ("1Gi" ),
251+ },
252+ },
253+ {
254+ Name : "dataset-initializer" ,
255+ Count : ptr.To [int32 ](1 ),
256+ SinglePodRequests : corev1.ResourceList {
257+ corev1 .ResourceCPU : resource .MustParse ("250m" ),
258+ corev1 .ResourceMemory : resource .MustParse ("512Mi" ),
259+ },
260+ },
261+ },
262+ },
263+ },
264+ objs : []client.Object {}, // Simulate no existing PodGroup
265+ wantObjs : []apiruntime.Object {
266+ utiltesting .MakeSchedulerPluginsPodGroup (metav1 .NamespaceDefault , "trainJob" ).
267+ MinMember (3 ).
268+ MinResources (corev1.ResourceList {
269+ corev1 .ResourceCPU : resource .MustParse ("1.25" ),
270+ corev1 .ResourceMemory : resource .MustParse ("2.5Gi" ),
271+ }).
272+ SchedulingTimeout (30 ).
273+ ControllerReference (trainerv1alpha1 .GroupVersion .WithKind (trainerv1alpha1 .TrainJobKind ), "trainJob" , "trainJob" ).
274+ Obj (),
275+ },
276+ },
277+ "succeeded to build PodGroup with MinResources" : {
278+ info : & runtime.Info {
279+ Scheduler : & runtime.Scheduler {},
280+ RuntimePolicy : runtime.RuntimePolicy {
281+ PodGroupPolicy : & trainerv1alpha1.PodGroupPolicy {
282+ PodGroupPolicySource : trainerv1alpha1.PodGroupPolicySource {
283+ Coscheduling : & trainerv1alpha1.CoschedulingPodGroupPolicySource {
284+ ScheduleTimeoutSeconds : ptr.To [int32 ](30 ),
285+ },
286+ },
287+ },
288+ },
289+ TemplateSpec : runtime.TemplateSpec {
290+ PodSets : []runtime.PodSet {
291+ {
292+ Name : "node" ,
293+ Count : ptr.To [int32 ](2 ),
294+ SinglePodRequests : corev1.ResourceList {
295+ corev1 .ResourceCPU : resource .MustParse ("500m" ),
296+ corev1 .ResourceMemory : resource .MustParse ("1Gi" ),
297+ },
298+ },
299+ },
300+ },
301+ },
302+ trainJob : utiltesting .MakeTrainJobWrapper (metav1 .NamespaceDefault , "trainJob" ).
303+ UID ("trainJob" ).
304+ Trainer (
305+ utiltesting .MakeTrainJobTrainerWrapper ().
306+ NumNodes (2 ).
307+ Obj ()).
308+ Obj (),
309+ wantInfo : & runtime.Info {
310+ Scheduler : & runtime.Scheduler {
311+ PodLabels : map [string ]string {
312+ "scheduling.x-k8s.io/pod-group" : "trainJob" ,
313+ },
314+ },
315+ RuntimePolicy : runtime.RuntimePolicy {
316+ PodGroupPolicy : & trainerv1alpha1.PodGroupPolicy {
317+ PodGroupPolicySource : trainerv1alpha1.PodGroupPolicySource {
318+ Coscheduling : & trainerv1alpha1.CoschedulingPodGroupPolicySource {
319+ ScheduleTimeoutSeconds : ptr.To [int32 ](30 ),
320+ },
321+ },
322+ },
323+ },
324+ TemplateSpec : runtime.TemplateSpec {
325+ PodSets : []runtime.PodSet {
326+ {
327+ Name : "node" ,
328+ Count : ptr.To [int32 ](2 ),
329+ SinglePodRequests : corev1.ResourceList {
330+ corev1 .ResourceCPU : resource .MustParse ("500m" ),
331+ corev1 .ResourceMemory : resource .MustParse ("1Gi" ),
332+ },
333+ },
334+ },
335+ },
336+ },
337+ objs : []client.Object {}, // Simulate no existing PodGroup
338+ wantObjs : []apiruntime.Object {
339+ utiltesting .MakeSchedulerPluginsPodGroup (metav1 .NamespaceDefault , "trainJob" ).
340+ MinMember (2 ).
341+ MinResources (corev1.ResourceList {
342+ corev1 .ResourceCPU : resource .MustParse ("1" ),
343+ corev1 .ResourceMemory : resource .MustParse ("2Gi" ),
344+ }).
345+ SchedulingTimeout (30 ).
346+ ControllerReference (trainerv1alpha1 .GroupVersion .WithKind (trainerv1alpha1 .TrainJobKind ), "trainJob" , "trainJob" ).
347+ Obj (),
348+ },
349+ },
187350 "failed to get PodGroup due to API error" : {
188351 info : & runtime.Info {
189352 Scheduler : & runtime.Scheduler {},
@@ -275,6 +438,86 @@ func TestCoScheduling(t *testing.T) {
275438 wantPodGroupPolicyError : nil ,
276439 wantBuildError : nil ,
277440 },
441+ "no action when TrainJob is suspended" : {
442+ info : & runtime.Info {
443+ Scheduler : & runtime.Scheduler {},
444+ RuntimePolicy : runtime.RuntimePolicy {
445+ PodGroupPolicy : & trainerv1alpha1.PodGroupPolicy {
446+ PodGroupPolicySource : trainerv1alpha1.PodGroupPolicySource {
447+ Coscheduling : & trainerv1alpha1.CoschedulingPodGroupPolicySource {
448+ ScheduleTimeoutSeconds : ptr.To [int32 ](30 ),
449+ },
450+ },
451+ },
452+ },
453+ TemplateSpec : runtime.TemplateSpec {
454+ PodSets : []runtime.PodSet {
455+ {
456+ Name : "node" ,
457+ Count : ptr.To [int32 ](2 ),
458+ SinglePodRequests : corev1.ResourceList {
459+ corev1 .ResourceCPU : resource .MustParse ("500m" ),
460+ corev1 .ResourceMemory : resource .MustParse ("1Gi" ),
461+ },
462+ },
463+ },
464+ },
465+ },
466+ trainJob : & trainerv1alpha1.TrainJob {
467+ ObjectMeta : metav1.ObjectMeta {
468+ Name : "existingTrainJob" ,
469+ Namespace : metav1 .NamespaceDefault ,
470+ },
471+ Spec : trainerv1alpha1.TrainJobSpec {
472+ Suspend : ptr .To (false ),
473+ },
474+ },
475+ wantInfo : & runtime.Info {
476+ Scheduler : & runtime.Scheduler {
477+ PodLabels : map [string ]string {
478+ "scheduling.x-k8s.io/pod-group" : "existingTrainJob" ,
479+ },
480+ },
481+ RuntimePolicy : runtime.RuntimePolicy {
482+ PodGroupPolicy : & trainerv1alpha1.PodGroupPolicy {
483+ PodGroupPolicySource : trainerv1alpha1.PodGroupPolicySource {
484+ Coscheduling : & trainerv1alpha1.CoschedulingPodGroupPolicySource {
485+ ScheduleTimeoutSeconds : ptr.To [int32 ](30 ),
486+ },
487+ },
488+ },
489+ },
490+ TemplateSpec : runtime.TemplateSpec {
491+ PodSets : []runtime.PodSet {
492+ {
493+ Name : "node" ,
494+ Count : ptr.To [int32 ](2 ),
495+ SinglePodRequests : corev1.ResourceList {
496+ corev1 .ResourceCPU : resource .MustParse ("500m" ),
497+ corev1 .ResourceMemory : resource .MustParse ("1Gi" ),
498+ },
499+ },
500+ },
501+ },
502+ },
503+ objs : []client.Object {
504+ & schedulerpluginsv1alpha1.PodGroup {
505+ ObjectMeta : metav1.ObjectMeta {
506+ Name : "existingTrainJob" ,
507+ Namespace : metav1 .NamespaceDefault ,
508+ },
509+ Spec : schedulerpluginsv1alpha1.PodGroupSpec {
510+ MinMember : 2 ,
511+ MinResources : corev1.ResourceList {
512+ corev1 .ResourceCPU : resource .MustParse ("1" ),
513+ corev1 .ResourceMemory : resource .MustParse ("2Gi" ),
514+ },
515+ ScheduleTimeoutSeconds : ptr.To [int32 ](30 ),
516+ },
517+ },
518+ },
519+ wantObjs : nil , // No new objects should be created
520+ },
278521 }
279522
280523 for name , tc := range cases {
@@ -286,6 +529,28 @@ func TestCoScheduling(t *testing.T) {
286529 clientBuilder := utiltesting .NewClientBuilder ().WithObjects (tc .objs ... )
287530 clientBuilder .WithInterceptorFuncs (interceptor.Funcs {
288531 Get : func (ctx context.Context , client client.WithWatch , key client.ObjectKey , obj client.Object , opts ... client.GetOption ) error {
532+ if podGroup , ok := obj .(* schedulerpluginsv1alpha1.PodGroup ); ok {
533+ // Check if the key matches the expected PodGroup
534+ if key .Name == "existingTrainJob" && key .Namespace == metav1 .NamespaceDefault {
535+ // Simulate finding the PodGroup by copying the expected object into the provided obj
536+ * podGroup = schedulerpluginsv1alpha1.PodGroup {
537+ ObjectMeta : metav1.ObjectMeta {
538+ Name : "existingTrainJob" ,
539+ Namespace : metav1 .NamespaceDefault ,
540+ },
541+ Spec : schedulerpluginsv1alpha1.PodGroupSpec {
542+ MinMember : 2 ,
543+ MinResources : corev1.ResourceList {
544+ corev1 .ResourceCPU : resource .MustParse ("1" ),
545+ corev1 .ResourceMemory : resource .MustParse ("2Gi" ),
546+ },
547+ ScheduleTimeoutSeconds : ptr.To [int32 ](30 ),
548+ },
549+ }
550+ return nil
551+ }
552+ }
553+
289554 if _ , ok := obj .(* schedulerpluginsv1alpha1.PodGroup ); ok && errors .Is (tc .wantBuildError , errorGetPodGroup ) {
290555 return errorGetPodGroup
291556 }
0 commit comments