Skip to content

Commit 8d77ba9

Browse files
committed
Validate all producer consumer plugins
1 parent 5df0acd commit 8d77ba9

File tree

3 files changed

+69
-15
lines changed

3 files changed

+69
-15
lines changed

pkg/epp/requestcontrol/dag.go

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,39 @@ import (
2020
"errors"
2121
"slices"
2222

23+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
24+
2325
fwk "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
2426
)
2527

2628
// buildDAG builds a dependency graph among data preparation plugins based on their
2729
// produced and consumed data keys.
28-
func buildDAG(plugins []fwk.PrepareDataPlugin) (map[string][]string, error) {
30+
func buildDAG(producers map[string]plugin.ProducerPlugin, consumers map[string]plugin.ConsumerPlugin) (map[string][]string, error) {
2931
dag := make(map[string][]string)
30-
for _, plugin := range plugins {
31-
dag[plugin.TypedName().String()] = []string{}
32-
}
3332
// Create dependency graph as a DAG.
34-
for i := range plugins {
35-
for j := range plugins {
36-
if i == j {
33+
for _, producer := range producers {
34+
dag[producer.TypedName().String()] = []string{}
35+
}
36+
for _, consumer := range consumers {
37+
dag[consumer.TypedName().String()] = []string{}
38+
}
39+
for pName, producer := range producers {
40+
for cName, consumer := range consumers {
41+
if pName == cName {
3742
continue
3843
}
3944
// Check whether plugin[i] produces something consumed by plugin[j]. In that case, j depends on i.
40-
if plugins[i].Produces() != nil && plugins[j].Consumes() != nil {
41-
for producedKey, producedData := range plugins[i].Produces() {
45+
if producer.Produces() != nil && consumer.Consumes() != nil {
46+
for producedKey, producedData := range producer.Produces() {
4247
// If plugin j consumes the produced key, then j depends on i. We can break after the first match.
43-
if consumedData, ok := plugins[j].Consumes()[producedKey]; ok {
48+
if consumedData, ok := consumer.Consumes()[producedKey]; ok {
4449
// Check types are same. Reflection is avoided here for simplicity.
4550
// TODO(#1985): Document this detail in IGW docs.
4651
if producedData != consumedData {
4752
return nil, errors.New("data type mismatch between produced and consumed data for key: " + producedKey)
4853
}
49-
iPluginName := plugins[i].TypedName().String()
50-
jPluginName := plugins[j].TypedName().String()
51-
dag[jPluginName] = append(dag[jPluginName], iPluginName)
54+
// Consumer depends on producer, so add an edge from consumer to producer.
55+
dag[cName] = append(dag[cName], pName)
5256
break
5357
}
5458
}

pkg/epp/requestcontrol/dag_test.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,17 @@ func TestPrepareDataGraph(t *testing.T) {
124124

125125
for _, tc := range testCases {
126126
t.Run(tc.name, func(t *testing.T) {
127-
dag, err := buildDAG(tc.plugins)
127+
producers := make(map[string]fwkplugin.ProducerPlugin)
128+
consumers := make(map[string]fwkplugin.ConsumerPlugin)
129+
for _, p := range tc.plugins {
130+
if pp, ok := p.(fwkplugin.ProducerPlugin); ok {
131+
producers[p.TypedName().String()] = pp
132+
}
133+
if cp, ok := p.(fwkplugin.ConsumerPlugin); ok {
134+
consumers[p.TypedName().String()] = cp
135+
}
136+
}
137+
dag, err := buildDAG(producers, consumers)
128138
if err != nil {
129139
if tc.expectError {
130140
assert.Error(t, err)

pkg/epp/requestcontrol/request_control_config.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,54 @@ func (c *Config) AddPlugins(pluginObjects ...plugin.Plugin) {
106106
}
107107
}
108108

109+
// ProducerConsumerPlugins iterates through all registered plugins and returns two slices:
110+
// one containing the names of plugins that implement the ProducerPlugin interface,
111+
// and another for plugins that implement the ConsumerPlugin interface.
112+
func (c *Config) ProducerConsumerPlugins() (map[string]plugin.ProducerPlugin, map[string]plugin.ConsumerPlugin) {
113+
var producers map[string]plugin.ProducerPlugin
114+
var consumers map[string]plugin.ConsumerPlugin
115+
116+
// Collect all unique plugins from the config.
117+
allPlugins := make(map[string]plugin.Plugin)
118+
for _, p := range c.admissionPlugins {
119+
allPlugins[p.TypedName().String()] = p
120+
}
121+
for _, p := range c.prepareDataPlugins {
122+
allPlugins[p.TypedName().String()] = p
123+
}
124+
for _, p := range c.preRequestPlugins {
125+
allPlugins[p.TypedName().String()] = p
126+
}
127+
for _, p := range c.responseReceivedPlugins {
128+
allPlugins[p.TypedName().String()] = p
129+
}
130+
for _, p := range c.responseStreamingPlugins {
131+
allPlugins[p.TypedName().String()] = p
132+
}
133+
for _, p := range c.responseCompletePlugins {
134+
allPlugins[p.TypedName().String()] = p
135+
}
136+
137+
for name, p := range allPlugins {
138+
if producer, ok := p.(plugin.ProducerPlugin); ok {
139+
producers[name] = producer
140+
}
141+
if consumer, ok := p.(plugin.ConsumerPlugin); ok {
142+
consumers[name] = consumer
143+
}
144+
}
145+
return producers, consumers
146+
}
147+
109148
// PrepareDataPluginGraph creates data dependency graph and sorts the plugins in topological order.
110149
// If a cycle is detected, it returns an error.
111150
func (c *Config) PrepareDataPluginGraph() error {
151+
producers, consumers := c.ProducerConsumerPlugins()
112152
// TODO(#1988): Add all producer and consumer plugins to the graph.
113153
if len(c.prepareDataPlugins) == 0 {
114154
return nil
115155
}
116-
dag, err := buildDAG(c.prepareDataPlugins)
156+
dag, err := buildDAG(producers, consumers)
117157
if err != nil {
118158
return err
119159
}

0 commit comments

Comments
 (0)