Skip to content

Commit c1bc94d

Browse files
authored
feat(api): Sync TrainJob JobsStatus from JobSet ReplicatedJobsStatus (#2802)
* feat(api): Sync TrainJob JobsStatus from JobSet ReplicatedJobsStatus Signed-off-by: Antonin Stefanutti <antonin@stefanutti.fr> * Add integration tests Signed-off-by: Antonin Stefanutti <antonin@stefanutti.fr> * Update e2e tests Signed-off-by: Antonin Stefanutti <antonin@stefanutti.fr> * Remove extra check Signed-off-by: Antonin Stefanutti <antonin@stefanutti.fr> * Sort JobsStatus in e2e tests Signed-off-by: Antonin Stefanutti <antonin@stefanutti.fr> * Fix e2e test for MPI job Signed-off-by: Antonin Stefanutti <antonin@stefanutti.fr> * Fail-fast when multiple terminal condition and JobsStatus plugins exist Signed-off-by: Antonin Stefanutti <antonin@stefanutti.fr> * Fold TerminalCondition and JobsStatus plugins Signed-off-by: Antonin Stefanutti <antonin@stefanutti.fr> --------- Signed-off-by: Antonin Stefanutti <antonin@stefanutti.fr>
1 parent 25022bd commit c1bc94d

12 files changed

Lines changed: 707 additions & 103 deletions

File tree

pkg/controller/trainjob_controller.go

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,6 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
104104
log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob))
105105
ctx = ctrl.LoggerInto(ctx, log)
106106
log.V(2).Info("Reconciling TrainJob")
107-
if isTrainJobFinished(&trainJob) {
108-
log.V(5).Info("TrainJob has already been finished")
109-
return ctrl.Result{}, nil
110-
}
111107

112108
var err error
113109
// Keep track of the origin TrainJob status
@@ -138,8 +134,9 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
138134
}
139135

140136
setSuspendedCondition(&trainJob)
141-
if terminalCondErr := setTerminalCondition(ctx, runtime, &trainJob); terminalCondErr != nil {
142-
err = errors.Join(err, terminalCondErr)
137+
138+
if statusErr := setTrainJobStatus(ctx, runtime, &trainJob); statusErr != nil {
139+
err = errors.Join(err, statusErr)
143140
}
144141

145142
if !equality.Semantic.DeepEqual(&trainJob.Status, originStatus) {
@@ -256,22 +253,17 @@ func removeFailedCondition(trainJob *trainer.TrainJob) {
256253
meta.RemoveStatusCondition(&trainJob.Status.Conditions, trainer.TrainJobFailed)
257254
}
258255

259-
func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error {
260-
terminalCond, err := runtime.TerminalCondition(ctx, trainJob)
256+
func setTrainJobStatus(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error {
257+
status, err := runtime.TrainJobStatus(ctx, trainJob)
261258
if err != nil {
262259
return err
263260
}
264-
if terminalCond != nil {
265-
meta.SetStatusCondition(&trainJob.Status.Conditions, *terminalCond)
261+
if status != nil {
262+
trainJob.Status = *status
266263
}
267264
return nil
268265
}
269266

270-
func isTrainJobFinished(trainJob *trainer.TrainJob) bool {
271-
return meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobComplete) ||
272-
meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobFailed)
273-
}
274-
275267
func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, options controller.Options) error {
276268
b := builder.TypedControllerManagedBy[reconcile.Request](mgr).
277269
Named("trainjob_controller").

pkg/runtime/core/clustertrainingruntime.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"errors"
2222
"fmt"
2323

24-
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2524
"k8s.io/apimachinery/pkg/runtime/schema"
2625
"k8s.io/apimachinery/pkg/util/validation/field"
2726
"sigs.k8s.io/controller-runtime/pkg/client"
@@ -71,8 +70,8 @@ func (r *ClusterTrainingRuntime) RuntimeInfo(
7170
return r.TrainingRuntime.RuntimeInfo(trainJob, runtimeTemplateSpec, mlPolicy, podGroupPolicy)
7271
}
7372

74-
func (r *ClusterTrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
75-
return r.TrainingRuntime.TerminalCondition(ctx, trainJob)
73+
func (r *ClusterTrainingRuntime) TrainJobStatus(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) {
74+
return r.TrainingRuntime.TrainJobStatus(ctx, trainJob)
7675
}
7776

7877
func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {

pkg/runtime/core/trainingruntime.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@ func syncPodSets(info *runtime.Info) {
246246
}
247247
}
248248

249-
func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
250-
return r.framework.RunTerminalConditionPlugins(ctx, trainJob)
249+
func (r *TrainingRuntime) TrainJobStatus(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) {
250+
return r.framework.RunTrainJobStatusPlugin(ctx, trainJob)
251251
}
252252

253253
func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {

pkg/runtime/framework/core/framework.go

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"context"
2121
"errors"
2222

23-
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2423
"k8s.io/apimachinery/pkg/util/validation/field"
2524
"sigs.k8s.io/controller-runtime/pkg/client"
2625
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
@@ -32,7 +31,7 @@ import (
3231
index "github.com/kubeflow/trainer/v2/pkg/runtime/indexer"
3332
)
3433

35-
var errorTooManyTerminalConditionPlugin = errors.New("too many TerminalCondition plugins are registered")
34+
var errorTooManyTrainJobStatusPlugin = errors.New("too many TrainJobStatus plugins are registered")
3635

3736
type Framework struct {
3837
registry fwkplugins.Registry
@@ -43,7 +42,7 @@ type Framework struct {
4342
watchExtensionPlugins []framework.WatchExtensionPlugin
4443
podNetworkPlugins []framework.PodNetworkPlugin
4544
componentBuilderPlugins []framework.ComponentBuilderPlugin
46-
terminalConditionPlugins []framework.TerminalConditionPlugin
45+
trainJobStatusPlugin framework.TrainJobStatusPlugin
4746
}
4847

4948
func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) {
@@ -79,8 +78,11 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl
7978
if p, ok := plugin.(framework.ComponentBuilderPlugin); ok {
8079
f.componentBuilderPlugins = append(f.componentBuilderPlugins, p)
8180
}
82-
if p, ok := plugin.(framework.TerminalConditionPlugin); ok {
83-
f.terminalConditionPlugins = append(f.terminalConditionPlugins, p)
81+
if p, ok := plugin.(framework.TrainJobStatusPlugin); ok {
82+
if f.trainJobStatusPlugin != nil {
83+
return nil, errorTooManyTrainJobStatusPlugin
84+
}
85+
f.trainJobStatusPlugin = p
8486
}
8587
}
8688
f.plugins = plugins
@@ -141,13 +143,9 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtim
141143
return objs, nil
142144
}
143145

144-
func (f *Framework) RunTerminalConditionPlugins(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
145-
// TODO (tenzen-y): Once we provide the Configuration API, we should validate which plugin should have terminalCondition execution points.
146-
if len(f.terminalConditionPlugins) > 1 {
147-
return nil, errorTooManyTerminalConditionPlugin
148-
}
149-
if len(f.terminalConditionPlugins) != 0 {
150-
return f.terminalConditionPlugins[0].TerminalCondition(ctx, trainJob)
146+
func (f *Framework) RunTrainJobStatusPlugin(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) {
147+
if f.trainJobStatusPlugin != nil {
148+
return f.trainJobStatusPlugin.Status(ctx, trainJob)
151149
}
152150
return nil, nil
153151
}

0 commit comments

Comments
 (0)