Skip to content

Commit d79c06a

Browse files
artembilangaryrussell
authored andcommitted
GH-2967: Fix ScatterGatherH for headers copy (#2968)
* GH-2967: Fix ScatterGatherH for headers copy Fixes #2967 The `ChannelInterceptor` is added into the `this.gatherChannel` on each request message making a subsequent requests for scatter-gather as halting on reply. * Add an interceptor into an injected `this.gatherChannel` only once during `ScatterGatherHandler` initialization * Introduce `ORIGINAL_REPLY_CHANNEL` and `ORIGINAL_ERROR_CHANNEL` headers to carry a request reply and error channels from headers * Populate `REPLY_CHANNEL` and `ERROR_CHANNEL` headers back before sending scattering replies into gatherer * Transfer a `GATHER_RESULT_CHANNEL` header now directly from the scatter message to make it available in the reply from the gatherer * Add note about those headers in the `scatter-gather.adoc` * Modify `ScatterGatherTests` to be sure that `ScatterGatherHandler` works for several requests **Cherry-pick to 5.1.x** * * Fix language in doc
1 parent 36c14ef commit d79c06a

File tree

4 files changed

+79
-56
lines changed

4 files changed

+79
-56
lines changed

spring-integration-core/src/main/java/org/springframework/integration/scattergather/ScatterGatherHandler.java

+37-37
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ public class ScatterGatherHandler extends AbstractReplyProducingMessageHandler i
5555

5656
private static final String GATHER_RESULT_CHANNEL = "gatherResultChannel";
5757

58+
private static final String ORIGINAL_REPLY_CHANNEL = "originalReplyChannel";
59+
60+
private static final String ORIGINAL_ERROR_CHANNEL = "originalErrorChannel";
61+
5862
private final MessageChannel scatterChannel;
5963

6064
private final MessageHandler gatherer;
@@ -107,9 +111,24 @@ public void setErrorChannelName(String errorChannelName) {
107111
protected void doInit() {
108112
BeanFactory beanFactory = getBeanFactory();
109113
if (this.gatherChannel == null) {
110-
this.gatherChannel = new FixedSubscriberChannel(this.gatherer);
114+
this.gatherChannel =
115+
new FixedSubscriberChannel((message) ->
116+
this.gatherer.handleMessage(enhanceScatterReplyMessage(message)));
111117
}
112118
else {
119+
Assert.isInstanceOf(InterceptableChannel.class, this.gatherChannel,
120+
() -> "An injected 'gatherChannel' '" + this.gatherChannel +
121+
"' must be an 'InterceptableChannel' instance.");
122+
((InterceptableChannel) this.gatherChannel)
123+
.addInterceptor(0,
124+
new ChannelInterceptor() {
125+
126+
@Override
127+
public Message<?> preSend(Message<?> message, MessageChannel channel) {
128+
return enhanceScatterReplyMessage(message);
129+
}
130+
131+
});
113132
if (this.gatherChannel instanceof SubscribableChannel) {
114133
this.gatherEndpoint = new EventDrivenConsumer((SubscribableChannel) this.gatherChannel, this.gatherer);
115134
}
@@ -121,7 +140,7 @@ else if (this.gatherChannel instanceof ReactiveStreamsSubscribableChannel) {
121140
this.gatherEndpoint = new ReactiveStreamsConsumer(this.gatherChannel, this.gatherer);
122141
}
123142
else {
124-
throw new BeanInitializationException("Unsupported 'replyChannel' type '" +
143+
throw new BeanInitializationException("Unsupported 'gatherChannel' type '" +
125144
this.gatherChannel.getClass() + "'. " +
126145
"'SubscribableChannel', 'PollableChannel' or 'ReactiveStreamsSubscribableChannel' " +
127146
"types are supported.");
@@ -131,7 +150,7 @@ else if (this.gatherChannel instanceof ReactiveStreamsSubscribableChannel) {
131150
}
132151

133152
((MessageProducer) this.gatherer)
134-
.setOutputChannel(new FixedSubscriberChannel(message -> {
153+
.setOutputChannel(new FixedSubscriberChannel((message) -> {
135154
MessageHeaders headers = message.getHeaders();
136155
MessageChannel gatherResultChannel = headers.get(GATHER_RESULT_CHANNEL, MessageChannel.class);
137156
if (gatherResultChannel != null) {
@@ -144,35 +163,28 @@ else if (this.gatherChannel instanceof ReactiveStreamsSubscribableChannel) {
144163
}));
145164
}
146165

166+
private Message<?> enhanceScatterReplyMessage(Message<?> message) {
167+
MessageHeaders headers = message.getHeaders();
168+
return getMessageBuilderFactory()
169+
.fromMessage(message)
170+
.setHeader(MessageHeaders.REPLY_CHANNEL, headers.get(ORIGINAL_REPLY_CHANNEL))
171+
.setHeader(MessageHeaders.ERROR_CHANNEL, headers.get(ORIGINAL_ERROR_CHANNEL))
172+
.removeHeaders(ORIGINAL_REPLY_CHANNEL, ORIGINAL_ERROR_CHANNEL)
173+
.build();
174+
}
175+
147176
@Override
148177
protected Object handleRequestMessage(Message<?> requestMessage) {
178+
MessageHeaders requestMessageHeaders = requestMessage.getHeaders();
149179
PollableChannel gatherResultChannel = new QueueChannel();
150180

151-
MessageChannel replyChannel = this.gatherChannel;
152-
153-
if (replyChannel instanceof InterceptableChannel) {
154-
((InterceptableChannel) replyChannel)
155-
.addInterceptor(0,
156-
new ChannelInterceptor() {
157-
158-
@Override
159-
public Message<?> preSend(Message<?> message, MessageChannel channel) {
160-
return enhanceScatterReplyMessage(message, gatherResultChannel, requestMessage);
161-
}
162-
163-
});
164-
}
165-
else {
166-
replyChannel =
167-
new FixedSubscriberChannel(message ->
168-
this.messagingTemplate.send(this.gatherChannel,
169-
enhanceScatterReplyMessage(message, gatherResultChannel, requestMessage)));
170-
}
171-
172181
Message<?> scatterMessage =
173182
getMessageBuilderFactory()
174183
.fromMessage(requestMessage)
175-
.setReplyChannel(replyChannel)
184+
.setHeader(GATHER_RESULT_CHANNEL, gatherResultChannel)
185+
.setHeader(ORIGINAL_REPLY_CHANNEL, requestMessageHeaders.getReplyChannel())
186+
.setHeader(ORIGINAL_ERROR_CHANNEL, requestMessageHeaders.getErrorChannel())
187+
.setReplyChannel(this.gatherChannel)
176188
.setErrorChannelName(this.errorChannelName)
177189
.build();
178190

@@ -181,18 +193,6 @@ public Message<?> preSend(Message<?> message, MessageChannel channel) {
181193
return gatherResultChannel.receive(this.gatherTimeout);
182194
}
183195

184-
private Message<?> enhanceScatterReplyMessage(Message<?> message, PollableChannel gatherResultChannel,
185-
Message<?> requestMessage) {
186-
187-
MessageHeaders requestMessageHeaders = requestMessage.getHeaders();
188-
return getMessageBuilderFactory()
189-
.fromMessage(message)
190-
.setHeader(GATHER_RESULT_CHANNEL, gatherResultChannel)
191-
.setHeader(MessageHeaders.REPLY_CHANNEL, requestMessageHeaders.getReplyChannel())
192-
.setHeader(MessageHeaders.ERROR_CHANNEL, requestMessageHeaders.getErrorChannel())
193-
.build();
194-
}
195-
196196
@Override
197197
public void start() {
198198
if (this.gatherEndpoint != null) {

spring-integration-core/src/test/java/org/springframework/integration/scattergather/config/ScatterGatherTests-context.xml

+7-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
xmlns="http://www.springframework.org/schema/integration"
55
xmlns:task="http://www.springframework.org/schema/task"
66
xsi:schemaLocation="http://www.springframework.org/schema/beans https://www.springframework.org/schema/beans/spring-beans.xsd
7-
http://www.springframework.org/schema/integration https://www.springframework.org/schema/integration/spring-integration.xsd http://www.springframework.org/schema/task https://www.springframework.org/schema/task/spring-task.xsd">
7+
http://www.springframework.org/schema/integration https://www.springframework.org/schema/integration/spring-integration.xsd
8+
http://www.springframework.org/schema/task https://www.springframework.org/schema/task/spring-task.xsd">
89

910
<channel id="output">
1011
<queue/>
@@ -52,9 +53,12 @@
5253

5354
<!--Sync scenario-->
5455

55-
<gateway id="gateway" default-request-channel="gatewayAuction" default-reply-timeout="10000" />
56+
<gateway id="gateway" default-request-channel="gatewayAuction" default-reply-timeout="10000"/>
5657

57-
<scatter-gather input-channel="gatewayAuction" output-channel="bridgeChannel" scatter-channel="auctionChannel">
58+
<channel id="gatherChannel2"/>
59+
60+
<scatter-gather input-channel="gatewayAuction" output-channel="bridgeChannel" scatter-channel="auctionChannel"
61+
gather-channel="gatherChannel2">
5862
<gatherer release-strategy-expression="messages.^[payload gt 5] != null or size() == 3"/>
5963
</scatter-gather>
6064

spring-integration-core/src/test/java/org/springframework/integration/scattergather/config/ScatterGatherTests.java

+33-16
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,19 @@
2929
import org.springframework.messaging.MessageChannel;
3030
import org.springframework.messaging.PollableChannel;
3131
import org.springframework.messaging.support.GenericMessage;
32+
import org.springframework.test.annotation.DirtiesContext;
3233
import org.springframework.test.context.ContextConfiguration;
3334
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
3435

3536
/**
3637
* @author Artem Bilan
3738
* @author Gary Russell
39+
*
3840
* @since 4.1
3941
*/
4042
@ContextConfiguration
4143
@RunWith(SpringJUnit4ClassRunner.class)
44+
@DirtiesContext
4245
public class ScatterGatherTests {
4346

4447
@Autowired
@@ -58,36 +61,50 @@ public class ScatterGatherTests {
5861

5962
@Test
6063
public void testAuction() {
61-
this.inputAuction.send(new GenericMessage<String>("foo"));
64+
this.inputAuction.send(new GenericMessage<>("foo"));
6265
Message<?> bestQuoteMessage = this.output.receive(10000);
63-
assertThat(bestQuoteMessage).isNotNull();
64-
Object payload = bestQuoteMessage.getPayload();
65-
assertThat(payload).isInstanceOf(List.class);
66-
assertThat(((List<?>) payload).size()).isGreaterThanOrEqualTo(1);
66+
assertThat(bestQuoteMessage)
67+
.isNotNull()
68+
.extracting(Message::getPayload)
69+
.isInstanceOf(List.class)
70+
.asList()
71+
.hasSizeGreaterThanOrEqualTo(1);
6772
}
6873

6974
@Test
7075
public void testDistribution() {
71-
this.inputDistribution.send(new GenericMessage<String>("foo"));
76+
this.inputDistribution.send(new GenericMessage<>("foo"));
7277
Message<?> bestQuoteMessage = this.output.receive(10000);
73-
assertThat(bestQuoteMessage).isNotNull();
74-
Object payload = bestQuoteMessage.getPayload();
75-
assertThat(payload).isInstanceOf(List.class);
76-
assertThat(((List<?>) payload).size()).isGreaterThanOrEqualTo(1);
78+
assertThat(bestQuoteMessage)
79+
.isNotNull()
80+
.extracting(Message::getPayload)
81+
.isInstanceOf(List.class)
82+
.asList()
83+
.hasSizeGreaterThanOrEqualTo(1);
7784
}
7885

7986
@Test
8087
public void testGatewayScatterGather() {
81-
Message<?> bestQuoteMessage = this.gateway.exchange(new GenericMessage<String>("foo"));
82-
assertThat(bestQuoteMessage).isNotNull();
83-
Object payload = bestQuoteMessage.getPayload();
84-
assertThat(payload).isInstanceOf(List.class);
85-
assertThat(((List<?>) payload).size()).isGreaterThanOrEqualTo(1);
88+
Message<?> bestQuoteMessage = this.gateway.exchange(new GenericMessage<>("foo"));
89+
assertThat(bestQuoteMessage)
90+
.isNotNull()
91+
.extracting(Message::getPayload)
92+
.isInstanceOf(List.class)
93+
.asList()
94+
.hasSizeGreaterThanOrEqualTo(1);
95+
96+
bestQuoteMessage = this.gateway.exchange(new GenericMessage<>("bar"));
97+
assertThat(bestQuoteMessage)
98+
.isNotNull()
99+
.extracting(Message::getPayload)
100+
.isInstanceOf(List.class)
101+
.asList()
102+
.hasSizeGreaterThanOrEqualTo(1);
86103
}
87104

88105
@Test
89106
public void testWithinChain() {
90-
this.scatterGatherWithinChain.send(new GenericMessage<String>("foo"));
107+
this.scatterGatherWithinChain.send(new GenericMessage<>("foo"));
91108
for (int i = 0; i < 3; i++) {
92109
Message<?> result = this.output.receive(10000);
93110
assertThat(result).isNotNull();

src/reference/asciidoc/scatter-gather.adoc

+2
Original file line numberDiff line numberDiff line change
@@ -209,5 +209,7 @@ Such an exception `payload` can be filtered out in the `MessageGroupProcessor` o
209209

210210
NOTE: Before sending scattering results to the gatherer, `ScatterGatherHandler` reinstates the request message headers, including reply and error channels if any.
211211
This way errors from the `AggregatingMessageHandler` are going to be propagated to the caller, even if an async hand off is applied in scatter recipient subflows.
212+
For successful operation, a `gatherResultChannel`, `originalReplyChannel` and `originalErrorChannel` headers must be transferred back to replies from scatter recipient subflows.
212213
In this case a reasonable, finite `gatherTimeout` must be configured for the `ScatterGatherHandler`.
213214
Otherwise it is going to be blocked waiting for a reply from the gatherer forever, by default.
215+

0 commit comments

Comments
 (0)