Skip to content

Commit 0627204

Browse files
committed
Add header propagation predicate support to message return value handlers
Signed-off-by: 김준환 <musoyou1085@gmail.com>
1 parent 7917ae5 commit 0627204

File tree

4 files changed

+189
-7
lines changed

4 files changed

+189
-7
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.springframework.util.PropertyPlaceholderHelper;
4646
import org.springframework.util.PropertyPlaceholderHelper.PlaceholderResolver;
4747
import org.springframework.util.StringUtils;
48+
import java.util.function.Predicate;
4849

4950
/**
5051
* A {@link HandlerMethodReturnValueHandler} for sending to destinations specified in a
@@ -73,6 +74,8 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
7374

7475
private @Nullable MessageHeaderInitializer headerInitializer;
7576

77+
private @Nullable Predicate<String> headerFilter;
78+
7679

7780
public SendToMethodReturnValueHandler(SimpMessageSendingOperations messagingTemplate, boolean annotationRequired) {
7881
Assert.notNull(messagingTemplate, "'messagingTemplate' must not be null");
@@ -133,6 +136,27 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia
133136
return this.headerInitializer;
134137
}
135138

139+
/**
140+
* Add a filter to determine which headers from the input message should be propagated to the output message.
141+
* Multiple filters are combined with logical OR.
142+
* <p>If not set, no input headers are propagated (default behavior).</p>
143+
*/
144+
public void addHeaderFilter(Predicate<String> filter) {
145+
Assert.notNull(filter, "Filter predicate must not be null");
146+
if (this.headerFilter == null) {
147+
this.headerFilter = filter;
148+
} else {
149+
this.headerFilter = this.headerFilter.or(filter);
150+
}
151+
}
152+
153+
/**
154+
* Return the configured header filter.
155+
*/
156+
public @Nullable Predicate<String> getHeaderFilter() {
157+
return this.headerFilter;
158+
}
159+
136160

137161
@Override
138162
public boolean supportsReturnType(MethodParameter returnType) {
@@ -171,11 +195,11 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu
171195
destination = destinationHelper.expandTemplateVars(destination);
172196
if (broadcast) {
173197
this.messagingTemplate.convertAndSendToUser(
174-
user, destination, returnValue, createHeaders(null, returnType));
198+
user, destination, returnValue, createHeaders(null, returnType, message));
175199
}
176200
else {
177201
this.messagingTemplate.convertAndSendToUser(
178-
user, destination, returnValue, createHeaders(sessionId, returnType));
202+
user, destination, returnValue, createHeaders(sessionId, returnType, message));
179203
}
180204
}
181205
}
@@ -185,7 +209,7 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu
185209
String[] destinations = getTargetDestinations(sendTo, message, this.defaultDestinationPrefix);
186210
for (String destination : destinations) {
187211
destination = destinationHelper.expandTemplateVars(destination);
188-
this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId, returnType));
212+
this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId, returnType, message));
189213
}
190214
}
191215
}
@@ -234,11 +258,22 @@ protected String[] getTargetDestinations(@Nullable Annotation annotation, Messag
234258
new String[] {defaultPrefix + destination} : new String[] {defaultPrefix + '/' + destination});
235259
}
236260

237-
private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter returnType) {
261+
private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter returnType, @Nullable Message<?> inputMessage) {
238262
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
239263
if (getHeaderInitializer() != null) {
240264
getHeaderInitializer().initHeaders(headerAccessor);
241265
}
266+
267+
if (inputMessage != null && headerFilter != null) {
268+
Map<String, Object> inputHeaders = inputMessage.getHeaders();
269+
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) {
270+
String name = entry.getKey();
271+
if (headerFilter.test(name)) {
272+
headerAccessor.setHeader(name, entry.getValue());
273+
}
274+
}
275+
}
276+
242277
if (sessionId != null) {
243278
headerAccessor.setSessionId(sessionId);
244279
}

spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
package org.springframework.messaging.simp.annotation.support;
1818

19+
import java.util.Map;
20+
import java.util.function.Predicate;
21+
1922
import org.apache.commons.logging.Log;
2023
import org.jspecify.annotations.Nullable;
2124

@@ -65,6 +68,8 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
6568

6669
private @Nullable MessageHeaderInitializer headerInitializer;
6770

71+
private @Nullable Predicate<String> headerFilter;
72+
6873

6974
/**
7075
* Construct a new SubscriptionMethodReturnValueHandler.
@@ -93,6 +98,27 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia
9398
return this.headerInitializer;
9499
}
95100

101+
/**
102+
* Add a filter to determine which headers from the input message should be propagated to the output message.
103+
* Multiple filters are combined with logical OR.
104+
* <p>If not set, no input headers are propagated (default behavior).</p>
105+
*/
106+
public void addHeaderFilter(Predicate<String> filter) {
107+
Assert.notNull(filter, "Filter predicate must not be null");
108+
if (this.headerFilter == null) {
109+
this.headerFilter = filter;
110+
} else {
111+
this.headerFilter = this.headerFilter.or(filter);
112+
}
113+
}
114+
115+
/**
116+
* Return the configured header filter.
117+
*/
118+
public @Nullable Predicate<String> getHeaderFilter() {
119+
return this.headerFilter;
120+
}
121+
96122

97123
@Override
98124
public boolean supportsReturnType(MethodParameter returnType) {
@@ -126,15 +152,26 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu
126152
if (logger.isDebugEnabled()) {
127153
logger.debug("Reply to @SubscribeMapping: " + returnValue);
128154
}
129-
MessageHeaders headersToSend = createHeaders(sessionId, subscriptionId, returnType);
155+
MessageHeaders headersToSend = createHeaders(sessionId, subscriptionId, returnType, message);
130156
this.messagingTemplate.convertAndSend(destination, returnValue, headersToSend);
131157
}
132158

133-
private MessageHeaders createHeaders(@Nullable String sessionId, String subscriptionId, MethodParameter returnType) {
159+
private MessageHeaders createHeaders(@Nullable String sessionId, String subscriptionId, MethodParameter returnType, @Nullable Message<?> inputMessage) {
134160
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
135161
if (getHeaderInitializer() != null) {
136162
getHeaderInitializer().initHeaders(accessor);
137163
}
164+
165+
if (inputMessage != null && headerFilter != null) {
166+
Map<String, Object> inputHeaders = inputMessage.getHeaders();
167+
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) {
168+
String name = entry.getKey();
169+
if (headerFilter.test(name)) {
170+
accessor.setHeader(name, entry.getValue());
171+
}
172+
}
173+
}
174+
138175
if (sessionId != null) {
139176
accessor.setSessionId(sessionId);
140177
}

spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,60 @@ public void sendToUserWithSendToOverride() throws Exception {
293293
assertResponse(parameter, sessionId, 1, "/dest4");
294294
}
295295

296+
@Test
297+
void sendToWithHeaderFilterSinglePredicate() throws Exception {
298+
given(this.messageChannel.send(any(Message.class))).willReturn(true);
299+
300+
String sessionId = "sess1";
301+
String customHeaderName = "x-custom-header";
302+
String customHeaderValue = "custom-value";
303+
Message<?> inputMessage = createMessage(sessionId, "sub1", null, null, null);
304+
inputMessage = MessageBuilder.fromMessage(inputMessage)
305+
.setHeader(customHeaderName, customHeaderValue)
306+
.build();
307+
308+
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
309+
handler.addHeaderFilter(name -> name.equals(customHeaderName));
310+
311+
handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage);
312+
313+
verify(this.messageChannel, times(2)).send(this.messageCaptor.capture());
314+
for (Message<?> sent : this.messageCaptor.getAllValues()) {
315+
MessageHeaders headers = sent.getHeaders();
316+
assertThat(headers.get(customHeaderName)).isEqualTo(customHeaderValue);
317+
}
318+
}
319+
320+
@Test
321+
void sendToWithHeaderFilterMultiplePredicates() throws Exception {
322+
given(this.messageChannel.send(any(Message.class))).willReturn(true);
323+
324+
String sessionId = "sess1";
325+
String headerA = "x-header-a";
326+
String headerB = "x-header-b";
327+
Message<?> inputMessage = createMessage(sessionId, "sub1", null, null, null);
328+
inputMessage = MessageBuilder.fromMessage(inputMessage)
329+
.setHeader(headerA, "A-value")
330+
.setHeader(headerB, "B-value")
331+
.build();
332+
333+
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
334+
handler.addHeaderFilter(name -> name.equals(headerA));
335+
handler.addHeaderFilter(name -> name.equals(headerB));
336+
337+
handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage);
338+
339+
verify(this.messageChannel, times(2)).send(this.messageCaptor.capture());
340+
for (Message<?> sent : this.messageCaptor.getAllValues()) {
341+
MessageHeaders headers = sent.getHeaders();
342+
assertThat(headers.get(headerA)).isEqualTo("A-value");
343+
assertThat(headers.get(headerB)).isEqualTo("B-value");
344+
}
345+
}
346+
296347

