Skip to content

GH-2967: Fix ScatterGatherH for headers copy #2968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 19, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ public class ScatterGatherHandler extends AbstractReplyProducingMessageHandler i

private static final String GATHER_RESULT_CHANNEL = "gatherResultChannel";

private static final String ORIGINAL_REPLY_CHANNEL = "originalReplyChannel";

private static final String ORIGINAL_ERROR_CHANNEL = "originalErrorChannel";

private final MessageChannel scatterChannel;

private final MessageHandler gatherer;
Expand Down Expand Up @@ -107,9 +111,24 @@ public void setErrorChannelName(String errorChannelName) {
protected void doInit() {
BeanFactory beanFactory = getBeanFactory();
if (this.gatherChannel == null) {
this.gatherChannel = new FixedSubscriberChannel(this.gatherer);
this.gatherChannel =
new FixedSubscriberChannel((message) ->
this.gatherer.handleMessage(enhanceScatterReplyMessage(message)));
}
else {
Assert.isInstanceOf(InterceptableChannel.class, this.gatherChannel,
() -> "An injected 'gatherChannel' '" + this.gatherChannel +
"' must be an 'InterceptableChannel' instance.");
((InterceptableChannel) this.gatherChannel)
.addInterceptor(0,
new ChannelInterceptor() {

@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
return enhanceScatterReplyMessage(message);
}

});
if (this.gatherChannel instanceof SubscribableChannel) {
this.gatherEndpoint = new EventDrivenConsumer((SubscribableChannel) this.gatherChannel, this.gatherer);
}
Expand All @@ -121,7 +140,7 @@ else if (this.gatherChannel instanceof ReactiveStreamsSubscribableChannel) {
this.gatherEndpoint = new ReactiveStreamsConsumer(this.gatherChannel, this.gatherer);
}
else {
throw new BeanInitializationException("Unsupported 'replyChannel' type '" +
throw new BeanInitializationException("Unsupported 'gatherChannel' type '" +
this.gatherChannel.getClass() + "'. " +
"'SubscribableChannel', 'PollableChannel' or 'ReactiveStreamsSubscribableChannel' " +
"types are supported.");
Expand All @@ -131,7 +150,7 @@ else if (this.gatherChannel instanceof ReactiveStreamsSubscribableChannel) {
}

((MessageProducer) this.gatherer)
.setOutputChannel(new FixedSubscriberChannel(message -> {
.setOutputChannel(new FixedSubscriberChannel((message) -> {
MessageHeaders headers = message.getHeaders();
MessageChannel gatherResultChannel = headers.get(GATHER_RESULT_CHANNEL, MessageChannel.class);
if (gatherResultChannel != null) {
Expand All @@ -144,35 +163,28 @@ else if (this.gatherChannel instanceof ReactiveStreamsSubscribableChannel) {
}));
}

private Message<?> enhanceScatterReplyMessage(Message<?> message) {
MessageHeaders headers = message.getHeaders();
return getMessageBuilderFactory()
.fromMessage(message)
.setHeader(MessageHeaders.REPLY_CHANNEL, headers.get(ORIGINAL_REPLY_CHANNEL))
.setHeader(MessageHeaders.ERROR_CHANNEL, headers.get(ORIGINAL_ERROR_CHANNEL))
.removeHeaders(ORIGINAL_REPLY_CHANNEL, ORIGINAL_ERROR_CHANNEL)
.build();
}

@Override
protected Object handleRequestMessage(Message<?> requestMessage) {
MessageHeaders requestMessageHeaders = requestMessage.getHeaders();
PollableChannel gatherResultChannel = new QueueChannel();

MessageChannel replyChannel = this.gatherChannel;

if (replyChannel instanceof InterceptableChannel) {
((InterceptableChannel) replyChannel)
.addInterceptor(0,
new ChannelInterceptor() {

@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
return enhanceScatterReplyMessage(message, gatherResultChannel, requestMessage);
}

});
}
else {
replyChannel =
new FixedSubscriberChannel(message ->
this.messagingTemplate.send(this.gatherChannel,
enhanceScatterReplyMessage(message, gatherResultChannel, requestMessage)));
}

Message<?> scatterMessage =
getMessageBuilderFactory()
.fromMessage(requestMessage)
.setReplyChannel(replyChannel)
.setHeader(GATHER_RESULT_CHANNEL, gatherResultChannel)
.setHeader(ORIGINAL_REPLY_CHANNEL, requestMessageHeaders.getReplyChannel())
.setHeader(ORIGINAL_ERROR_CHANNEL, requestMessageHeaders.getErrorChannel())
.setReplyChannel(this.gatherChannel)
.setErrorChannelName(this.errorChannelName)
.build();

Expand All @@ -181,18 +193,6 @@ public Message<?> preSend(Message<?> message, MessageChannel channel) {
return gatherResultChannel.receive(this.gatherTimeout);
}

private Message<?> enhanceScatterReplyMessage(Message<?> message, PollableChannel gatherResultChannel,
Message<?> requestMessage) {

MessageHeaders requestMessageHeaders = requestMessage.getHeaders();
return getMessageBuilderFactory()
.fromMessage(message)
.setHeader(GATHER_RESULT_CHANNEL, gatherResultChannel)
.setHeader(MessageHeaders.REPLY_CHANNEL, requestMessageHeaders.getReplyChannel())
.setHeader(MessageHeaders.ERROR_CHANNEL, requestMessageHeaders.getErrorChannel())
.build();
}

@Override
public void start() {
if (this.gatherEndpoint != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
xmlns="http://www.springframework.org/schema/integration"
xmlns:task="http://www.springframework.org/schema/task"
xsi:schemaLocation="http://www.springframework.org/schema/beans https://www.springframework.org/schema/beans/spring-beans.xsd
http://www.springframework.org/schema/integration https://www.springframework.org/schema/integration/spring-integration.xsd http://www.springframework.org/schema/task https://www.springframework.org/schema/task/spring-task.xsd">
http://www.springframework.org/schema/integration https://www.springframework.org/schema/integration/spring-integration.xsd
http://www.springframework.org/schema/task https://www.springframework.org/schema/task/spring-task.xsd">

<channel id="output">
<queue/>
Expand Down Expand Up @@ -52,9 +53,12 @@

<!--Sync scenario-->

<gateway id="gateway" default-request-channel="gatewayAuction" default-reply-timeout="10000" />
<gateway id="gateway" default-request-channel="gatewayAuction" default-reply-timeout="10000"/>

<scatter-gather input-channel="gatewayAuction" output-channel="bridgeChannel" scatter-channel="auctionChannel">
<channel id="gatherChannel2"/>

<scatter-gather input-channel="gatewayAuction" output-channel="bridgeChannel" scatter-channel="auctionChannel"
gather-channel="gatherChannel2">
<gatherer release-strategy-expression="messages.^[payload gt 5] != null or size() == 3"/>
</scatter-gather>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,19 @@
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.PollableChannel;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;

/**
* @author Artem Bilan
* @author Gary Russell
*
* @since 4.1
*/
@ContextConfiguration
@RunWith(SpringJUnit4ClassRunner.class)
@DirtiesContext
public class ScatterGatherTests {

@Autowired
Expand All @@ -58,36 +61,50 @@ public class ScatterGatherTests {

@Test
public void testAuction() {
this.inputAuction.send(new GenericMessage<String>("foo"));
this.inputAuction.send(new GenericMessage<>("foo"));
Message<?> bestQuoteMessage = this.output.receive(10000);
assertThat(bestQuoteMessage).isNotNull();
Object payload = bestQuoteMessage.getPayload();
assertThat(payload).isInstanceOf(List.class);
assertThat(((List<?>) payload).size()).isGreaterThanOrEqualTo(1);
assertThat(bestQuoteMessage)
.isNotNull()
.extracting(Message::getPayload)
.isInstanceOf(List.class)
.asList()
.hasSizeGreaterThanOrEqualTo(1);
}

@Test
public void testDistribution() {
this.inputDistribution.send(new GenericMessage<String>("foo"));
this.inputDistribution.send(new GenericMessage<>("foo"));
Message<?> bestQuoteMessage = this.output.receive(10000);
assertThat(bestQuoteMessage).isNotNull();
Object payload = bestQuoteMessage.getPayload();
assertThat(payload).isInstanceOf(List.class);
assertThat(((List<?>) payload).size()).isGreaterThanOrEqualTo(1);
assertThat(bestQuoteMessage)
.isNotNull()
.extracting(Message::getPayload)
.isInstanceOf(List.class)
.asList()
.hasSizeGreaterThanOrEqualTo(1);
}

@Test
public void testGatewayScatterGather() {
Message<?> bestQuoteMessage = this.gateway.exchange(new GenericMessage<String>("foo"));
assertThat(bestQuoteMessage).isNotNull();
Object payload = bestQuoteMessage.getPayload();
assertThat(payload).isInstanceOf(List.class);
assertThat(((List<?>) payload).size()).isGreaterThanOrEqualTo(1);
Message<?> bestQuoteMessage = this.gateway.exchange(new GenericMessage<>("foo"));
assertThat(bestQuoteMessage)
.isNotNull()
.extracting(Message::getPayload)
.isInstanceOf(List.class)
.asList()
.hasSizeGreaterThanOrEqualTo(1);

bestQuoteMessage = this.gateway.exchange(new GenericMessage<>("bar"));
assertThat(bestQuoteMessage)
.isNotNull()
.extracting(Message::getPayload)
.isInstanceOf(List.class)
.asList()
.hasSizeGreaterThanOrEqualTo(1);
}

@Test
public void testWithinChain() {
this.scatterGatherWithinChain.send(new GenericMessage<String>("foo"));
this.scatterGatherWithinChain.send(new GenericMessage<>("foo"));
for (int i = 0; i < 3; i++) {
Message<?> result = this.output.receive(10000);
assertThat(result).isNotNull();
Expand Down
2 changes: 2 additions & 0 deletions src/reference/asciidoc/scatter-gather.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -209,5 +209,7 @@ Such an exception `payload` can be filtered out in the `MessageGroupProcessor` o

NOTE: Before sending scattering results to the gatherer, `ScatterGatherHandler` reinstates the request message headers, including reply and error channels if any.
This way errors from the `AggregatingMessageHandler` are going to be propagated to the caller, even if an async hand off is applied in scatter recipient subflows.
To make it working properly, a `gatherResultChannel`, `originalReplyChannel` and `originalErrorChannel` headers must be transferred back to replies from scatter recipient subflows.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/To make it working properly,/For successful operation,/

In this case a reasonable, finite `gatherTimeout` must be configured for the `ScatterGatherHandler`.
Otherwise it is going to be blocked waiting for a reply from the gatherer forever, by default.