diff --git a/spring-integration-core/src/main/java/org/springframework/integration/aggregator/FluxAggregatorMessageHandler.java b/spring-integration-core/src/main/java/org/springframework/integration/aggregator/FluxAggregatorMessageHandler.java new file mode 100644 index 00000000000..844dc0f73df --- /dev/null +++ b/spring-integration-core/src/main/java/org/springframework/integration/aggregator/FluxAggregatorMessageHandler.java @@ -0,0 +1,271 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.aggregator; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import java.util.function.Predicate; + +import org.springframework.context.Lifecycle; +import org.springframework.integration.IntegrationMessageHeaderAccessor; +import org.springframework.integration.channel.ReactiveStreamsSubscribableChannel; +import org.springframework.integration.handler.AbstractMessageProducingHandler; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.util.Assert; + +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; + +/** + * The {@link AbstractMessageProducingHandler} implementation for aggregation logic based on the + * Reactor's {@link Flux#groupBy} and {@link Flux#window} operators. + *

+ * The incoming messages are emitted into a {@link FluxSink} provided by the {@link Flux#create} + * initialized in the constructor. + *

+ * The resulting windows for groups are wrapped into the {@link Message}s for downstream consumption. + *

+ * If the {@link #getOutputChannel()} is not a {@link ReactiveStreamsSubscribableChannel} instance, + * a subscription for the whole aggregating {@link Flux} is happened in the {@link #start()} method. + * + * @author Artem Bilan + * + * @since 5.2 + */ +public class FluxAggregatorMessageHandler extends AbstractMessageProducingHandler implements Lifecycle { + + private final AtomicBoolean subscribed = new AtomicBoolean(); + + private final Flux> aggregatorFlux; + + private CorrelationStrategy correlationStrategy = + new HeaderAttributeCorrelationStrategy(IntegrationMessageHeaderAccessor.CORRELATION_ID); + + private Predicate> boundaryTrigger; + + private Function, Integer> windowSizeFunction = FluxAggregatorMessageHandler::sequenceSizeHeader; + + private Function>, Flux>>> windowConfigurer; + + private Duration windowTimespan; + + private Function>, Mono>> combineFunction = this::messageForWindowFlux; + + private FluxSink> sink; + + private volatile Disposable subscription; + + /** + * Create an instance with a {@link Flux#create} and apply {@link Flux#groupBy} and {@link Flux#window} + * transformation into it. + */ + public FluxAggregatorMessageHandler() { + this.aggregatorFlux = + Flux.>create(emitter -> this.sink = emitter, FluxSink.OverflowStrategy.BUFFER) + .groupBy(this::groupBy) + .flatMap((group) -> group.transform(this::releaseBy)) + .publish() + .autoConnect(); + } + + private Object groupBy(Message message) { + return this.correlationStrategy.getCorrelationKey(message); + } + + private Flux> releaseBy(Flux> groupFlux) { + return groupFlux + .transform(this.windowConfigurer != null ? this.windowConfigurer : this::applyWindowOptions) + .flatMap((windowFlux) -> windowFlux.transform(this.combineFunction)); + } + + private Flux>> applyWindowOptions(Flux> groupFlux) { + if (this.boundaryTrigger != null) { + return groupFlux.windowUntil(this.boundaryTrigger); + } + return groupFlux + .switchOnFirst((signal, group) -> { + if (signal.hasValue()) { + Integer maxSize = this.windowSizeFunction.apply(signal.get()); + if (maxSize != null) { + if (this.windowTimespan != null) { + return group.windowTimeout(maxSize, this.windowTimespan); + } + else { + return group.window(maxSize); + } + } + else { + if (this.windowTimespan != null) { + return group.window(this.windowTimespan); + } + else { + return Flux.error( + new IllegalStateException( + "One of the 'boundaryTrigger', 'windowSizeFunction' or " + + "'windowTimespan' options must be configured or " + + "'sequenceSize' header must be supplied in the messages " + + "to aggregate.")); + } + } + } + else { + return Flux.just(group); + } + }); + } + + /** + * Configure a {@link CorrelationStrategy} to determine a group key from the incoming messages. + * By default a {@link HeaderAttributeCorrelationStrategy} is used against a + * {@link IntegrationMessageHeaderAccessor#CORRELATION_ID} header value. + * @param correlationStrategy the {@link CorrelationStrategy} to use. + */ + public void setCorrelationStrategy(CorrelationStrategy correlationStrategy) { + Assert.notNull(correlationStrategy, "'correlationStrategy' must not be null"); + this.correlationStrategy = correlationStrategy; + } + + /** + * Configure a transformation {@link Function} to apply for a {@link Flux} window to emit. + * Requires a {@link Mono} result with a {@link Message} as value as a combination result + * of the incoming {@link Flux} for window. + * By default a {@link Flux} for window is fully wrapped into a message with headers copied + * from the first message in window. Such a {@link Flux} in the payload has to be subscribed + * and consumed downstream. + * @param combineFunction the {@link Function} to use for result windows transformation. + */ + public void setCombineFunction(Function>, Mono>> combineFunction) { + Assert.notNull(combineFunction, "'combineFunction' must not be null"); + this.combineFunction = combineFunction; + } + + /** + * Configure a {@link Predicate} for messages to determine a window boundary in the + * {@link Flux#windowUntil} operator. + * Has a precedence over any other window configuration options. + * @param boundaryTrigger the {@link Predicate} to use for window boundary. + * @see Flux#windowUntil(Predicate) + */ + public void setBoundaryTrigger(Predicate> boundaryTrigger) { + this.boundaryTrigger = boundaryTrigger; + } + + /** + * Specify a size for windows to close. + * Can be combined with the {@link #setWindowTimespan(Duration)}. + * @param windowSize the size for window to use. + * @see Flux#window(int) + * @see Flux#windowTimeout(int, Duration) + */ + public void setWindowSize(int windowSize) { + setWindowSizeFunction((message) -> windowSize); + } + + /** + * Specify a {@link Function} to determine a size for windows to close against the first message in group. + * Tne result of the function can be combined with the {@link #setWindowTimespan(Duration)}. + * By default an {@link IntegrationMessageHeaderAccessor#SEQUENCE_SIZE} header is consulted. + * @param windowSizeFunction the {@link Function} to use to determine a window size + * against a first message in the group. + * @see Flux#window(int) + * @see Flux#windowTimeout(int, Duration) + */ + public void setWindowSizeFunction(Function, Integer> windowSizeFunction) { + Assert.notNull(windowSizeFunction, "'windowSizeFunction' must not be null"); + this.windowSizeFunction = windowSizeFunction; + } + + /** + * Configure a {@link Duration} for closing windows periodically. + * Can be combined with the {@link #setWindowSize(int)} or {@link #setWindowSizeFunction(Function)}. + * @param windowTimespan the {@link Duration} to use for windows to close periodically. + * @see Flux#window(Duration) + * @see Flux#windowTimeout(int, Duration) + */ + public void setWindowTimespan(Duration windowTimespan) { + this.windowTimespan = windowTimespan; + } + + /** + * Configure a {@link Function} to apply a transformation into the grouping {@link Flux} + * for any arbitrary {@link Flux#window} options not covered by the simple options. + * Has a precedence over any other window configuration options. + * @param windowConfigurer the {@link Function} to apply any custom window transformation. + */ + public void setWindowConfigurer(Function>, Flux>>> windowConfigurer) { + this.windowConfigurer = windowConfigurer; + } + + @Override + public void start() { + if (this.subscribed.compareAndSet(false, true)) { + MessageChannel outputChannel = getOutputChannel(); + if (outputChannel instanceof ReactiveStreamsSubscribableChannel) { + ((ReactiveStreamsSubscribableChannel) outputChannel).subscribeTo(this.aggregatorFlux); + } + else { + this.subscription = + this.aggregatorFlux.subscribe((messageToSend) -> produceOutput(messageToSend, messageToSend)); + } + } + } + + @Override + public void stop() { + if (this.subscribed.compareAndSet(true, false) && this.subscription != null) { + this.subscription.dispose(); + } + } + + @Override + public boolean isRunning() { + return this.subscribed.get(); + } + + @Override + protected void handleMessageInternal(Message message) { + Assert.state(isRunning(), + "The 'FluxAggregatorMessageHandler' has not been started to accept incoming messages"); + + this.sink.next(message); + } + + @Override + protected boolean shouldCopyRequestHeaders() { + return false; + } + + private Mono> messageForWindowFlux(Flux> messageFlux) { + Flux> window = messageFlux.publish().autoConnect(); + return window + .next() + .map((first) -> + getMessageBuilderFactory() + .withPayload(Flux.concat(Mono.just(first), window)) + .copyHeaders(first.getHeaders()) + .build()); + } + + private static Integer sequenceSizeHeader(Message message) { + return message.getHeaders().get(IntegrationMessageHeaderAccessor.SEQUENCE_SIZE, Integer.class); + } + +} diff --git a/spring-integration-core/src/main/java/org/springframework/integration/splitter/AbstractMessageSplitter.java b/spring-integration-core/src/main/java/org/springframework/integration/splitter/AbstractMessageSplitter.java index 06b2a6e341d..740fbc219fe 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/splitter/AbstractMessageSplitter.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/splitter/AbstractMessageSplitter.java @@ -37,6 +37,7 @@ import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; import com.fasterxml.jackson.core.TreeNode; import reactor.core.publisher.Flux; @@ -140,7 +141,7 @@ protected final Object handleRequestMessage(Message message) { } } else if (result.getClass().isArray()) { - Object[] items = (Object[]) result; + Object[] items = ObjectUtils.toObjectArray(result); sequenceSize = items.length; if (reactive) { flux = Flux.fromArray(items); diff --git a/spring-integration-core/src/test/java/org/springframework/integration/aggregator/FluxAggregatorMessageHandlerTests.java b/spring-integration-core/src/test/java/org/springframework/integration/aggregator/FluxAggregatorMessageHandlerTests.java new file mode 100644 index 00000000000..1f93f0d4fa0 --- /dev/null +++ b/spring-integration-core/src/test/java/org/springframework/integration/aggregator/FluxAggregatorMessageHandlerTests.java @@ -0,0 +1,295 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.aggregator; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.Executors; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.junit.jupiter.api.Test; + +import org.springframework.integration.IntegrationMessageHeaderAccessor; +import org.springframework.integration.channel.QueueChannel; +import org.springframework.integration.support.MessageBuilder; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.support.GenericMessage; + +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + +/** + * @author Artem Bilan + * + * @since 5.2 + */ +@SuppressWarnings("unchecked") +class FluxAggregatorMessageHandlerTests { + + @Test + void testDefaultAggregation() { + QueueChannel resultChannel = new QueueChannel(); + FluxAggregatorMessageHandler fluxAggregatorMessageHandler = new FluxAggregatorMessageHandler(); + fluxAggregatorMessageHandler.setOutputChannel(resultChannel); + fluxAggregatorMessageHandler.start(); + + for (int i = 0; i < 20; i++) { + Message messageToAggregate = + MessageBuilder.withPayload("" + i) + .setCorrelationId(i % 2) + .setSequenceSize(10) + .build(); + fluxAggregatorMessageHandler.handleMessage(messageToAggregate); + } + + Message result = resultChannel.receive(10_000); + assertThat(result).isNotNull() + .extracting(Message::getHeaders) + .satisfies((headers) -> + assertThat((MessageHeaders) headers) + .containsEntry(IntegrationMessageHeaderAccessor.CORRELATION_ID, 0)); + + Object payload = result.getPayload(); + assertThat(payload).isInstanceOf(Flux.class); + + Flux> window = (Flux>) payload; + + StepVerifier.create( + window.map(Message::getPayload) + .cast(String.class)) + .expectNextSequence( + IntStream.iterate(0, i -> i + 2) + .limit(10) + .mapToObj(Objects::toString) + .collect(Collectors.toList())) + .verifyComplete(); + + result = resultChannel.receive(10_000); + assertThat(result).isNotNull() + .extracting(Message::getHeaders) + .satisfies((headers) -> + assertThat((MessageHeaders) headers) + .containsEntry(IntegrationMessageHeaderAccessor.CORRELATION_ID, 1)); + + payload = result.getPayload(); + window = (Flux>) payload; + + StepVerifier.create( + window.map(Message::getPayload) + .cast(String.class)) + .expectNextSequence( + IntStream.iterate(1, i -> i + 2) + .limit(10) + .mapToObj(Objects::toString) + .collect(Collectors.toList())) + .verifyComplete(); + + fluxAggregatorMessageHandler.stop(); + } + + @Test + void testCustomCombineFunction() { + QueueChannel resultChannel = new QueueChannel(); + FluxAggregatorMessageHandler fluxAggregatorMessageHandler = new FluxAggregatorMessageHandler(); + fluxAggregatorMessageHandler.setOutputChannel(resultChannel); + fluxAggregatorMessageHandler.setWindowSize(10); + fluxAggregatorMessageHandler.setCombineFunction( + (messageFlux) -> + messageFlux + .map(Message::getPayload) + .collectList() + .map(GenericMessage::new)); + fluxAggregatorMessageHandler.start(); + + for (int i = 0; i < 20; i++) { + Message messageToAggregate = + MessageBuilder.withPayload(i) + .setCorrelationId(i % 2) + .build(); + fluxAggregatorMessageHandler.handleMessage(messageToAggregate); + } + + Message result = resultChannel.receive(10_000); + assertThat(result).isNotNull(); + + Object payload = result.getPayload(); + assertThat(payload) + .isInstanceOf(List.class) + .asList() + .containsExactly( + IntStream.iterate(0, i -> i + 2) + .limit(10) + .boxed() + .toArray()); + + result = resultChannel.receive(10_000); + assertThat(result).isNotNull(); + + payload = result.getPayload(); + assertThat(payload) + .isInstanceOf(List.class) + .asList() + .containsExactly( + IntStream.iterate(1, i -> i + 2) + .limit(10) + .boxed() + .toArray()); + + fluxAggregatorMessageHandler.stop(); + } + + @Test + void testWindowTimespan() { + QueueChannel resultChannel = new QueueChannel(); + FluxAggregatorMessageHandler fluxAggregatorMessageHandler = new FluxAggregatorMessageHandler(); + fluxAggregatorMessageHandler.setOutputChannel(resultChannel); + fluxAggregatorMessageHandler.setWindowTimespan(Duration.ofMillis(100)); + fluxAggregatorMessageHandler.start(); + + Executors.newSingleThreadExecutor() + .submit(() -> { + for (int i = 0; i < 10; i++) { + Message messageToAggregate = + MessageBuilder.withPayload(i) + .setCorrelationId("1") + .build(); + fluxAggregatorMessageHandler.handleMessage(messageToAggregate); + Thread.sleep(20); + } + return null; + }); + + Message result = resultChannel.receive(10_000); + assertThat(result).isNotNull(); + + Flux> window = (Flux>) result.getPayload(); + + List messageList = + window.map(Message::getPayload) + .cast(Integer.class) + .collectList() + .block(Duration.ofSeconds(10)); + + assertThat(messageList) + .isNotEmpty() + .hasSizeLessThan(10) + .contains(0, 1); + + result = resultChannel.receive(10_000); + assertThat(result).isNotNull(); + + window = (Flux>) result.getPayload(); + + messageList = + window.map(Message::getPayload) + .cast(Integer.class) + .collectList() + .block(Duration.ofSeconds(10)); + + assertThat(messageList) + .isNotEmpty() + .hasSizeLessThan(10) + .doesNotContain(0, 1); + + fluxAggregatorMessageHandler.stop(); + } + + @Test + void testBoundaryTrigger() { + QueueChannel resultChannel = new QueueChannel(); + FluxAggregatorMessageHandler fluxAggregatorMessageHandler = new FluxAggregatorMessageHandler(); + fluxAggregatorMessageHandler.setOutputChannel(resultChannel); + fluxAggregatorMessageHandler.setBoundaryTrigger((message) -> "terminate".equals(message.getPayload())); + fluxAggregatorMessageHandler.start(); + + for (int i = 0; i < 3; i++) { + Message messageToAggregate = + MessageBuilder.withPayload("" + i) + .setCorrelationId("1") + .build(); + fluxAggregatorMessageHandler.handleMessage(messageToAggregate); + } + + fluxAggregatorMessageHandler.handleMessage( + MessageBuilder.withPayload("terminate") + .setCorrelationId("1") + .build()); + + fluxAggregatorMessageHandler.handleMessage( + MessageBuilder.withPayload("next") + .setCorrelationId("1") + .build()); + + Message result = resultChannel.receive(10_000); + assertThat(result).isNotNull(); + + Flux> window = (Flux>) result.getPayload(); + + StepVerifier.create( + window.map(Message::getPayload) + .cast(String.class)) + .expectNext("0", "1", "2") + .expectNext("terminate") + .verifyComplete(); + + fluxAggregatorMessageHandler.stop(); + } + + + @Test + void testCustomWindow() { + QueueChannel resultChannel = new QueueChannel(); + FluxAggregatorMessageHandler fluxAggregatorMessageHandler = new FluxAggregatorMessageHandler(); + fluxAggregatorMessageHandler.setOutputChannel(resultChannel); + fluxAggregatorMessageHandler.setWindowConfigurer((group) -> + group.windowWhile((message) -> + message.getPayload() instanceof Integer)); + fluxAggregatorMessageHandler.start(); + + for (int i = 0; i < 3; i++) { + Message messageToAggregate = + MessageBuilder.withPayload(i) + .setCorrelationId("1") + .build(); + fluxAggregatorMessageHandler.handleMessage(messageToAggregate); + } + + fluxAggregatorMessageHandler.handleMessage( + MessageBuilder.withPayload("terminate") + .setCorrelationId("1") + .build()); + + Message result = resultChannel.receive(10_000); + assertThat(result).isNotNull(); + + Flux> window = (Flux>) result.getPayload(); + + StepVerifier.create( + window.map(Message::getPayload) + .cast(Integer.class)) + .expectNext(0, 1, 2) + .verifyComplete(); + + fluxAggregatorMessageHandler.stop(); + } + +} diff --git a/spring-integration-core/src/test/java/org/springframework/integration/dsl/correlation/CorrelationHandlerTests.java b/spring-integration-core/src/test/java/org/springframework/integration/dsl/correlation/CorrelationHandlerTests.java index 327be1b9239..9d2b20b29c6 100644 --- a/spring-integration-core/src/test/java/org/springframework/integration/dsl/correlation/CorrelationHandlerTests.java +++ b/spring-integration-core/src/test/java/org/springframework/integration/dsl/correlation/CorrelationHandlerTests.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.junit.Test; import org.junit.runner.RunWith; @@ -32,6 +33,7 @@ import org.springframework.context.annotation.Configuration; import org.springframework.core.task.TaskExecutor; import org.springframework.integration.IntegrationMessageHeaderAccessor; +import org.springframework.integration.aggregator.FluxAggregatorMessageHandler; import org.springframework.integration.aggregator.HeaderAttributeCorrelationStrategy; import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.config.EnableIntegration; @@ -40,6 +42,7 @@ import org.springframework.integration.dsl.MessageChannelSpec; import org.springframework.integration.dsl.MessageChannels; import org.springframework.integration.dsl.Transformers; +import org.springframework.integration.dsl.context.IntegrationFlowContext; import org.springframework.integration.handler.MessageTriggerAction; import org.springframework.integration.json.ObjectToJsonTransformer; import org.springframework.integration.support.MessageBuilder; @@ -52,6 +55,8 @@ import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.TextNode; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; /** * @author Artem Bilan @@ -167,6 +172,36 @@ public void testSplitterDiscard() { .hasSize(0); } + @Autowired + private IntegrationFlowContext integrationFlowContext; + + @Test + public void testFluxAggregator() { + IntegrationFlow testFlow = (flow) -> + flow.split() + .channel(MessageChannels.flux()) + .handle(new FluxAggregatorMessageHandler()); + + IntegrationFlowContext.IntegrationFlowRegistration registration = + this.integrationFlowContext.registration(testFlow) + .register(); + + @SuppressWarnings("unchecked") + Flux> window = + registration.getMessagingTemplate() + .convertSendAndReceive(new Integer[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }, Flux.class); + + assertThat(window).isNotNull(); + + StepVerifier.create( + window.map(Message::getPayload) + .cast(Integer.class)) + .expectNextSequence(IntStream.range(0, 10).boxed().collect(Collectors.toList())) + .verifyComplete(); + + registration.destroy(); + } + @Configuration @EnableIntegration public static class ContextConfiguration {