4545import org .springframework .util .PropertyPlaceholderHelper ;
4646import org .springframework .util .PropertyPlaceholderHelper .PlaceholderResolver ;
4747import org .springframework .util .StringUtils ;
48+ import java .util .function .Predicate ;
4849
4950/**
5051 * A {@link HandlerMethodReturnValueHandler} for sending to destinations specified in a
@@ -74,6 +75,13 @@ public class SendToMethodReturnValueHandler implements HandlerMethodReturnValueH
7475 private @ Nullable MessageHeaderInitializer headerInitializer ;
7576
7677
78+ /**
79+ * Predicate to determine which header names from the input message should be propagated.
80+ * If null, no headers are propagated (default behavior).
81+ */
82+ private @ Nullable Predicate <String > headerPropagationPredicate ;
83+
84+
7785 public SendToMethodReturnValueHandler (SimpMessageSendingOperations messagingTemplate , boolean annotationRequired ) {
7886 Assert .notNull (messagingTemplate , "'messagingTemplate' must not be null" );
7987 this .messagingTemplate = messagingTemplate ;
@@ -133,6 +141,24 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia
133141 return this .headerInitializer ;
134142 }
135143
144+ /**
145+ * Set a predicate to filter which headers from the input message should be propagated to the output message.
146+ * <p><b>Warning:</b> The predicate should avoid propagating or overwriting well-known protocol headers
147+ * (e.g. headers starting with "simp", "content-type", etc.) to prevent breaking internal messaging semantics.
148+ * </p>
149+ * <p>If not set, no input headers are propagated (default behavior).</p>
150+ */
151+ public void setHeaderPropagationPredicate (@ Nullable Predicate <String > predicate ) {
152+ this .headerPropagationPredicate = predicate ;
153+ }
154+
155+ /**
156+ * Return the configured header propagation predicate.
157+ */
158+ public @ Nullable Predicate <String > getHeaderPropagationPredicate () {
159+ return this .headerPropagationPredicate ;
160+ }
161+
136162
137163 @ Override
138164 public boolean supportsReturnType (MethodParameter returnType ) {
@@ -171,11 +197,11 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu
171197 destination = destinationHelper .expandTemplateVars (destination );
172198 if (broadcast ) {
173199 this .messagingTemplate .convertAndSendToUser (
174- user , destination , returnValue , createHeaders (null , returnType ));
200+ user , destination , returnValue , createHeaders (null , returnType , message ));
175201 }
176202 else {
177203 this .messagingTemplate .convertAndSendToUser (
178- user , destination , returnValue , createHeaders (sessionId , returnType ));
204+ user , destination , returnValue , createHeaders (sessionId , returnType , message ));
179205 }
180206 }
181207 }
@@ -185,7 +211,7 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu
185211 String [] destinations = getTargetDestinations (sendTo , message , this .defaultDestinationPrefix );
186212 for (String destination : destinations ) {
187213 destination = destinationHelper .expandTemplateVars (destination );
188- this .messagingTemplate .convertAndSend (destination , returnValue , createHeaders (sessionId , returnType ));
214+ this .messagingTemplate .convertAndSend (destination , returnValue , createHeaders (sessionId , returnType , message ));
189215 }
190216 }
191217 }
@@ -234,7 +260,7 @@ protected String[] getTargetDestinations(@Nullable Annotation annotation, Messag
234260 new String [] {defaultPrefix + destination } : new String [] {defaultPrefix + '/' + destination });
235261 }
236262
237- private MessageHeaders createHeaders (@ Nullable String sessionId , MethodParameter returnType ) {
263+ private MessageHeaders createHeaders (@ Nullable String sessionId , MethodParameter returnType , @ Nullable Message <?> inputMessage ) {
238264 SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor .create (SimpMessageType .MESSAGE );
239265 if (getHeaderInitializer () != null ) {
240266 getHeaderInitializer ().initHeaders (headerAccessor );
@@ -243,6 +269,18 @@ private MessageHeaders createHeaders(@Nullable String sessionId, MethodParameter
243269 headerAccessor .setSessionId (sessionId );
244270 }
245271 headerAccessor .setHeader (AbstractMessageSendingTemplate .CONVERSION_HINT_HEADER , returnType );
272+
273+ // Header propagation policy
274+ if (inputMessage != null && headerPropagationPredicate != null ) {
275+ Map <String , Object > inputHeaders = inputMessage .getHeaders ();
276+ for (Map .Entry <String , Object > entry : inputHeaders .entrySet ()) {
277+ String name = entry .getKey ();
278+ if (headerPropagationPredicate .test (name )) {
279+ headerAccessor .setHeader (name , entry .getValue ());
280+ }
281+ }
282+ }
283+
246284 headerAccessor .setLeaveMutable (true );
247285 return headerAccessor .getMessageHeaders ();
248286 }
0 commit comments