Skip to content

Commit ad5e26d

Browse files
committed
Add header propagation predicate support to message return value handlers
Signed-off-by: 김준환 <musoyou1085@gmail.com>
1 parent 7917ae5 commit ad5e26d

File tree

4 files changed

+134
-6
lines changed

4 files changed

+134
-6
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandler.java

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.springframework.util.PropertyPlaceholderHelper;
4646
import org.springframework.util.PropertyPlaceholderHelper.PlaceholderResolver;
4747
import org.springframework.util.StringUtils;
48+
import java.util.function.Predicate;
4849

4950
/**
5051
* A {@link HandlerMethodReturnValueHandler} for sending to destinations specified in a
@@ -74,6 +75,13 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
7475
private @Nullable MessageHeaderInitializer headerInitializer;
7576

7677

78+
/**
79+
* Predicate to determine which header names from the input message should be propagated.
80+
* If null, no headers are propagated (default behavior).
81+
*/
82+
private @Nullable Predicate<String> headerPropagationPredicate;
83+
84+
7785
public SendToMethodReturnValueHandler(SimpMessageSendingOperations messagingTemplate, boolean annotationRequired) {
7886
Assert.notNull(messagingTemplate, "'messagingTemplate' must not be null");
7987
this.messagingTemplate = messagingTemplate;
@@ -133,6 +141,24 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia
133141
return this.headerInitializer;
134142
}
135143

144+
/**
145+
* Set a predicate to filter which headers from the input message should be propagated to the output message.
146+
* <p><b>Warning:</b> The predicate should avoid propagating or overwriting well-known protocol headers
147+
* (e.g. headers starting with "simp", "content-type", etc.) to prevent breaking internal messaging semantics.
148+
* </p>
149+
* <p>If not set, no input headers are propagated (default behavior).</p>
150+
*/
151+
public void setHeaderPropagationPredicate(@Nullable Predicate<String> predicate) {
152+
this.headerPropagationPredicate = predicate;
153+
}
154+
155+
/**
156+
* Return the configured header propagation predicate.
157+
*/
158+
public @Nullable Predicate<String> getHeaderPropagationPredicate() {
159+
return this.headerPropagationPredicate;
160+
}
161+
136162

137163
@Override
138164
public boolean supportsReturnType(MethodParameter returnType) {
@@ -171,11 +197,11 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu
171197
destination = destinationHelper.expandTemplateVars(destination);
172198
if (broadcast) {
173199
this.messagingTemplate.convertAndSendToUser(
174-
user, destination, returnValue, createHeaders(null, returnType));
200+
user, destination, returnValue, createHeaders(null, returnType, message));
175201
}
176202
else {
177203
this.messagingTemplate.convertAndSendToUser(
178-
user, destination, returnValue, createHeaders(sessionId, returnType));
204+
user, destination, returnValue, createHeaders(sessionId, returnType, message));
179205
}
180206
}
181207
}
@@ -185,7 +211,7 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu
185211
String[] destinations = getTargetDestinations(sendTo, message, this.defaultDestinationPrefix);
186212
for (String destination : destinations) {
187213
destination = destinationHelper.expandTemplateVars(destination);
188-
this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId, returnType));
214+
this.messagingTemplate.convertAndSend(destination, returnValue, createHeaders(sessionId, returnType, message));
189215
}
190216
}
191217
}
@@ -234,7 +260,7 @@ protected String[] getTargetDestinations(@Nullable Annotation annotation, Messag
234260
new String[] {defaultPrefix + destination} : new String[] {defaultPrefix + '/' + destination});
235261
}
236262

