Skip to content

Commit 45afa3c

Browse files
authored
support for grpc static_metadata, tweak http static_headers behaviour to match the doc (#1076)
1 parent 724bb05 commit 45afa3c

File tree

4 files changed

+227
-6
lines changed

4 files changed

+227
-6
lines changed

internal/configtypes/types.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,10 @@ type ProxyCommonGRPC struct {
633633
CredentialsValue string `mapstructure:"credentials_value" json:"credentials_value" envconfig:"credentials_value" yaml:"credentials_value" toml:"credentials_value"`
634634
// Compression enables compression for outgoing calls (gzip).
635635
Compression bool `mapstructure:"compression" json:"compression" envconfig:"compression" yaml:"compression" toml:"compression"`
636+
// StaticMetadata is a static set of key/value pairs to attach to GRPC proxy request as
637+
// metadata. Headers received from HTTP client request or metadata from GRPC client request
638+
// both have priority over values set in StaticMetadata map (but only if explicitly allowed).
639+
StaticMetadata MapStringString `mapstructure:"static_metadata" default:"{}" json:"static_metadata" envconfig:"static_metadata" yaml:"static_metadata" toml:"static_metadata"`
636640
}
637641

638642
type ProxyCommon struct {

internal/proxy/grpc.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,18 @@ func getDialOpts(name string, p Config) ([]grpc.DialOption, error) {
7575
}
7676

7777
func grpcRequestContext(ctx context.Context, proxy Config) context.Context {
78-
md := requestMetadata(ctx, proxy.HttpHeaders, proxy.GrpcMetadata)
78+
md := requestMetadata(ctx, proxy.HttpHeaders, proxy.GrpcMetadata, proxy.GRPC.StaticMetadata)
7979
return metadata.NewOutgoingContext(ctx, md)
8080
}
8181

82-
func requestMetadata(ctx context.Context, allowedHeaders []string, allowedMetaKeys []string) metadata.MD {
82+
func requestMetadata(ctx context.Context, allowedHeaders []string, allowedMetaKeys []string, staticMetadata map[string]string) metadata.MD {
8383
requestMD := metadata.MD{}
8484

85+
// Set static metadata first, so that dynamic metadata can override it.
86+
for k, v := range staticMetadata {
87+
requestMD.Set(k, v)
88+
}
89+
8590
emulatedHeaders, _ := clientcontext.GetEmulatedHeadersFromContext(ctx)
8691
for k, v := range emulatedHeaders {
8792
if slices.Contains(allowedHeaders, strings.ToLower(k)) {

internal/proxy/http.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ func httpRequestHeaders(ctx context.Context, proxy Config) http.Header {
131131
func requestHeaders(ctx context.Context, allowedHeaders, allowedMetaKeys []string, staticHeaders map[string]string) http.Header {
132132
headers := http.Header{}
133133

134+
// Set static headers first, so that dynamic headers can override them.
135+
for k, v := range staticHeaders {
136+
headers.Set(k, v)
137+
}
138+
134139
emulatedHeaders, _ := clientcontext.GetEmulatedHeadersFromContext(ctx)
135140
for k, v := range emulatedHeaders {
136141
if slices.Contains(allowedHeaders, strings.ToLower(k)) {
@@ -154,10 +159,6 @@ func requestHeaders(ctx context.Context, allowedHeaders, allowedMetaKeys []strin
154159
}
155160
}
156161

157-
for k, v := range staticHeaders {
158-
headers.Set(k, v)
159-
}
160-
161162
headers.Set("Content-Type", "application/json")
162163

163164
return headers

internal/proxy/proxy_test.go

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
package proxy
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"testing"
7+
8+
"github.com/centrifugal/centrifugo/v6/internal/middleware"
9+
10+
"github.com/stretchr/testify/require"
11+
"google.golang.org/grpc/metadata"
12+
)
13+
14+
// TestRequestHeaders_StaticHeadersOverride tests that static headers are set
15+
// but can be overridden by client headers only when the header key is explicitly
16+
// allowed in the configuration.
17+
func TestRequestHeaders_StaticHeadersOverride(t *testing.T) {
18+
tests := []struct {
19+
name string
20+
staticHeaders map[string]string
21+
clientHeaders http.Header
22+
allowedHeaders []string
23+
expectedResult map[string]string
24+
}{
25+
{
26+
name: "static header not overridden when key not allowed",
27+
staticHeaders: map[string]string{
28+
"X-Static-Header": "static-value",
29+
},
30+
clientHeaders: http.Header{
31+
"X-Static-Header": []string{"client-value"},
32+
},
33+
allowedHeaders: []string{}, // No headers allowed
34+
expectedResult: map[string]string{
35+
"X-Static-Header": "static-value",
36+
"Content-Type": "application/json",
37+
},
38+
},
39+
{
40+
name: "static header overridden when key is allowed",
41+
staticHeaders: map[string]string{
42+
"X-Static-Header": "static-value",
43+
},
44+
clientHeaders: http.Header{
45+
"X-Static-Header": []string{"client-value"},
46+
},
47+
allowedHeaders: []string{"x-static-header"}, // Header is allowed
48+
expectedResult: map[string]string{
49+
"X-Static-Header": "client-value",
50+
"Content-Type": "application/json",
51+
},
52+
},
53+
{
54+
name: "multiple static headers with partial override",
55+
staticHeaders: map[string]string{
56+
"X-Static-1": "static-1",
57+
"X-Static-2": "static-2",
58+
},
59+
clientHeaders: http.Header{
60+
"X-Static-1": []string{"client-1"},
61+
"X-Static-2": []string{"client-2"},
62+
},
63+
allowedHeaders: []string{"x-static-1"}, // Only first header is allowed
64+
expectedResult: map[string]string{
65+
"X-Static-1": "client-1", // Overridden (allowed)
66+
"X-Static-2": "static-2", // Not overridden (not allowed)
67+
"Content-Type": "application/json",
68+
},
69+
},
70+
{
71+
name: "client header not in static set but allowed",
72+
staticHeaders: map[string]string{
73+
"X-Static-Header": "static-value",
74+
},
75+
clientHeaders: http.Header{
76+
"X-Client-Only": []string{"client-only-value"},
77+
},
78+
allowedHeaders: []string{"x-client-only"},
79+
expectedResult: map[string]string{
80+
"X-Static-Header": "static-value",
81+
"X-Client-Only": "client-only-value",
82+
"Content-Type": "application/json",
83+
},
84+
},
85+
{
86+
name: "no static headers",
87+
staticHeaders: map[string]string{},
88+
clientHeaders: http.Header{
89+
"X-Client-Header": []string{"client-value"},
90+
},
91+
allowedHeaders: []string{"x-client-header"},
92+
expectedResult: map[string]string{
93+
"X-Client-Header": "client-value",
94+
"Content-Type": "application/json",
95+
},
96+
},
97+
}
98+
99+
for _, tt := range tests {
100+
t.Run(tt.name, func(t *testing.T) {
101+
ctx := context.Background()
102+
ctx = middleware.SetHeadersToContext(ctx, tt.clientHeaders)
103+
104+
result := requestHeaders(ctx, tt.allowedHeaders, []string{}, tt.staticHeaders)
105+
106+
require.Equal(t, len(tt.expectedResult), len(result), "number of headers should match")
107+
for expectedKey, expectedValue := range tt.expectedResult {
108+
require.Equal(t, expectedValue, result.Get(expectedKey),
109+
"header %s should have value %s", expectedKey, expectedValue)
110+
}
111+
})
112+
}
113+
}
114+
115+
// TestRequestMetadata_StaticMetadataOverride tests that static metadata is set
116+
// but can be overridden by client metadata only when the metadata key is explicitly
117+
// allowed in the configuration.
118+
func TestRequestMetadata_StaticMetadataOverride(t *testing.T) {
119+
tests := []struct {
120+
name string
121+
staticMetadata map[string]string
122+
clientMetadata metadata.MD
123+
allowedMetaKeys []string
124+
expectedResult map[string]string
125+
}{
126+
{
127+
name: "static metadata not overridden when key not allowed",
128+
staticMetadata: map[string]string{
129+
"x-static-meta": "static-value",
130+
},
131+
clientMetadata: metadata.MD{
132+
"x-static-meta": []string{"client-value"},
133+
},
134+
allowedMetaKeys: []string{}, // No metadata keys allowed
135+
expectedResult: map[string]string{
136+
"x-static-meta": "static-value",
137+
},
138+
},
139+
{
140+
name: "static metadata overridden when key is allowed",
141+
staticMetadata: map[string]string{
142+
"x-static-meta": "static-value",
143+
},
144+
clientMetadata: metadata.MD{
145+
"x-static-meta": []string{"client-value"},
146+
},
147+
allowedMetaKeys: []string{"x-static-meta"}, // Metadata key is allowed
148+
expectedResult: map[string]string{
149+
"x-static-meta": "client-value",
150+
},
151+
},
152+
{
153+
name: "multiple static metadata with partial override",
154+
staticMetadata: map[string]string{
155+
"x-static-1": "static-1",
156+
"x-static-2": "static-2",
157+
},
158+
clientMetadata: metadata.MD{
159+
"x-static-1": []string{"client-1"},
160+
"x-static-2": []string{"client-2"},
161+
},
162+
allowedMetaKeys: []string{"x-static-1"}, // Only first key is allowed
163+
expectedResult: map[string]string{
164+
"x-static-1": "client-1", // Overridden (allowed)
165+
"x-static-2": "static-2", // Not overridden (not allowed)
166+
},
167+
},
168+
{
169+
name: "client metadata not in static set but allowed",
170+
staticMetadata: map[string]string{
171+
"x-static-meta": "static-value",
172+
},
173+
clientMetadata: metadata.MD{
174+
"x-client-only": []string{"client-only-value"},
175+
},
176+
allowedMetaKeys: []string{"x-client-only"},
177+
expectedResult: map[string]string{
178+
"x-static-meta": "static-value",
179+
"x-client-only": "client-only-value",
180+
},
181+
},
182+
{
183+
name: "no static metadata",
184+
staticMetadata: map[string]string{},
185+
clientMetadata: metadata.MD{
186+
"x-client-meta": []string{"client-value"},
187+
},
188+
allowedMetaKeys: []string{"x-client-meta"},
189+
expectedResult: map[string]string{
190+
"x-client-meta": "client-value",
191+
},
192+
},
193+
}
194+
195+
for _, tt := range tests {
196+
t.Run(tt.name, func(t *testing.T) {
197+
ctx := context.Background()
198+
ctx = metadata.NewIncomingContext(ctx, tt.clientMetadata)
199+
200+
result := requestMetadata(ctx, []string{}, tt.allowedMetaKeys, tt.staticMetadata)
201+
202+
require.Equal(t, len(tt.expectedResult), len(result), "number of metadata entries should match")
203+
for expectedKey, expectedValue := range tt.expectedResult {
204+
values := result.Get(expectedKey)
205+
require.Len(t, values, 1, "metadata key %s should have exactly one value", expectedKey)
206+
require.Equal(t, expectedValue, values[0],
207+
"metadata %s should have value %s", expectedKey, expectedValue)
208+
}
209+
})
210+
}
211+
}

0 commit comments

Comments
 (0)