297348
private void assertResponse(MethodParameter methodParameter, String sessionId,
298-
int index, String destination) {
349+
int index, String destination) {
299350

300351
SimpMessageHeaderAccessor accessor = getCapturedAccessor(index);
301352
assertThat(accessor.getSessionId()).isEqualTo(sessionId);

spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,65 @@ void testJsonView() throws Exception {
186186
assertThat(new String((byte[]) message.getPayload(), StandardCharsets.UTF_8)).isEqualTo("{\"withView1\":\"with\"}");
187187
}
188188

189+
@Test
190+
void testHeaderFilterSinglePredicate() throws Exception {
191+
String sessionId = "sess1";
192+
String subscriptionId = "subs1";
193+
String destination = "/dest";
194+
String customHeaderName = "x-custom-header";
195+
String customHeaderValue = "custom-value";
196+
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
197+
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
198+
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
199+
.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, destination)
200+
.setHeader(customHeaderName, customHeaderValue)
201+
.build();
202+
203+
MessageSendingOperations messagingTemplate = mock();
204+
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);
205+
206+
handler.addHeaderFilter(name -> name.equals(customHeaderName));
207+
208+
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
209+
210+
ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
211+
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
212+
213+
MessageHeaders sentHeaders = captor.getValue();
214+
assertThat(sentHeaders.get(customHeaderName)).isEqualTo(customHeaderValue);
215+
}
216+
217+
@Test
218+
void testHeaderFilterMultiplePredicates() throws Exception {
219+
String sessionId = "sess1";
220+
String subscriptionId = "subs1";
221+
String destination = "/dest";
222+
String headerA = "x-header-a";
223+
String headerB = "x-header-b";
224+
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
225+
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
226+
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
227+
.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, destination)
228+
.setHeader(headerA, "A-value")
229+
.setHeader(headerB, "B-value")
230+
.build();
231+
232+
MessageSendingOperations messagingTemplate = mock();
233+
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);
234+
235+
handler.addHeaderFilter(name -> name.equals(headerA));
236+
handler.addHeaderFilter(name -> name.equals(headerB));
237+
238+
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
239+
240+
ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
241+
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
242+
243+
MessageHeaders sentHeaders = captor.getValue();
244+
assertThat(sentHeaders.get(headerA)).isEqualTo("A-value");
245+
assertThat(sentHeaders.get(headerB)).isEqualTo("B-value");
246+
}
247+
189248

190249
private Message<?> createInputMessage(String sessId, String subsId, String dest, Principal principal) {
191250
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create();

0 commit comments

Comments
 (0)