Skip to content

GH-2744: ScatterGather: reinstate request headers #2750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanInitializationException;
import org.springframework.context.Lifecycle;
import org.springframework.integration.channel.ChannelInterceptorAware;
import org.springframework.integration.channel.FixedSubscriberChannel;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.channel.ReactiveStreamsSubscribableChannel;
Expand All @@ -30,14 +31,14 @@
import org.springframework.integration.endpoint.PollingConsumer;
import org.springframework.integration.endpoint.ReactiveStreamsConsumer;
import org.springframework.integration.handler.AbstractReplyProducingMessageHandler;
import org.springframework.integration.support.channel.HeaderChannelRegistry;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.PollableChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;

Expand Down Expand Up @@ -66,8 +67,6 @@ public class ScatterGatherHandler extends AbstractReplyProducingMessageHandler i

private AbstractEndpoint gatherEndpoint;

private HeaderChannelRegistry replyChannelRegistry;


public ScatterGatherHandler(MessageHandler scatterer, MessageHandler gatherer) {
this(new FixedSubscriberChannel(scatterer), gatherer);
Expand Down Expand Up @@ -134,52 +133,64 @@ else if (this.gatherChannel instanceof ReactiveStreamsSubscribableChannel) {
((MessageProducer) this.gatherer)
.setOutputChannel(new FixedSubscriberChannel(message -> {
MessageHeaders headers = message.getHeaders();
if (headers.containsKey(GATHER_RESULT_CHANNEL)) {
Object gatherResultChannel = headers.get(GATHER_RESULT_CHANNEL);
if (gatherResultChannel instanceof MessageChannel) {
messagingTemplate.send((MessageChannel) gatherResultChannel, message);
}
else if (gatherResultChannel instanceof String) {
messagingTemplate.send((String) gatherResultChannel, message);
}
MessageChannel gatherResultChannel = headers.get(GATHER_RESULT_CHANNEL, MessageChannel.class);
if (gatherResultChannel != null) {
this.messagingTemplate.send(gatherResultChannel, message);
}
else {
throw new MessageDeliveryException(message,
"The 'gatherResultChannel' header is required to delivery gather result.");
"The 'gatherResultChannel' header is required to deliver the gather result.");
}
}));

this.replyChannelRegistry =
beanFactory.getBean(IntegrationContextUtils.INTEGRATION_HEADER_CHANNEL_REGISTRY_BEAN_NAME,
HeaderChannelRegistry.class);
}

@Override
protected Object handleRequestMessage(Message<?> requestMessage) {
PollableChannel gatherResultChannel = new QueueChannel();

Object gatherResultChannelName = this.replyChannelRegistry.channelToChannelName(gatherResultChannel);
MessageChannel replyChannel = this.gatherChannel;

if (replyChannel instanceof ChannelInterceptorAware) {
((ChannelInterceptorAware) replyChannel)
.addInterceptor(0,
new ChannelInterceptor() {

@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
return enhanceScatterReplyMessage(message, gatherResultChannel, requestMessage);
}

});
}
else {
replyChannel =
new FixedSubscriberChannel(message ->
this.messagingTemplate.send(this.gatherChannel,
enhanceScatterReplyMessage(message, gatherResultChannel, requestMessage)));
}

Message<?> scatterMessage =
getMessageBuilderFactory()
.fromMessage(requestMessage)
.setHeader(GATHER_RESULT_CHANNEL, gatherResultChannelName)
.setReplyChannel(this.gatherChannel)
.setReplyChannel(replyChannel)
.setErrorChannelName(this.errorChannelName)
.build();

this.messagingTemplate.send(this.scatterChannel, scatterMessage);

Message<?> gatherResult = gatherResultChannel.receive(this.gatherTimeout);
if (gatherResult != null) {
return getMessageBuilderFactory()
.fromMessage(gatherResult)
.removeHeader(GATHER_RESULT_CHANNEL)
.setHeader(MessageHeaders.REPLY_CHANNEL, requestMessage.getHeaders().getReplyChannel())
.setHeader(MessageHeaders.ERROR_CHANNEL, requestMessage.getHeaders().getErrorChannel());
}
return gatherResultChannel.receive(this.gatherTimeout);
}

private Message<?> enhanceScatterReplyMessage(Message<?> message, PollableChannel gatherResultChannel,
Message<?> requestMessage) {

return null;
MessageHeaders requestMessageHeaders = requestMessage.getHeaders();
return getMessageBuilderFactory()
.fromMessage(message)
.setHeader(GATHER_RESULT_CHANNEL, gatherResultChannel)
.setHeader(MessageHeaders.REPLY_CHANNEL, requestMessageHeaders.getReplyChannel())
.setHeader(MessageHeaders.ERROR_CHANNEL, requestMessageHeaders.getErrorChannel())
.build();
}

