Skip to content

Commit c2c44c2

Browse files
sirimamillaartembilan
authored andcommitted
GH-3152: Fix for nested Scatter Gather
Fixes #3152 The upstream `gatherResultChannel` header has been missed when we produced a reply from nested scatter-gather Added Test Case for Nested Scatter Gather test Simplified the the test cases and added author in changed cases Corrected codestyle issue in Travis CI Removed additional OriginalReplyChannel and originalErrorChannel in Headers. Added additional not to be executed line of code in test case. Restored OriginalErrorChannel Header and removed error handling related fixes * Clean up code style and improve readability **Cherry-pick to 5.1.x & master**
1 parent 3850c7e commit c2c44c2

File tree

2 files changed

+47
-10
lines changed

2 files changed

+47
-10
lines changed

spring-integration-core/src/main/java/org/springframework/integration/scattergather/ScatterGatherHandler.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2014-2019 the original author or authors.
2+
* Copyright 2014-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -48,15 +48,14 @@
4848
*
4949
* @author Artem Bilan
5050
* @author Abdul Zaheer
51+
* @author Jayadev Sirimamilla
5152
*
5253
* @since 4.1
5354
*/
5455
public class ScatterGatherHandler extends AbstractReplyProducingMessageHandler implements Lifecycle {
5556

5657
private static final String GATHER_RESULT_CHANNEL = "gatherResultChannel";
5758

58-
private static final String ORIGINAL_REPLY_CHANNEL = "originalReplyChannel";
59-
6059
private static final String ORIGINAL_ERROR_CHANNEL = "originalErrorChannel";
6160

6261
private final MessageChannel scatterChannel;
@@ -167,9 +166,7 @@ private Message<?> enhanceScatterReplyMessage(Message<?> message) {
167166
MessageHeaders headers = message.getHeaders();
168167
return getMessageBuilderFactory()
169168
.fromMessage(message)
170-
.setHeader(MessageHeaders.REPLY_CHANNEL, headers.get(ORIGINAL_REPLY_CHANNEL))
171169
.setHeader(MessageHeaders.ERROR_CHANNEL, headers.get(ORIGINAL_ERROR_CHANNEL))
172-
.removeHeaders(ORIGINAL_REPLY_CHANNEL, ORIGINAL_ERROR_CHANNEL)
173170
.build();
174171
}
175172

@@ -182,15 +179,22 @@ protected Object handleRequestMessage(Message<?> requestMessage) {
182179
getMessageBuilderFactory()
183180
.fromMessage(requestMessage)
184181
.setHeader(GATHER_RESULT_CHANNEL, gatherResultChannel)
185-
.setHeader(ORIGINAL_REPLY_CHANNEL, requestMessageHeaders.getReplyChannel())
186182
.setHeader(ORIGINAL_ERROR_CHANNEL, requestMessageHeaders.getErrorChannel())
187183
.setReplyChannel(this.gatherChannel)
188184
.setErrorChannelName(this.errorChannelName)
189185
.build();
190186

191187
this.messagingTemplate.send(this.scatterChannel, scatterMessage);
192188

193-
return gatherResultChannel.receive(this.gatherTimeout);
189+
Message<?> gatherResult = gatherResultChannel.receive(this.gatherTimeout);
190+
if (gatherResult != null) {
191+
return getMessageBuilderFactory()
192+
.fromMessage(gatherResult)
193+
.removeHeaders(GATHER_RESULT_CHANNEL, ORIGINAL_ERROR_CHANNEL,
194+
MessageHeaders.REPLY_CHANNEL, MessageHeaders.ERROR_CHANNEL);
195+
}
196+
197+
return null;
194198
}
195199

196200
@Override

spring-integration-core/src/test/java/org/springframework/integration/dsl/routers/RouterTests.java

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2016-2019 the original author or authors.
2+
* Copyright 2016-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -42,9 +42,11 @@
4242
import org.springframework.integration.config.EnableIntegration;
4343
import org.springframework.integration.config.EnableMessageHistory;
4444
import org.springframework.integration.dsl.IntegrationFlow;
45+
import org.springframework.integration.dsl.IntegrationFlowDefinition;
4546
import org.springframework.integration.dsl.IntegrationFlows;
4647
import org.springframework.integration.dsl.MessageChannels;
4748
import org.springframework.integration.expression.FunctionExpression;
49+
import org.springframework.integration.store.MessageGroup;
4850
import org.springframework.integration.support.MessageBuilder;
4951
import org.springframework.messaging.Message;
5052
import org.springframework.messaging.MessageChannel;
@@ -62,6 +64,7 @@
6264
/**
6365
* @author Artem Bilan
6466
* @author Gary Russell
67+
* @author Jayadev Sirimamilla
6568
*
6669
* @since 5.0
6770
*/
@@ -531,7 +534,6 @@ public void testExceptionTypeRouteFlow() {
531534
private MessageChannel nestedScatterGatherFlowInput;
532535

533536
@Test
534-
@SuppressWarnings("unchecked")
535537
public void testNestedScatterGather() {
536538
QueueChannel replyChannel = new QueueChannel();
537539
Message<String> request = MessageBuilder.withPayload("this is a test")
@@ -578,7 +580,7 @@ public void testScatterGatherWithExecutorChannelSubFlow() {
578580
assertThat(receive).isNotNull();
579581
Object payload = receive.getPayload();
580582
assertThat(payload).isInstanceOf(List.class);
581-
assertThat(((List) payload).get(1)).isInstanceOf(RuntimeException.class);
583+
assertThat(((List<?>) payload).get(1)).isInstanceOf(RuntimeException.class);
582584
}
583585

584586
@Autowired
@@ -592,6 +594,25 @@ public void propagateErrorFromGatherer() {
592594
.withMessage("intentional");
593595
}
594596

597+
@Autowired
598+
@Qualifier("scatterGatherInSubFlow.input")
599+
MessageChannel scatterGatherInSubFlowChannel;
600+
601+
602+
@Test
603+
public void testNestedScatterGatherSuccess() {
604+
PollableChannel replyChannel = new QueueChannel();
605+
this.scatterGatherInSubFlowChannel.send(
606+
org.springframework.integration.support.MessageBuilder.withPayload("baz")
607+
.setReplyChannel(replyChannel)
608+
.build());
609+
610+
Message<?> receive = replyChannel.receive(10000);
611+
assertThat(receive).isNotNull();
612+
assertThat(receive.getPayload()).isEqualTo("baz");
613+
614+
}
615+
595616
@Configuration
596617
@EnableIntegration
597618
@EnableMessageHistory({ "recipientListOrder*", "recipient1*", "recipient2*" })
@@ -896,9 +917,21 @@ public IntegrationFlow propagateErrorFromGatherer(TaskExecutor taskExecutor) {
896917
throw new RuntimeException("intentional");
897918
}),
898919
sg -> sg.gatherTimeout(100))
920+
.transform(m -> "This should not be executed, results must have been propagated to Error Channel")
899921
.get();
900922
}
901923

924+
@Bean
925+
public IntegrationFlow scatterGatherInSubFlow() {
926+
return flow -> flow.scatterGather(s -> s.applySequence(true)
927+
.recipientFlow(inflow -> inflow
928+
.scatterGather(s1 -> s1.applySequence(true)
929+
.recipientFlow(IntegrationFlowDefinition::bridge),
930+
g -> g.outputProcessor(MessageGroup::getOne)
931+
)),
932+
g -> g.outputProcessor(MessageGroup::getOne));
933+
}
934+
902935
}
903936

904937
private static class RoutingTestBean {

0 commit comments

Comments
 (0)