237-
private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter returnType) {
263+
private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter returnType, @Nullable Message<?> inputMessage) {
238264
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
239265
if (getHeaderInitializer() != null) {
240266
getHeaderInitializer().initHeaders(headerAccessor);
@@ -243,6 +269,18 @@ private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter
243269
headerAccessor.setSessionId(sessionId);
244270
}
245271
headerAccessor.setHeader(AbstractMessageSendingTemplate.CONVERSION_HINT_HEADER, returnType);
272+
273+
// Header propagation policy
274+
if (inputMessage != null && headerPropagationPredicate != null) {
275+
Map<String, Object> inputHeaders = inputMessage.getHeaders();
276+
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) {
277+
String name = entry.getKey();
278+
if (headerPropagationPredicate.test(name)) {
279+
headerAccessor.setHeader(name, entry.getValue());
280+
}
281+
}
282+
}
283+
246284
headerAccessor.setLeaveMutable(true);
247285
return headerAccessor.getMessageHeaders();
248286
}

spring-messaging/src/main/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandler.java

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
package org.springframework.messaging.simp.annotation.support;
1818

19+
import java.util.Map;
20+
import java.util.function.Predicate;
21+
1922
import org.apache.commons.logging.Log;
2023
import org.jspecify.annotations.Nullable;
2124

@@ -65,6 +68,12 @@ public class SubscriptionMethodReturnValueHandler implements HandlerMethodReturn
6568

6669
private @Nullable MessageHeaderInitializer headerInitializer;
6770

71+
/**
72+
* Predicate to determine which header names from the input message should be propagated.
73+
* If null, no headers are propagated (default behavior).
74+
*/
75+
private @Nullable Predicate<String> headerPropagationPredicate;
76+
6877

6978
/**
7079
* Construct a new SubscriptionMethodReturnValueHandler.
@@ -93,6 +102,24 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia
93102
return this.headerInitializer;
94103
}
95104

105+
/**
106+
* Set a predicate to filter which headers from the input message should be propagated to the output message.
107+
* <p><b>Warning:</b> The predicate should avoid propagating or overwriting well-known protocol headers
108+
* (e.g. headers starting with "simp", "content-type", etc.) to prevent breaking internal messaging semantics.
109+
* </p>
110+
* <p>If not set, no input headers are propagated (default behavior).</p>
111+
*/
112+
public void setHeaderPropagationPredicate(@Nullable Predicate<String> predicate) {
113+
this.headerPropagationPredicate = predicate;
114+
}
115+
116+
/**
117+
* Return the configured header propagation predicate.
118+
*/
119+
public @Nullable Predicate<String> getHeaderPropagationPredicate() {
120+
return this.headerPropagationPredicate;
121+
}
122+
96123

97124
@Override
98125
public boolean supportsReturnType(MethodParameter returnType) {
@@ -126,11 +153,11 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu
126153
if (logger.isDebugEnabled()) {
127154
logger.debug("Reply to @SubscribeMapping: " + returnValue);
128155
}
129-
MessageHeaders headersToSend = createHeaders(sessionId, subscriptionId, returnType);
156+
MessageHeaders headersToSend = createHeaders(sessionId, subscriptionId, returnType, message);
130157
this.messagingTemplate.convertAndSend(destination, returnValue, headersToSend);
131158
}
132159

