Skip to content

Commit 5803013

Browse files
committed
Validate all producer consumer plugins
1 parent 5df0acd commit 5803013

File tree

3 files changed

+68
-15
lines changed

3 files changed

+68
-15
lines changed

pkg/epp/requestcontrol/dag.go

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,39 @@ import (
2020
"errors"
2121
"slices"
2222

23+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
24+
2325
fwk "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
2426
)
2527

2628
// buildDAG builds a dependency graph among data preparation plugins based on their
2729
// produced and consumed data keys.
28-
func buildDAG(plugins []fwk.PrepareDataPlugin) (map[string][]string, error) {
30+
func buildDAG(producers map[string]plugin.ProducerPlugin, consumers map[string]plugin.ConsumerPlugin) (map[string][]string, error) {
2931
dag := make(map[string][]string)
30-
for _, plugin := range plugins {
31-
dag[plugin.TypedName().String()] = []string{}
32-
}
3332
// Create dependency graph as a DAG.
34-
for i := range plugins {
35-
for j := range plugins {
36-
if i == j {
33+
for _, producer := range producers {
34+
dag[producer.TypedName().String()] = []string{}
35+
}
36+
for _, consumer := range consumers {
37+
dag[consumer.TypedName().String()] = []string{}
38+
}
39+
for pName, producer := range producers {
40+
for cName, consumer := range consumers {
41+
if pName == cName {
3742
continue
3843
}
3944
// Check whether plugin[i] produces something consumed by plugin[j]. In that case, j depends on i.
40-
if plugins[i].Produces() != nil && plugins[j].Consumes() != nil {
41-
for producedKey, producedData := range plugins[i].Produces() {
45+
if producer.Produces() != nil && consumer.Consumes() != nil {
46+
for producedKey, producedData := range producer.Produces() {
4247
// If plugin j consumes the produced key, then j depends on i. We can break after the first match.
43-
if consumedData, ok := plugins[j].Consumes()[producedKey]; ok {
48+
if consumedData, ok := consumer.Consumes()[producedKey]; ok {
4449
// Check types are same. Reflection is avoided here for simplicity.
4550
// TODO(#1985): Document this detail in IGW docs.
4651
if producedData != consumedData {
4752
return nil, errors.New("data type mismatch between produced and consumed data for key: " + producedKey)
4853
}
49-
iPluginName := plugins[i].TypedName().String()
50-
jPluginName := plugins[j].TypedName().String()
51-
dag[jPluginName] = append(dag[jPluginName], iPluginName)
54+
// Consumer depends on producer, so add an edge from consumer to producer.
55+
dag[cName] = append(dag[cName], pName)
5256
break
5357
}
5458
}

pkg/epp/requestcontrol/dag_test.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,17 @@ func TestPrepareDataGraph(t *testing.T) {
124124

125125
for _, tc := range testCases {
126126
t.Run(tc.name, func(t *testing.T) {
127-
dag, err := buildDAG(tc.plugins)
127+
producers := make(map[string]fwkplugin.ProducerPlugin)
128+
consumers := make(map[string]fwkplugin.ConsumerPlugin)
129+
for _, p := range tc.plugins {
130+
if pp, ok := p.(fwkplugin.ProducerPlugin); ok {
131+
producers[p.TypedName().String()] = pp
132+
}
133+
if cp, ok := p.(fwkplugin.ConsumerPlugin); ok {
134+
consumers[p.TypedName().String()] = cp
135+
}
136+
}
137+
dag, err := buildDAG(producers, consumers)
128138
if err != nil {
129139
if tc.expectError {
130140
assert.Error(t, err)

pkg/epp/requestcontrol/request_control_config.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,53 @@ func (c *Config) AddPlugins(pluginObjects ...plugin.Plugin) {
106106
}
107107
}
108108

109+
// ProducerConsumerPlugins iterates through all registered plugins and returns two slices:
110+
// one containing the names of plugins that implement the ProducerPlugin interface,
111+
// and another for plugins that implement the ConsumerPlugin interface.
112+
func (c *Config) ProducerConsumerPlugins() (map[string]plugin.ProducerPlugin, map[string]plugin.ConsumerPlugin) {
113+
producers := make(map[string]plugin.ProducerPlugin)
114+
consumers := make(map[string]plugin.ConsumerPlugin)
115+
// Collect all unique plugins from the config.
116+
allPlugins := make(map[string]plugin.Plugin)
117+
for _, p := range c.admissionPlugins {
118+
allPlugins[p.TypedName().String()] = p
119+
}
120+
for _, p := range c.prepareDataPlugins {
121+
allPlugins[p.TypedName().String()] = p
122+
}
123+
for _, p := range c.preRequestPlugins {
124+
allPlugins[p.TypedName().String()] = p
125+
}
126+
for _, p := range c.responseReceivedPlugins {
127+
allPlugins[p.TypedName().String()] = p
128+
}
129+
for _, p := range c.responseStreamingPlugins {
130+
allPlugins[p.TypedName().String()] = p
131+
}
132+
for _, p := range c.responseCompletePlugins {
133+
allPlugins[p.TypedName().String()] = p
134+
}
135+
136+
for name, p := range allPlugins {
137+
if producer, ok := p.(plugin.ProducerPlugin); ok {
138+
producers[name] = producer
139+
}
140+
if consumer, ok := p.(plugin.ConsumerPlugin); ok {
141+
consumers[name] = consumer
142+
}
143+
}
144+
return producers, consumers
145+
}
146+
109147
// PrepareDataPluginGraph creates data dependency graph and sorts the plugins in topological order.
110148
// If a cycle is detected, it returns an error.
111149
func (c *Config) PrepareDataPluginGraph() error {
150+
producers, consumers := c.ProducerConsumerPlugins()
112151
// TODO(#1988): Add all producer and consumer plugins to the graph.
113152
if len(c.prepareDataPlugins) == 0 {
114153
return nil
115154
}
116-
dag, err := buildDAG(c.prepareDataPlugins)
155+
dag, err := buildDAG(producers, consumers)
117156
if err != nil {
118157
return err
119158
}

0 commit comments

Comments
 (0)