@Override
Expand All @@ -201,11 +212,11 @@ public boolean isRunning() {
return this.gatherEndpoint == null || this.gatherEndpoint.isRunning();
}

private void checkClass(Class<?> gathererClass, String className, String type) throws LinkageError {
private static void checkClass(Class<?> gathererClass, String className, String type) throws LinkageError {
try {
Class<?> clazz = ClassUtils.forName(className, ClassUtils.getDefaultClassLoader());
Assert.isAssignable(clazz, gathererClass, () -> "the '" + type + "' must be an " + className + " " +
"instance");
Assert.isAssignable(clazz, gathererClass,
() -> "the '" + type + "' must be an " + className + " " + "instance");
}
catch (ClassNotFoundException e) {
throw new IllegalStateException("The class for '" + className + "' cannot be loaded", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.integration.dsl.routers;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
Expand All @@ -29,6 +30,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.junit.Test;
Expand Down Expand Up @@ -590,6 +592,16 @@ public void testScatterGatherWithExecutorChannelSubFlow() {
assertThat(((List) payload).get(1), instanceOf(RuntimeException.class));
}

@Autowired
@Qualifier("propagateErrorFromGatherer.gateway")
private Function<Object, ?> propagateErrorFromGathererGateway;

@Test
public void propagateErrorFromGatherer() {
assertThatThrownBy(() -> propagateErrorFromGathererGateway.apply("bar"))
.hasMessage("intentional");
}

@Configuration
@EnableIntegration
@EnableMessageHistory({ "recipientListOrder*", "recipient1*", "recipient2*" })
Expand Down Expand Up @@ -881,6 +893,22 @@ public Message<?> processAsyncScatterError(MessagingException payload) {
.build();
}

@Bean
public IntegrationFlow propagateErrorFromGatherer(TaskExecutor taskExecutor) {
return IntegrationFlows.from(Function.class)
.scatterGather(s -> s
.applySequence(true)
.recipientFlow(subFlow -> subFlow
.channel(c -> c.executor(taskExecutor))
.transform(p -> "foo")),
g -> g
.outputProcessor(group -> {
throw new RuntimeException("intentional");
}),
sg -> sg.gatherTimeout(100))
.get();
}

}

private static class RoutingTestBean {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2014-2015 the original author or authors.
* Copyright 2014-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,7 +24,7 @@

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.Executor;

import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -278,8 +278,8 @@ public MessageChannel gatherChannel() {
}

@Bean
public SubscribableChannel scatterAuctionWithGatherChannel() {
PublishSubscribeChannel channel = new PublishSubscribeChannel(Executors.newCachedThreadPool());
public SubscribableChannel scatterAuctionWithGatherChannel(Executor executor) {
PublishSubscribeChannel channel = new PublishSubscribeChannel(executor);
channel.setApplySequence(true);
return channel;
}
Expand All @@ -296,7 +296,8 @@ public MessageHandler gatherer2() {
@Bean
@ServiceActivator(inputChannel = "inputAuctionWithGatherChannel")
public MessageHandler scatterGatherAuctionWithGatherChannel() {
ScatterGatherHandler handler = new ScatterGatherHandler(scatterAuctionWithGatherChannel(), gatherer2());
ScatterGatherHandler handler =
new ScatterGatherHandler(scatterAuctionWithGatherChannel(null), gatherer2());
handler.setGatherChannel(gatherChannel());
handler.setOutputChannel(output());
return handler;
Expand Down
5 changes: 5 additions & 0 deletions src/reference/asciidoc/scatter-gather.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,8 @@ public Message<?> processAsyncScatterError(MessagingException payload) {
To produce a proper reply, we have to copy headers (including `replyChannel` and `errorChannel`) from the `failedMessage` of the `MessagingException` that has been sent to the `scatterGatherErrorChannel` by the `MessagePublishingErrorHandler`.
This way the target exception is returned to the gatherer of the `ScatterGatherHandler` for reply messages group completion.
Such an exception `payload` can be filtered out in the `MessageGroupProcessor` of the gatherer or processed other way downstream, after the scatter-gather endpoint.

NOTE: Before sending scattering results to the gatherer, `ScatterGatherHandler` reinstates the request message headers, including reply and error channels if any.
This way errors from the `AggregatingMessageHandler` are going to be propagated to the caller, even if async an hand off is applied in scatter recipient subflows.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an async - will fix on merge

In this case a reasonable, finite `gatherTimeout` must be configured for the `ScatterGatherHandler`.
Otherwise it is going to be blocked waiting for a reply from the gatherer forever, by default.