133-
private MessageHeaders createHeaders(@Nullable String sessionId, String subscriptionId, MethodParameter returnType) {
160+
private MessageHeaders createHeaders(@Nullable String sessionId, String subscriptionId, MethodParameter returnType, @Nullable Message<?> inputMessage) {
134161
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
135162
if (getHeaderInitializer() != null) {
136163
getHeaderInitializer().initHeaders(accessor);
@@ -140,6 +167,17 @@ private MessageHeaders createHeaders(@Nullable String sessionId, String subscrip
140167
}
141168
accessor.setSubscriptionId(subscriptionId);
142169
accessor.setHeader(AbstractMessageSendingTemplate.CONVERSION_HINT_HEADER, returnType);
170+
171+
if (inputMessage != null && headerPropagationPredicate != null) {
172+
Map<String, Object> inputHeaders = inputMessage.getHeaders();
173+
for (Map.Entry<String, Object> entry : inputHeaders.entrySet()) {
174+
String name = entry.getKey();
175+
if (headerPropagationPredicate.test(name)) {
176+
accessor.setHeader(name, entry.getValue());
177+
}
178+
}
179+
}
180+
143181
accessor.setLeaveMutable(true);
144182
return accessor.getMessageHeaders();
145183
}

spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SendToMethodReturnValueHandlerTests.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,30 @@ public void sendToUserWithSendToOverride() throws Exception {
293293
assertResponse(parameter, sessionId, 1, "/dest4");
294294
}
295295

296+
@Test
297+
void sendToWithHeaderPropagationPredicate() throws Exception {
298+
given(this.messageChannel.send(any(Message.class))).willReturn(true);
299+
300+
String sessionId = "sess1";
301+
String customHeaderName = "x-custom-header";
302+
String customHeaderValue = "custom-value";
303+
Message<?> inputMessage = createMessage(sessionId, "sub1", null, null, null);
304+
inputMessage = MessageBuilder.fromMessage(inputMessage)
305+
.setHeader(customHeaderName, customHeaderValue)
306+
.build();
307+
308+
SendToMethodReturnValueHandler handler = new SendToMethodReturnValueHandler(new SimpMessagingTemplate(this.messageChannel), true);
309+
handler.setHeaderPropagationPredicate(name -> name.equals(customHeaderName));
310+
311+
handler.handleReturnValue(PAYLOAD, this.sendToReturnType, inputMessage);
312+
313+
verify(this.messageChannel, times(2)).send(this.messageCaptor.capture());
314+
for (Message<?> sent : this.messageCaptor.getAllValues()) {
315+
MessageHeaders headers = sent.getHeaders();
316+
assertThat(headers.get(customHeaderName)).isEqualTo(customHeaderValue);
317+
}
318+
}
319+
296320

297321
private void assertResponse(MethodParameter methodParameter, String sessionId,
298322
int index, String destination) {

spring-messaging/src/test/java/org/springframework/messaging/simp/annotation/support/SubscriptionMethodReturnValueHandlerTests.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,34 @@ void testJsonView() throws Exception {
186186
assertThat(new String((byte[]) message.getPayload(), StandardCharsets.UTF_8)).isEqualTo("{\"withView1\":\"with\"}");
187187
}
188188

189+
@Test
190+
void testHeaderPropagationPredicate() throws Exception {
191+
String sessionId = "sess1";
192+
String subscriptionId = "subs1";
193+
String destination = "/dest";
194+
String customHeaderName = "x-custom-header";
195+
String customHeaderValue = "custom-value";
196+
Message<?> inputMessage = MessageBuilder.withPayload(PAYLOAD)
197+
.setHeader(SimpMessageHeaderAccessor.SESSION_ID_HEADER, sessionId)
198+
.setHeader(SimpMessageHeaderAccessor.SUBSCRIPTION_ID_HEADER, subscriptionId)
199+
.setHeader(SimpMessageHeaderAccessor.DESTINATION_HEADER, destination)
200+
.setHeader(customHeaderName, customHeaderValue)
201+
.build();
202+
203+
MessageSendingOperations messagingTemplate = mock();
204+
SubscriptionMethodReturnValueHandler handler = new SubscriptionMethodReturnValueHandler(messagingTemplate);
205+
206+
handler.setHeaderPropagationPredicate(name -> name.equals(customHeaderName));
207+
208+
handler.handleReturnValue(PAYLOAD, this.subscribeEventReturnType, inputMessage);
209+
210+
ArgumentCaptor<MessageHeaders> captor = ArgumentCaptor.forClass(MessageHeaders.class);
211+
verify(messagingTemplate).convertAndSend(eq(destination), eq(PAYLOAD), captor.capture());
212+
213+
MessageHeaders sentHeaders = captor.getValue();
214+
assertThat(sentHeaders.get(customHeaderName)).isEqualTo(customHeaderValue);
215+
}
216+
189217

190218
private Message<?> createInputMessage(String sessId, String subsId, String dest, Principal principal) {
191219
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create();

0 commit comments

Comments
 (0)