Skip to content

Commit 5eda90b

Browse files
artembilangaryrussell
authored andcommitted
Fix ThreadSPropagationChInterceptor for stacking (#8735)
* Fix ThreadSPropagationChInterceptor for stacking Related SO thread: https://stackoverflow.com/questions/77058188/multiple-threadstatepropagationchannelinterceptors-not-possible The current `ThreadStatePropagationChannelInterceptor` logic is to wrap one message to another (`MessageWithThreadState`), essentially stacking contexts. The `postReceive()` logic is to unwrap a `MessageWithThreadState`, therefore we deal with the latest pushed context which leads to the `ClassCastException` * Rework `ThreadStatePropagationChannelInterceptor` logic to reuse existing `MessageWithThreadState` and add the current context to its `stateQueue`. Therefore, the `postReceive()` will `poll()` the oldest context which is, essentially, the one populated by this interceptor before, according to the interceptors order * Fix `AbstractMessageChannel.setInterceptors()` to not modify provided list of interceptors * The new `ThreadStatePropagationChannelInterceptorTests` demonstrates the problem described in that mentioned SO question and verifies that context are propagated in the order they have been populated **Cherry-pick to `6.1.x` & `6.0.x`** * * Fix `ThreadStatePropagationChannelInterceptor` for publish-subscribe scenario. Essentially, copy the state queue to a new decorated message * Fix `BroadcastingDispatcher` to always decorate message, even if not `applySequence` * * Fix unused import in the `BroadcastingDispatcher` * * Fix unused import in the `ThreadStatePropagationChannelInterceptor`
1 parent 945c842 commit 5eda90b

File tree

4 files changed

+145
-38
lines changed

4 files changed

+145
-38
lines changed

spring-integration-core/src/main/java/org/springframework/integration/channel/AbstractMessageChannel.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.integration.channel;
1818

1919
import java.util.ArrayDeque;
20+
import java.util.ArrayList;
2021
import java.util.Arrays;
2122
import java.util.Collections;
2223
import java.util.Comparator;
@@ -161,8 +162,9 @@ public void setDatatypes(Class<?>... datatypes) {
161162
*/
162163
@Override
163164
public void setInterceptors(List<ChannelInterceptor> interceptors) {
164-
interceptors.sort(this.orderComparator);
165-
this.interceptors.set(interceptors);
165+
List<ChannelInterceptor> interceptorsToUse = new ArrayList<>(interceptors);
166+
interceptorsToUse.sort(this.orderComparator);
167+
this.interceptors.set(interceptorsToUse);
166168
}
167169

168170
/**

spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java

+29-14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
package org.springframework.integration.channel.interceptor;
1818

19+
import java.util.LinkedList;
20+
import java.util.Queue;
21+
1922
import io.micrometer.common.lang.Nullable;
2023

2124
import org.springframework.integration.support.MessageDecorator;
@@ -58,20 +61,27 @@ public abstract class ThreadStatePropagationChannelInterceptor<S> implements Exe
5861
public final Message<?> preSend(Message<?> message, MessageChannel channel) {
5962
S threadContext = obtainPropagatingContext(message, channel);
6063
if (threadContext != null) {
61-
return new MessageWithThreadState<>(message, threadContext);
62-
}
63-
else {
64-
return message;
64+
if (message instanceof MessageWithThreadState messageWithThreadState) {
65+
messageWithThreadState.stateQueue.add(threadContext);
66+
}
67+
else {
68+
return new MessageWithThreadState(message, threadContext);
69+
}
6570
}
71+
72+
return message;
6673
}
6774

6875
@Override
6976
@SuppressWarnings("unchecked")
7077
public final Message<?> postReceive(Message<?> message, MessageChannel channel) {
71-
if (message instanceof MessageWithThreadState) {
72-
MessageWithThreadState<S> messageWithThreadState = (MessageWithThreadState<S>) message;
73-
Message<?> messageToHandle = messageWithThreadState.message;
74-
populatePropagatedContext(messageWithThreadState.state, messageToHandle, channel);
78+
if (message instanceof MessageWithThreadState messageWithThreadState) {
79+
Object threadContext = messageWithThreadState.stateQueue.poll();
80+
Message<?> messageToHandle = messageWithThreadState;
81+
if (messageWithThreadState.stateQueue.isEmpty()) {
82+
messageToHandle = messageWithThreadState.message;
83+
}
84+
populatePropagatedContext((S) threadContext, messageToHandle, channel);
7585
return messageToHandle;
7686
}
7787
return message;
@@ -88,16 +98,21 @@ public final Message<?> beforeHandle(Message<?> message, MessageChannel channel,
8898
protected abstract void populatePropagatedContext(@Nullable S state, Message<?> message, MessageChannel channel);
8999

90100

91-
private static final class MessageWithThreadState<S> implements Message<Object>, MessageDecorator {
101+
private static final class MessageWithThreadState implements Message<Object>, MessageDecorator {
92102

93103
private final Message<Object> message;
94104

95-
private final S state;
105+
private final Queue<Object> stateQueue;
106+
107+
MessageWithThreadState(Message<?> message, Object state) {
108+
this(message, new LinkedList<>());
109+
this.stateQueue.add(state);
110+
}
96111

97112
@SuppressWarnings("unchecked")
98-
MessageWithThreadState(Message<?> message, S state) {
113+
private MessageWithThreadState(Message<?> message, Queue<Object> stateQueue) {
99114
this.message = (Message<Object>) message;
100-
this.state = state;
115+
this.stateQueue = new LinkedList<>(stateQueue);
101116
}
102117

103118
@Override
@@ -112,14 +127,14 @@ public MessageHeaders getHeaders() {
112127

113128
@Override
114129
public Message<?> decorateMessage(Message<?> message) {
115-
return new MessageWithThreadState<>(message, this.state);
130+
return new MessageWithThreadState(message, this.stateQueue);
116131
}
117132

118133
@Override
119134
public String toString() {
120135
return "MessageWithThreadState{" +
121136
"message=" + this.message +
122-
", state=" + this.state +
137+
", state=" + this.stateQueue +
123138
'}';
124139
}
125140

spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java

+17-22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
1717
package org.springframework.integration.dispatcher;
1818

1919
import java.util.Collection;
20-
import java.util.UUID;
2120
import java.util.concurrent.Executor;
2221

2322
import org.springframework.beans.BeansException;
@@ -57,13 +56,13 @@ public class BroadcastingDispatcher extends AbstractDispatcher implements BeanFa
5756

5857
private final boolean requireSubscribers;
5958

60-
private volatile boolean ignoreFailures;
59+
private final Executor executor;
6160

62-
private volatile boolean applySequence;
61+
private boolean ignoreFailures;
6362

64-
private final Executor executor;
63+
private boolean applySequence;
6564

66-
private volatile int minSubscribers;
65+
private int minSubscribers;
6766

6867
private MessageHandlingTaskDecorator messageHandlingTaskDecorator = task -> task;
6968

@@ -149,24 +148,20 @@ public boolean dispatch(Message<?> message) {
149148
int dispatched = 0;
150149
int sequenceNumber = 1;
151150
Collection<MessageHandler> handlers = this.getHandlers();
152-
if (this.requireSubscribers && handlers.size() == 0) {
151+
if (this.requireSubscribers && handlers.isEmpty()) {
153152
throw new MessageDispatchingException(message, "Dispatcher has no subscribers");
154153
}
155154
int sequenceSize = handlers.size();
156155
Message<?> messageToSend = message;
157-
UUID sequenceId = null;
158-
if (this.applySequence) {
159-
sequenceId = message.getHeaders().getId();
160-
}
161156
for (MessageHandler handler : handlers) {
162157
if (this.applySequence) {
163158
messageToSend = getMessageBuilderFactory()
164159
.fromMessage(message)
165-
.pushSequenceDetails(sequenceId, sequenceNumber++, sequenceSize)
160+
.pushSequenceDetails(message.getHeaders().getId(), sequenceNumber++, sequenceSize)
166161
.build();
167-
if (message instanceof MessageDecorator) {
168-
messageToSend = ((MessageDecorator) message).decorateMessage(messageToSend);
169-
}
162+
}
163+
if (message instanceof MessageDecorator messageDecorator) {
164+
messageToSend = messageDecorator.decorateMessage(messageToSend);
170165
}
171166

172167
if (this.executor != null) {
@@ -175,7 +170,7 @@ public boolean dispatch(Message<?> message) {
175170
dispatched++;
176171
}
177172
else {
178-
if (this.invokeHandler(handler, messageToSend)) {
173+
if (invokeHandler(handler, messageToSend)) {
179174
dispatched++;
180175
}
181176
}
@@ -222,15 +217,15 @@ private boolean invokeHandler(MessageHandler handler, Message<?> message) {
222217
handler.handleMessage(message);
223218
return true;
224219
}
225-
catch (RuntimeException e) {
220+
catch (RuntimeException ex) {
226221
if (!this.ignoreFailures) {
227-
if (e instanceof MessagingException && ((MessagingException) e).getFailedMessage() == null) { // NOSONAR
228-
throw new MessagingException(message, "Failed to handle Message", e);
222+
if (ex instanceof MessagingException exception && exception.getFailedMessage() == null) { // NOSONAR
223+
throw new MessagingException(message, "Failed to handle Message", ex);
229224
}
230-
throw e;
225+
throw ex;
231226
}
232-
else if (this.logger.isWarnEnabled()) {
233-
logger.warn("Suppressing Exception since 'ignoreFailures' is set to TRUE.", e);
227+
else {
228+
logger.warn("Suppressing Exception since 'ignoreFailures' is set to TRUE.", ex);
234229
}
235230
return false;
236231
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright 2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.integration.channel.interceptor;
18+
19+
import java.util.ArrayList;
20+
import java.util.List;
21+
22+
import org.junit.jupiter.api.Test;
23+
24+
import org.springframework.core.task.SyncTaskExecutor;
25+
import org.springframework.integration.channel.ExecutorChannel;
26+
import org.springframework.integration.util.ErrorHandlingTaskExecutor;
27+
import org.springframework.messaging.Message;
28+
import org.springframework.messaging.MessageChannel;
29+
import org.springframework.messaging.support.GenericMessage;
30+
import org.springframework.util.ReflectionUtils;
31+
32+
import static org.assertj.core.api.Assertions.assertThat;
33+
import static org.mockito.Mockito.mock;
34+
35+
/**
36+
* @author Artem Bilan
37+
*
38+
* @since 6.2
39+
*/
40+
public class ThreadStatePropagationChannelInterceptorTests {
41+
42+
@Test
43+
void ThreadStatePropagationChannelInterceptorsCanBeStacked() {
44+
TestContext1 ctx1 = new TestContext1();
45+
TestContext2 ctx2 = new TestContext2();
46+
47+
List<Object> propagatedContexts = new ArrayList<>();
48+
49+
var interceptor1 = new ThreadStatePropagationChannelInterceptor<TestContext1>() {
50+
@Override
51+
protected TestContext1 obtainPropagatingContext(Message<?> message, MessageChannel channel) {
52+
return ctx1;
53+
}
54+
55+
@Override
56+
protected void populatePropagatedContext(TestContext1 state, Message<?> message, MessageChannel channel) {
57+
propagatedContexts.add(state);
58+
}
59+
60+
};
61+
62+
var interceptor2 = new ThreadStatePropagationChannelInterceptor<TestContext2>() {
63+
@Override
64+
protected TestContext2 obtainPropagatingContext(Message<?> message, MessageChannel channel) {
65+
return ctx2;
66+
}
67+
68+
@Override
69+
protected void populatePropagatedContext(TestContext2 state, Message<?> message, MessageChannel channel) {
70+
propagatedContexts.add(state);
71+
}
72+
73+
};
74+
75+
ExecutorChannel testChannel = new ExecutorChannel(
76+
new ErrorHandlingTaskExecutor(new SyncTaskExecutor(), ReflectionUtils::rethrowRuntimeException));
77+
testChannel.setInterceptors(List.of(interceptor1, interceptor2));
78+
testChannel.setBeanFactory(mock());
79+
testChannel.afterPropertiesSet();
80+
testChannel.subscribe(m -> {
81+
});
82+
83+
testChannel.send(new GenericMessage<>("test data"));
84+
85+
assertThat(propagatedContexts.get(0)).isEqualTo(ctx1);
86+
assertThat(propagatedContexts.get(1)).isEqualTo(ctx2);
87+
}
88+
89+
private record TestContext1() {
90+
}
91+
92+
private record TestContext2() {
93+
}
94+
95+
}

0 commit comments

Comments
 (0)