Skip to content

Commit c25b925

Browse files
committed
Fail-fast when multiple terminal condition and JobsStatus plugins exist
Signed-off-by: Antonin Stefanutti <antonin@stefanutti.fr>
1 parent 389023c commit c25b925

2 files changed

Lines changed: 24 additions & 23 deletions

File tree

pkg/runtime/framework/core/framework.go

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ type Framework struct {
4545
watchExtensionPlugins []framework.WatchExtensionPlugin
4646
podNetworkPlugins []framework.PodNetworkPlugin
4747
componentBuilderPlugins []framework.ComponentBuilderPlugin
48-
terminalConditionPlugins []framework.TerminalConditionPlugin
49-
jobsStatusPlugins []framework.JobsStatusPlugin
48+
terminalConditionPlugin framework.TerminalConditionPlugin
49+
jobsStatusPlugin framework.JobsStatusPlugin
5050
}
5151

5252
func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) {
@@ -80,10 +80,16 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl
8080
f.componentBuilderPlugins = append(f.componentBuilderPlugins, p)
8181
}
8282
if p, ok := plugin.(framework.TerminalConditionPlugin); ok {
83-
f.terminalConditionPlugins = append(f.terminalConditionPlugins, p)
83+
if f.terminalConditionPlugin != nil {
84+
return nil, errorTooManyTerminalConditionPlugin
85+
}
86+
f.terminalConditionPlugin = p
8487
}
8588
if p, ok := plugin.(framework.JobsStatusPlugin); ok {
86-
f.jobsStatusPlugins = append(f.jobsStatusPlugins, p)
89+
if f.jobsStatusPlugin != nil {
90+
return nil, errorTooManyJobsStatusPlugin
91+
}
92+
f.jobsStatusPlugin = p
8793
}
8894
}
8995
f.plugins = plugins
@@ -145,22 +151,15 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtim
145151
}
146152

147153
func (f *Framework) RunTerminalConditionPlugins(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
148-
// TODO (tenzen-y): Once we provide the Configuration API, we should validate which plugin should have terminalCondition execution points.
149-
if len(f.terminalConditionPlugins) > 1 {
150-
return nil, errorTooManyTerminalConditionPlugin
151-
}
152-
if len(f.terminalConditionPlugins) != 0 {
153-
return f.terminalConditionPlugins[0].TerminalCondition(ctx, trainJob)
154+
if f.terminalConditionPlugin != nil {
155+
return f.terminalConditionPlugin.TerminalCondition(ctx, trainJob)
154156
}
155157
return nil, nil
156158
}
157159

158160
func (f *Framework) RunJobsStatusPlugins(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) {
159-
if len(f.jobsStatusPlugins) > 1 {
160-
return nil, errorTooManyJobsStatusPlugin
161-
}
162-
if len(f.jobsStatusPlugins) != 0 {
163-
return f.jobsStatusPlugins[0].JobsStatus(ctx, trainJob)
161+
if f.jobsStatusPlugin != nil {
162+
return f.jobsStatusPlugin.JobsStatus(ctx, trainJob)
164163
}
165164
return nil, nil
166165
}

pkg/runtime/framework/core/framework_test.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,8 @@ func TestNew(t *testing.T) {
103103
&jobset.JobSet{},
104104
&mpi.MPI{},
105105
},
106-
terminalConditionPlugins: []framework.TerminalConditionPlugin{
107-
&jobset.JobSet{},
108-
},
109-
jobsStatusPlugins: []framework.JobsStatusPlugin{
110-
&jobset.JobSet{},
111-
},
106+
terminalConditionPlugin: &jobset.JobSet{},
107+
jobsStatusPlugin: &jobset.JobSet{},
112108
},
113109
},
114110
"indexer key for trainingRuntime and runtimeClass is an empty": {
@@ -1052,7 +1048,10 @@ func TestTerminalConditionPlugins(t *testing.T) {
10521048

10531049
fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder))
10541050
if err != nil {
1055-
t.Fatal(err)
1051+
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
1052+
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
1053+
}
1054+
return
10561055
}
10571056

10581057
gotCond, gotErr := fwk.RunTerminalConditionPlugins(ctx, tc.trainJob)
@@ -1206,7 +1205,10 @@ func TestJobsStatusPlugins(t *testing.T) {
12061205

12071206
fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder))
12081207
if err != nil {
1209-
t.Fatal(err)
1208+
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
1209+
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
1210+
}
1211+
return
12101212
}
12111213

12121214
gotStatuses, gotErr := fwk.RunJobsStatusPlugins(ctx, tc.trainJob)

0 commit comments

Comments
 (0)