@@ -45,8 +45,8 @@ type Framework struct {
4545 watchExtensionPlugins []framework.WatchExtensionPlugin
4646 podNetworkPlugins []framework.PodNetworkPlugin
4747 componentBuilderPlugins []framework.ComponentBuilderPlugin
48- terminalConditionPlugins [] framework.TerminalConditionPlugin
49- jobsStatusPlugins [] framework.JobsStatusPlugin
48+ terminalConditionPlugin framework.TerminalConditionPlugin
49+ jobsStatusPlugin framework.JobsStatusPlugin
5050}
5151
5252func New (ctx context.Context , c client.Client , r fwkplugins.Registry , indexer client.FieldIndexer ) (* Framework , error ) {
@@ -80,10 +80,16 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl
8080 f .componentBuilderPlugins = append (f .componentBuilderPlugins , p )
8181 }
8282 if p , ok := plugin .(framework.TerminalConditionPlugin ); ok {
83- f .terminalConditionPlugins = append (f .terminalConditionPlugins , p )
83+ if f .terminalConditionPlugin != nil {
84+ return nil , errorTooManyTerminalConditionPlugin
85+ }
86+ f .terminalConditionPlugin = p
8487 }
8588 if p , ok := plugin .(framework.JobsStatusPlugin ); ok {
86- f .jobsStatusPlugins = append (f .jobsStatusPlugins , p )
89+ if f .jobsStatusPlugin != nil {
90+ return nil , errorTooManyJobsStatusPlugin
91+ }
92+ f .jobsStatusPlugin = p
8793 }
8894 }
8995 f .plugins = plugins
@@ -145,22 +151,15 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtim
145151}
146152
147153func (f * Framework ) RunTerminalConditionPlugins (ctx context.Context , trainJob * trainer.TrainJob ) (* metav1.Condition , error ) {
148- // TODO (tenzen-y): Once we provide the Configuration API, we should validate which plugin should have terminalCondition execution points.
149- if len (f .terminalConditionPlugins ) > 1 {
150- return nil , errorTooManyTerminalConditionPlugin
151- }
152- if len (f .terminalConditionPlugins ) != 0 {
153- return f .terminalConditionPlugins [0 ].TerminalCondition (ctx , trainJob )
154+ if f .terminalConditionPlugin != nil {
155+ return f .terminalConditionPlugin .TerminalCondition (ctx , trainJob )
154156 }
155157 return nil , nil
156158}
157159
158160func (f * Framework ) RunJobsStatusPlugins (ctx context.Context , trainJob * trainer.TrainJob ) ([]trainer.JobStatus , error ) {
159- if len (f .jobsStatusPlugins ) > 1 {
160- return nil , errorTooManyJobsStatusPlugin
161- }
162- if len (f .jobsStatusPlugins ) != 0 {
163- return f .jobsStatusPlugins [0 ].JobsStatus (ctx , trainJob )
161+ if f .jobsStatusPlugin != nil {
162+ return f .jobsStatusPlugin .JobsStatus (ctx , trainJob )
164163 }
165164 return nil , nil
166165}
0 commit comments