Skip to content

Commit d9ada32

Browse files
committed
Fail-fast when multiple terminal condition and JobsStatus plugins exist
Signed-off-by: Antonin Stefanutti <antonin@stefanutti.fr>
1 parent 389023c commit d9ada32

1 file changed

Lines changed: 14 additions & 15 deletions

File tree

pkg/runtime/framework/core/framework.go

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ type Framework struct {
4545
watchExtensionPlugins []framework.WatchExtensionPlugin
4646
podNetworkPlugins []framework.PodNetworkPlugin
4747
componentBuilderPlugins []framework.ComponentBuilderPlugin
48-
terminalConditionPlugins []framework.TerminalConditionPlugin
49-
jobsStatusPlugins []framework.JobsStatusPlugin
48+
terminalConditionPlugin framework.TerminalConditionPlugin
49+
jobsStatusPlugin framework.JobsStatusPlugin
5050
}
5151

5252
func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) {
@@ -80,10 +80,16 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl
8080
f.componentBuilderPlugins = append(f.componentBuilderPlugins, p)
8181
}
8282
if p, ok := plugin.(framework.TerminalConditionPlugin); ok {
83-
f.terminalConditionPlugins = append(f.terminalConditionPlugins, p)
83+
if f.terminalConditionPlugin != nil {
84+
return nil, errorTooManyTerminalConditionPlugin
85+
}
86+
f.terminalConditionPlugin = p
8487
}
8588
if p, ok := plugin.(framework.JobsStatusPlugin); ok {
86-
f.jobsStatusPlugins = append(f.jobsStatusPlugins, p)
89+
if f.jobsStatusPlugin != nil {
90+
return nil, errorTooManyJobsStatusPlugin
91+
}
92+
f.jobsStatusPlugin = p
8793
}
8894
}
8995
f.plugins = plugins
@@ -145,22 +151,15 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtim
145151
}
146152

147153
func (f *Framework) RunTerminalConditionPlugins(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
148-
// TODO (tenzen-y): Once we provide the Configuration API, we should validate which plugin should have terminalCondition execution points.
149-
if len(f.terminalConditionPlugins) > 1 {
150-
return nil, errorTooManyTerminalConditionPlugin
151-
}
152-
if len(f.terminalConditionPlugins) != 0 {
153-
return f.terminalConditionPlugins[0].TerminalCondition(ctx, trainJob)
154+
if f.terminalConditionPlugin != nil {
155+
return f.terminalConditionPlugin.TerminalCondition(ctx, trainJob)
154156
}
155157
return nil, nil
156158
}
157159

158160
func (f *Framework) RunJobsStatusPlugins(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) {
159-
if len(f.jobsStatusPlugins) > 1 {
160-
return nil, errorTooManyJobsStatusPlugin
161-
}
162-
if len(f.jobsStatusPlugins) != 0 {
163-
return f.jobsStatusPlugins[0].JobsStatus(ctx, trainJob)
161+
if f.jobsStatusPlugin != nil {
162+
return f.jobsStatusPlugin.JobsStatus(ctx, trainJob)
164163
}
165164
return nil, nil
166165
}

0 commit comments

Comments
 (0)