Skip to content

Commit 8931933

Browse files
committed
Refactor header propagation to support multiple header filters
- Rename to headerFilter and switch to addHeaderFilter - Allow multiple filters combined with Predicate#or - Apply consistently to SendToMethodReturnValueHandler and SubscriptionMethodReturnValueHandler Signed-off-by: 김준환 <musoyou1085@gmail.com>
1 parent ad5e26d commit 8931933

File tree

4 files changed

+128
-74
lines changed

4 files changed

+128
-74
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,7 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
7474

7575
private @Nullable MessageHeaderInitializer headerInitializer;
7676

77-
78-
/**
79-
* Predicate to determine which header names from the input message should be propagated.
80-
* If null, no headers are propagated (default behavior).
81-
*/
82-
private @Nullable Predicate<String> headerPropagationPredicate;
77+
private @Nullable Predicate<String> headerFilter;
8378

8479

8580
public SendToMethodReturnValueHandler(SimpMessageSendingOperations messagingTemplate, boolean annotationRequired) {
@@ -142,21 +137,24 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia
142137
}
143138

144139
/**
145-
* Set a predicate to filter which headers from the input message should be propagated to the output message.
146-
* <p><b>Warning:</b> The predicate should avoid propagating or overwriting well-known protocol headers
147-
* (e.g. headers starting with "simp", "content-type", etc.) to prevent breaking internal messaging semantics.
148-
* </p>
140+
* Add a filter to determine which headers from the input message should be propagated to the output message.
141+
* Multiple filters are combined with logical OR.
149142
* <p>If not set, no input headers are propagated (default behavior).</p>
150143
*/
151-
public void setHeaderPropagationPredicate(@Nullable Predicate<String> predicate) {
152-
this.headerPropagationPredicate = predicate;
144+
public void addHeaderFilter(Predicate<String> filter) {
145+
Assert.notNull(filter, "Filter predicate must not be null");
146+
if (this.headerFilter == null) {
147+
this.headerFilter = filter;
148+
} else {
149+
this.headerFilter = this.headerFilter.or(filter);
150+
}
153151
}
154152

155153
/**
156-
* Return the configured header propagation predicate.
154+
* Return the configured header filter.
157155
*/
158-
public @Nullable Predicate<String> getHeaderPropagationPredicate() {
159-
return this.headerPropagationPredicate;
156+
public @Nullable Predicate<String> getHeaderFilter() {
157+
return this.headerFilter;
160158
}
161159

162160

@@ -265,22 +263,21 @@ private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter
265263
if (getHeaderInitializer() != null) {
266264
getHeaderInitializer().initHeaders(headerAccessor);
267265
}
268-
if (sessionId != null) {
269-
headerAccessor.setSessionId(sessionId);
270-
}
271-
headerAccessor.setHeader(AbstractMessageSendingTemplate.CONVERSION_HINT_HEADER, returnType);
272266

273-
// Header propagation policy
274-
if (inputMessage != null && headerPropagationPredicate != null) {
267+
if (inputMessage != null && headerFilter != null) {
275268
Map<String, Object> inputHeaders = inputMessage.getHeaders();
276269
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) {
277270
String name = entry.getKey();
278-
if (headerPropagationPredicate.test(name)) {
271+
if (headerFilter.test(name)) {
279272
headerAccessor.setHeader(name, entry.getValue());
280273
}
281274
}
282275
}
283276

277+
if (sessionId != null) {
278+
headerAccessor.setSessionId(sessionId);
279+
}
280+
headerAccessor.setHeader(AbstractMessageSendingTemplate.CONVERSION_HINT_HEADER, returnType);
284281
headerAccessor.setLeaveMutable(true);
285282
return headerAccessor.getMessageHeaders();
286283
}

spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,7 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
6868

6969
private @Nullable MessageHeaderInitializer headerInitializer;
7070

71-
/**
72-
* Predicate to determine which header names from the input message should be propagated.
73-
* If null, no headers are propagated (default behavior).
74-
*/
75-
private @Nullable Predicate<String> headerPropagationPredicate;
71+
private @Nullable Predicate<String> headerFilter;
7672

7773

7874
/**
@@ -103,21 +99,24 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia
10399
}
104100

105101
/**
106-
* Set a predicate to filter which headers from the input message should be propagated to the output message.
107-
* <p><b>Warning:</b> The predicate should avoid propagating or overwriting well-known protocol headers
108-
* (e.g. headers starting with "simp", "content-type", etc.) to prevent breaking internal messaging semantics.
109-
* </p>
102+
* Add a filter to determine which headers from the input message should be propagated to the output message.
103+
* Multiple filters are combined with logical OR.
110104
* <p>If not set, no input headers are propagated (default behavior).</p>
111105
*/
112-
public void setHeaderPropagationPredicate(@Nullable Predicate<String> predicate) {
113-
this.headerPropagationPredicate = predicate;
106+
public void addHeaderFilter(Predicate<String> filter) {
107+
Assert.notNull(filter, "Filter predicate must not be null");
108+
if (this.headerFilter == null) {
109+
this.headerFilter = filter;
110+
} else {
111+
this.headerFilter = this.headerFilter.or(filter);
112+
}
114113
}
115114

116115
/**
117-
* Return the configured header propagation predicate.
116+
* Return the configured header filter.
118117
*/
119-
public @Nullable Predicate<String> getHeaderPropagationPredicate() {
120-
return this.headerPropagationPredicate;
118+
public @Nullable Predicate<String> getHeaderFilter() {
119+
return this.headerFilter;
121120
}
122121

123122

@@ -162,22 +161,22 @@ private MessageHeaders createHeaders(@Nullable String sessionId, String subscrip
162161
if (getHeaderInitializer() != null) {
163162
getHeaderInitializer().initHeaders(accessor);
164163
}
165-
if (sessionId != null) {
166-
accessor.setSessionId(sessionId);
167-
}
168-
accessor.setSubscriptionId(subscriptionId);
169-
accessor.setHeader(AbstractMessageSendingTemplate.CONVERSION_HINT_HEADER, returnType);
170164

171-
if (inputMessage != null && headerPropagationPredicate != null) {
165+
if (inputMessage != null && headerFilter != null) {
172166
Map<String, Object> inputHeaders = inputMessage.getHeaders();
173167
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) {
174168
String name = entry.getKey();
175-
if (headerPropagationPredicate.test(name)) {
169+
if (headerFilter.test(name)) {
176170
accessor.setHeader(name, entry.getValue());
177171
}
178172
}
179173
}
180174

175+
if (sessionId != null) {
176+
accessor.setSessionId(sessionId);
177+
}
178+
accessor.setSubscriptionId(subscriptionId);
179+
accessor.setHeader(AbstractMessageSendingTemplate.CONVERSION_HINT_HEADER, returnType);
181180
accessor.setLeaveMutable(true);
182181
return accessor.getMessageHeaders();
183182
}

spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,19 +294,19 @@ public void sendToUserWithSendToOverride() throws Exception {
294294
}
295295

296296
@Test
297-
void sendToWithHeaderPropagationPredicate() throws Exception {
297+
void sendToWithHeaderFilterSinglePredicate() throws Exception {
298298
given(this.messageChannel.send(any(Message.class))).willReturn(true);
299299

300300
String sessionId = "sess1";
301301
String customHeaderName = "x-custom-header";
302302
String customHeaderValue = "custom-value";
303303
Message<?> inputMessage = createMessage(sessionId, "sub1", null, null, null);
304304
inputMessage = MessageBuilder.fromMessage(inputMessage)
305-
.setHeader(customHeaderName, customHeaderValue)
306-
.build();
305+
.setHeader(customHeaderName, customHeaderValue)
306+
.build();
307307

308308
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
309-
handler.setHeaderPropagationPredicate(name -> name.equals(customHeaderName));
309+
handler.addHeaderFilter(name -> name.equals(customHeaderName));
310310

311311
handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage);
312312

@@ -317,9 +317,36 @@ void sendToWithHeaderPropagationPredicate() throws Exception {
317317
}
318318
}
319319

320+
@Test
321+
void sendToWithHeaderFilterMultiplePredicates() throws Exception {
322+
given(this.messageChannel.send(any(Message.class))).willReturn(true);
323+
324+
String sessionId = "sess1";
325+
String headerA = "x-header-a";
326+
String headerB = "x-header-b";
327+
Message<?> inputMessage = createMessage(sessionId, "sub1", null, null, null);
328+
inputMessage = MessageBuilder.fromMessage(inputMessage)
329+
.setHeader(headerA, "A-value")
330+
.setHeader(headerB, "B-value")
331+
.build();
332+
333+
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
334+
handler.addHeaderFilter(name -> name.equals(headerA));
335+
handler.addHeaderFilter(name -> name.equals(headerB));
336+
337+
handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage);
338+
339+
verify(this.messageChannel, times(2)).send(this.messageCaptor.capture());
340+
for (Message<?> sent : this.messageCaptor.getAllValues()) {
341+
MessageHeaders headers = sent.getHeaders();
342+
assertThat(headers.get(headerA)).isEqualTo("A-value");
343+
assertThat(headers.get(headerB)).isEqualTo("B-value");
344+
}
345+
}
346+
320347

321348
private void assertResponse(MethodParameter methodParameter, String sessionId,
322-
int index, String destination) {
349+
int index, String destination) {
323350

324351
SimpMessageHeaderAccessor accessor = getCapturedAccessor(index);
325352
assertThat(accessor.getSessionId()).isEqualTo(sessionId);

spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -187,32 +187,63 @@ void testJsonView() throws Exception {
187187
}
188188

189189
@Test
190-
void testHeaderPropagationPredicate() throws Exception {
191-
String sessionId = "sess1";
192-
String subscriptionId = "subs1";
193-
String destination = "/dest";
194-
String customHeaderName = "x-custom-header";
195-
String customHeaderValue = "custom-value";
196-
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
197-
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
198-
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
199-
.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, destination)
200-
.setHeader(customHeaderName, customHeaderValue)
201-
.build();
202-
203-
MessageSendingOperations messagingTemplate = mock();
204-
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);
205-
206-
handler.setHeaderPropagationPredicate(name -> name.equals(customHeaderName));
207-
208-
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
209-
210-
ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
211-
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
212-
213-
MessageHeaders sentHeaders = captor.getValue();
214-
assertThat(sentHeaders.get(customHeaderName)).isEqualTo(customHeaderValue);
215-
}
190+
void testHeaderFilterSinglePredicate() throws Exception {
191+
String sessionId = "sess1";
192+
String subscriptionId = "subs1";
193+
String destination = "/dest";
194+
String customHeaderName = "x-custom-header";
195+
String customHeaderValue = "custom-value";
196+
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
197+
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
198+
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
199+
.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, destination)
200+
.setHeader(customHeaderName, customHeaderValue)
201+
.build();
202+
203+
MessageSendingOperations messagingTemplate = mock();
204+
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);
205+
206+
handler.addHeaderFilter(name -> name.equals(customHeaderName));
207+
208+
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
209+
210+
ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
211+
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
212+
213+
MessageHeaders sentHeaders = captor.getValue();
214+
assertThat(sentHeaders.get(customHeaderName)).isEqualTo(customHeaderValue);
215+
}
216+
217+
@Test
218+
void testHeaderFilterMultiplePredicates() throws Exception {
219+
String sessionId = "sess1";
220+
String subscriptionId = "subs1";
221+
String destination = "/dest";
222+
String headerA = "x-header-a";
223+
String headerB = "x-header-b";
224+
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
225+
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
226+
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
227+
.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, destination)
228+
.setHeader(headerA, "A-value")
229+
.setHeader(headerB, "B-value")
230+
.build();
231+
232+
MessageSendingOperations messagingTemplate = mock();
233+
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);
234+
235+
handler.addHeaderFilter(name -> name.equals(headerA));
236+
handler.addHeaderFilter(name -> name.equals(headerB));
237+
238+
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
239+
240+
ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
241+
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
242+
243+
MessageHeaders sentHeaders = captor.getValue();
244+
assertThat(sentHeaders.get(headerA)).isEqualTo("A-value");
245+
assertThat(sentHeaders.get(headerB)).isEqualTo("B-value");
246+
}
216247

217248

218249
private Message<?> createInputMessage(String sessId, String subsId, String dest, Principal principal) {

0 commit comments

Comments
 (0)