diff --git a/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/AbstractRSocketConnector.java b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/AbstractRSocketConnector.java new file mode 100644 index 00000000000..7caadf04eef --- /dev/null +++ b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/AbstractRSocketConnector.java @@ -0,0 +1,141 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.rsocket; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.SmartInitializingSingleton; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.context.SmartLifecycle; +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.codec.StringDecoder; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; + +/** + * A base connector container for common RSocket client and server functionality. + *

+ * It accepts {@link IntegrationRSocketEndpoint} instances for mapping registration via an internal + * {@link IntegrationRSocketAcceptor} or performs an auto-detection otherwise, when all bean are ready + * in the application context. + * + * @author Artem Bilan + * + * @since 5.2 + * + * @see IntegrationRSocketAcceptor + */ +public abstract class AbstractRSocketConnector + implements ApplicationContextAware, InitializingBean, DisposableBean, SmartInitializingSingleton, + SmartLifecycle { + + protected final IntegrationRSocketAcceptor rsocketAcceptor; // NOSONAR - final + + private MimeType dataMimeType = MimeTypeUtils.TEXT_PLAIN; + + private RSocketStrategies rsocketStrategies = + RSocketStrategies.builder() + .decoder(StringDecoder.allMimeTypes()) + .encoder(CharSequenceEncoder.allMimeTypes()) + .dataBufferFactory(new DefaultDataBufferFactory()) + .build(); + + private volatile boolean running; + + private ApplicationContext applicationContext; + + protected AbstractRSocketConnector(IntegrationRSocketAcceptor rsocketAcceptor) { + this.rsocketAcceptor = rsocketAcceptor; + } + + public void setDataMimeType(MimeType dataMimeType) { + Assert.notNull(dataMimeType, "'dataMimeType' must not be null"); + this.dataMimeType = dataMimeType; + } + + protected MimeType getDataMimeType() { + return this.dataMimeType; + } + + public void setRSocketStrategies(RSocketStrategies rsocketStrategies) { + Assert.notNull(rsocketStrategies, "'rsocketStrategies' must not be null"); + this.rsocketStrategies = rsocketStrategies; + } + + public RSocketStrategies getRSocketStrategies() { + return this.rsocketStrategies; + } + + public void setEndpoints(IntegrationRSocketEndpoint... endpoints) { + Assert.notNull(endpoints, "'endpoints' must not be null"); + for (IntegrationRSocketEndpoint endpoint : endpoints) { + addEndpoint(endpoint); + } + } + + public void addEndpoint(IntegrationRSocketEndpoint endpoint) { + this.rsocketAcceptor.addEndpoint(endpoint); + } + + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.applicationContext = applicationContext; + this.rsocketAcceptor.setApplicationContext(applicationContext); + } + + protected ApplicationContext getApplicationContext() { + return this.applicationContext; + } + + @Override + public void afterPropertiesSet() { + this.rsocketAcceptor.setDefaultDataMimeType(this.dataMimeType); + this.rsocketAcceptor.setRSocketStrategies(this.rsocketStrategies); + this.rsocketAcceptor.afterPropertiesSet(); + } + + @Override + public void afterSingletonsInstantiated() { + this.rsocketAcceptor.detectEndpoints(); + } + + @Override + public void start() { + if (!this.running) { + this.running = true; + doStart(); + } + } + + protected abstract void doStart(); + + @Override + public void stop() { + this.running = false; + } + + @Override + public boolean isRunning() { + return this.running; + } + +} diff --git a/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/ClientRSocketConnector.java b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/ClientRSocketConnector.java index 1189d877b5f..a30af81d5f4 100644 --- a/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/ClientRSocketConnector.java +++ b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/ClientRSocketConnector.java @@ -18,14 +18,10 @@ import java.net.URI; import java.util.function.Consumer; +import java.util.function.Function; -import org.springframework.beans.factory.DisposableBean; -import org.springframework.beans.factory.InitializingBean; import org.springframework.messaging.rsocket.RSocketRequester; -import org.springframework.messaging.rsocket.RSocketStrategies; import org.springframework.util.Assert; -import org.springframework.util.MimeType; -import org.springframework.util.MimeTypeUtils; import io.rsocket.Payload; import io.rsocket.RSocket; @@ -39,7 +35,11 @@ import reactor.core.publisher.Mono; /** - * A client connector to the RSocket server. + * A client {@link AbstractRSocketConnector} extension to the RSocket server. + *

+ * Note: the {@link RSocketFactory.ClientRSocketFactory#acceptor(Function)} + * in the provided {@link #factoryConfigurer} is overridden with an internal {@link IntegrationRSocketAcceptor} + * for the proper Spring Integration channel adapter mappings. * * @author Artem Bilan * @@ -48,17 +48,17 @@ * @see RSocketFactory.ClientRSocketFactory * @see RSocketRequester */ -public class ClientRSocketConnector implements InitializingBean, DisposableBean { +public class ClientRSocketConnector extends AbstractRSocketConnector { private final ClientTransport clientTransport; - private MimeType dataMimeType = MimeTypeUtils.TEXT_PLAIN; + private Consumer factoryConfigurer = (clientRSocketFactory) -> { }; - private Payload connectPayload = EmptyPayload.INSTANCE; + private String connectRoute; - private RSocketStrategies rsocketStrategies = RSocketStrategies.builder().build(); + private String connectData = ""; - private Consumer factoryConfigurer = (clientRSocketFactory) -> { }; + private boolean autoConnect; private Mono rsocketMono; @@ -71,47 +71,51 @@ public ClientRSocketConnector(URI uri) { } public ClientRSocketConnector(ClientTransport clientTransport) { + super(new IntegrationRSocketAcceptor()); Assert.notNull(clientTransport, "'clientTransport' must not be null"); this.clientTransport = clientTransport; } - public void setDataMimeType(MimeType dataMimeType) { - Assert.notNull(dataMimeType, "'dataMimeType' must not be null"); - this.dataMimeType = dataMimeType; - } - public void setFactoryConfigurer(Consumer factoryConfigurer) { Assert.notNull(factoryConfigurer, "'factoryConfigurer' must not be null"); this.factoryConfigurer = factoryConfigurer; } - public void setRSocketStrategies(RSocketStrategies rsocketStrategies) { - Assert.notNull(rsocketStrategies, "'rsocketStrategies' must not be null"); - this.rsocketStrategies = rsocketStrategies; + public void setConnectRoute(String connectRoute) { + this.connectRoute = connectRoute; } - public void setConnectRoute(String connectRoute) { - this.connectPayload = DefaultPayload.create("", connectRoute); + public void setConnectData(String connectData) { + Assert.notNull(connectData, "'connectData' must not be null"); + this.connectData = connectData; } @Override public void afterPropertiesSet() { + super.afterPropertiesSet(); RSocketFactory.ClientRSocketFactory clientFactory = RSocketFactory.connect() - .dataMimeType(this.dataMimeType.toString()); + .dataMimeType(getDataMimeType().toString()); this.factoryConfigurer.accept(clientFactory); - clientFactory.setupPayload(this.connectPayload); + clientFactory.acceptor(this.rsocketAcceptor); + Payload connectPayload = EmptyPayload.INSTANCE; + if (this.connectRoute != null) { + connectPayload = DefaultPayload.create(this.connectData, this.connectRoute); + } + clientFactory.setupPayload(connectPayload); this.rsocketMono = clientFactory.transport(this.clientTransport).start().cache(); } - public void connect() { - this.rsocketMono.subscribe(); + @Override + public void afterSingletonsInstantiated() { + this.autoConnect = this.rsocketAcceptor.detectEndpoints(); } - public Mono getRSocketRequester() { - return this.rsocketMono - .map(rsocket -> RSocketRequester.wrap(rsocket, this.dataMimeType, this.rsocketStrategies)) - .cache(); + @Override + protected void doStart() { + if (this.autoConnect) { + connect(); + } } @Override @@ -121,4 +125,17 @@ public void destroy() { .subscribe(); } + /** + * Perform subscription into the RSocket server for incoming requests. + */ + public void connect() { + this.rsocketMono.subscribe(); + } + + public Mono getRSocketRequester() { + return this.rsocketMono + .map(rsocket -> RSocketRequester.wrap(rsocket, getDataMimeType(), getRSocketStrategies())) + .cache(); + } + } diff --git a/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/IntegrationRSocket.java b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/IntegrationRSocket.java new file mode 100644 index 00000000000..bf218d5a630 --- /dev/null +++ b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/IntegrationRSocket.java @@ -0,0 +1,207 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.rsocket; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import org.reactivestreams.Publisher; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.NettyDataBuffer; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.lang.Nullable; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.handler.DestinationPatternsMessageCondition; +import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler; +import org.springframework.messaging.rsocket.RSocketPayloadReturnValueHandler; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.rsocket.RSocketRequesterMethodArgumentResolver; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.messaging.support.MessageHeaderAccessor; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; + +import io.netty.buffer.ByteBuf; +import io.rsocket.AbstractRSocket; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; + +/** + * Implementation of {@link RSocket} that wraps incoming requests with a + * {@link Message}, delegates to a {@link Function} for handling, and then + * obtains the response from a "reply" header. + *

+ * Essentially, this is an adapted for Spring Integration copy + * of the {@link org.springframework.messaging.rsocket.MessagingRSocket} because + * that one is not public. + * + * @author Artem Bilan + * + * @since 5.2 + * + * @see org.springframework.messaging.rsocket.MessagingRSocket + */ +class IntegrationRSocket extends AbstractRSocket { + + private final Function, Mono> handler; + + private final RSocketRequester requester; + + private final DataBufferFactory bufferFactory; + + @Nullable + private MimeType dataMimeType; + + IntegrationRSocket(Function, Mono> handler, RSocketRequester requester, + @Nullable MimeType defaultDataMimeType, DataBufferFactory bufferFactory) { + + Assert.notNull(handler, "'handler' is required"); + Assert.notNull(requester, "'requester' is required"); + this.handler = handler; + this.requester = requester; + this.dataMimeType = defaultDataMimeType; + this.bufferFactory = bufferFactory; + } + + public void setDataMimeType(MimeType dataMimeType) { + this.dataMimeType = dataMimeType; + } + + public RSocketRequester getRequester() { + return this.requester; + } + + @Override + public Mono fireAndForget(Payload payload) { + return handle(payload); + } + + @Override + public Mono requestResponse(Payload payload) { + return handleAndReply(payload, Flux.just(payload)).next(); + } + + @Override + public Flux requestStream(Payload payload) { + return handleAndReply(payload, Flux.just(payload)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .switchOnFirst((signal, innerFlux) -> { + Payload firstPayload = signal.get(); + return firstPayload == null ? innerFlux : handleAndReply(firstPayload, innerFlux); + }); + } + + @Override + public Mono metadataPush(Payload payload) { + // Not very useful until createHeaders does more with metadata + return handle(payload); + } + + + private Mono handle(Payload payload) { + String destination = getDestination(payload); + MessageHeaders headers = createHeaders(destination, null); + DataBuffer dataBuffer = retainDataAndReleasePayload(payload); + int refCount = refCount(dataBuffer); + Message message = MessageBuilder.createMessage(dataBuffer, headers); + return Mono.defer(() -> this.handler.apply(message)) + .doFinally(s -> { + if (refCount(dataBuffer) == refCount) { + DataBufferUtils.release(dataBuffer); + } + }); + } + + static int refCount(DataBuffer dataBuffer) { + return dataBuffer instanceof NettyDataBuffer ? + ((NettyDataBuffer) dataBuffer).getNativeBuffer().refCnt() : 1; + } + + private Flux handleAndReply(Payload firstPayload, Flux payloads) { + MonoProcessor> replyMono = MonoProcessor.create(); + String destination = getDestination(firstPayload); + MessageHeaders headers = createHeaders(destination, replyMono); + + AtomicBoolean read = new AtomicBoolean(); + Flux buffers = payloads.map(this::retainDataAndReleasePayload).doOnSubscribe(s -> read.set(true)); + Message> message = MessageBuilder.createMessage(buffers, headers); + + return Mono.defer(() -> this.handler.apply(message)) + .doFinally(s -> { + // Subscription should have happened by now due to ChannelSendOperator + if (!read.get()) { + buffers.subscribe(DataBufferUtils::release); + } + }) + .thenMany(Flux.defer(() -> replyMono.isTerminated() ? + replyMono.flatMapMany(Function.identity()) : + Mono.error(new IllegalStateException("Something went wrong: reply Mono not set")))); + } + + private DataBuffer retainDataAndReleasePayload(Payload payload) { + return payloadToDataBuffer(payload, this.bufferFactory); + } + + private MessageHeaders createHeaders(String destination, @Nullable MonoProcessor replyMono) { + MessageHeaderAccessor headers = new MessageHeaderAccessor(); + headers.setLeaveMutable(true); + headers.setHeader(DestinationPatternsMessageCondition.LOOKUP_DESTINATION_HEADER, destination); + if (this.dataMimeType != null) { + headers.setContentType(this.dataMimeType); + } + headers.setHeader(RSocketRequesterMethodArgumentResolver.RSOCKET_REQUESTER_HEADER, this.requester); + if (replyMono != null) { + headers.setHeader(RSocketPayloadReturnValueHandler.RESPONSE_HEADER, replyMono); + } + headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, this.bufferFactory); + return headers.getMessageHeaders(); + } + + static String getDestination(Payload payload) { + return payload.getMetadataUtf8(); + } + + static DataBuffer payloadToDataBuffer(Payload payload, DataBufferFactory bufferFactory) { + payload.retain(); + try { + if (bufferFactory instanceof NettyDataBufferFactory) { + ByteBuf byteBuf = payload.sliceData().retain(); + return ((NettyDataBufferFactory) bufferFactory).wrap(byteBuf); + } + else { + return bufferFactory.wrap(payload.getData()); + } + } + finally { + if (payload.refCnt() > 0) { + payload.release(); + } + } + } + +} diff --git a/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/IntegrationRSocketAcceptor.java b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/IntegrationRSocketAcceptor.java new file mode 100644 index 00000000000..92e0e8832ba --- /dev/null +++ b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/IntegrationRSocketAcceptor.java @@ -0,0 +1,135 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.rsocket; + +import java.lang.reflect.Method; +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.function.Predicate; + +import org.springframework.context.ApplicationContext; +import org.springframework.core.MethodParameter; +import org.springframework.lang.Nullable; +import org.springframework.messaging.Message; +import org.springframework.messaging.ReactiveMessageHandler; +import org.springframework.messaging.handler.CompositeMessageCondition; +import org.springframework.messaging.handler.DestinationPatternsMessageCondition; +import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver; +import org.springframework.messaging.handler.invocation.reactive.SyncHandlerMethodArgumentResolver; +import org.springframework.messaging.rsocket.RSocketMessageHandler; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.util.MimeType; +import org.springframework.util.ReflectionUtils; + +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.RSocket; + +/** + * The {@link RSocketMessageHandler} extension for Spring Integration needs. + *

+ * The most of logic is copied from {@link org.springframework.messaging.rsocket.MessageHandlerAcceptor}. + * That cannot be extended because it is {@link final}. + *

+ * This class adds an {@link IntegrationRSocketEndpoint} beans detection and registration functionality, + * as well as serves as a container over an internal {@link IntegrationRSocket} implementation. + * + * @author Artem Bilan + * + * @since 5.2 + * + * @see org.springframework.messaging.rsocket.MessageHandlerAcceptor + */ +class IntegrationRSocketAcceptor extends RSocketMessageHandler implements Function { + + private static final Method HANDLE_MESSAGE_METHOD = + ReflectionUtils.findMethod(ReactiveMessageHandler.class, "handleMessage", Message.class); + + @Nullable + private MimeType defaultDataMimeType; + + /** + * Configure the default content type to use for data payloads. + *

By default this is not set. However a server acceptor will use the + * content type from the {@link ConnectionSetupPayload}, so this is typically + * required for clients but can also be used on servers as a fallback. + * @param defaultDataMimeType the MimeType to use + */ + public void setDefaultDataMimeType(@Nullable MimeType defaultDataMimeType) { + this.defaultDataMimeType = defaultDataMimeType; + } + + public boolean detectEndpoints() { + ApplicationContext applicationContext = getApplicationContext(); + if (applicationContext != null && getHandlerMethods().isEmpty()) { + return applicationContext + .getBeansOfType(IntegrationRSocketEndpoint.class) + .values() + .stream() + .peek(this::addEndpoint) + .count() > 0; + } + else { + return false; + } + } + + public void addEndpoint(IntegrationRSocketEndpoint endpoint) { + registerHandlerMethod(endpoint, HANDLE_MESSAGE_METHOD, + new CompositeMessageCondition( + new DestinationPatternsMessageCondition(endpoint.getPath(), getPathMatcher()))); + } + + @Override + protected List initArgumentResolvers() { + return Collections.singletonList(new MessageHandlerMethodArgumentResolver()); + } + + @Override + protected Predicate> initHandlerPredicate() { + return (clazz) -> false; + } + + @Override + public RSocket apply(RSocket sendingRSocket) { + return createRSocket(sendingRSocket); + } + + protected IntegrationRSocket createRSocket(RSocket rsocket) { + RSocketStrategies rsocketStrategies = getRSocketStrategies(); + return new IntegrationRSocket(this::handleMessage, + RSocketRequester.wrap(rsocket, this.defaultDataMimeType, rsocketStrategies), + this.defaultDataMimeType, + rsocketStrategies.dataBufferFactory()); + } + + private static final class MessageHandlerMethodArgumentResolver implements SyncHandlerMethodArgumentResolver { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return true; + } + + @Override + public Object resolveArgumentValue(MethodParameter parameter, Message message) { + return message; + } + + } + +} diff --git a/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/IntegrationRSocketEndpoint.java b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/IntegrationRSocketEndpoint.java new file mode 100644 index 00000000000..4cf164e6f5d --- /dev/null +++ b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/IntegrationRSocketEndpoint.java @@ -0,0 +1,38 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.rsocket; + +import org.springframework.messaging.ReactiveMessageHandler; + +/** + * A marker {@link ReactiveMessageHandler} extension interface for Spring Integration + * inbound endpoints. + * It is used as mapping predicate in the internal RSocket acceptor of the + * {@link AbstractRSocketConnector}. + * + * @author Artem Bilan + * + * @since 5.2 + * + * @see AbstractRSocketConnector + * @see org.springframework.integration.rsocket.inbound.RSocketInboundGateway + */ +public interface IntegrationRSocketEndpoint extends ReactiveMessageHandler { + + String[] getPath(); + +} diff --git a/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/RSocketConnectedEvent.java b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/RSocketConnectedEvent.java new file mode 100644 index 00000000000..e65323af0be --- /dev/null +++ b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/RSocketConnectedEvent.java @@ -0,0 +1,72 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.rsocket; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.integration.events.IntegrationEvent; +import org.springframework.messaging.rsocket.RSocketRequester; + +/** + * An {@link IntegrationEvent} to indicate that {@code RSocket} from the client is connected + * to the server. + *

+ * This event can be used for mapping {@link RSocketRequester} to the client by the + * {@code destination} meta-data or connect payload {@code data}. + * + * @author Artem Bilan + * + * @since 5.2 + * + * @see IntegrationRSocketAcceptor + */ +@SuppressWarnings("serial") +public class RSocketConnectedEvent extends IntegrationEvent { + + private final String destination; + + private final DataBuffer data; + + private final RSocketRequester requester; + + public RSocketConnectedEvent(Object source, String destination, DataBuffer data, RSocketRequester requester) { + super(source); + this.destination = destination; + this.data = data; + this.requester = requester; + } + + public String getDestination() { + return this.destination; + } + + public DataBuffer getData() { + return this.data; + } + + public RSocketRequester getRequester() { + return this.requester; + } + + @Override + public String toString() { + return "RSocketConnectedEvent{" + + "destination='" + this.destination + '\'' + + ", requester=" + this.requester + + '}'; + } + +} diff --git a/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/ServerRSocketConnector.java b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/ServerRSocketConnector.java new file mode 100644 index 00000000000..4788e670c65 --- /dev/null +++ b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/ServerRSocketConnector.java @@ -0,0 +1,184 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.rsocket; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.lang.Nullable; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.util.Assert; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.StringUtils; + +import io.rsocket.Closeable; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.RSocket; +import io.rsocket.RSocketFactory; +import io.rsocket.SocketAcceptor; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.netty.http.server.HttpServer; + +/** + * A server {@link AbstractRSocketConnector} extension to accept and manage client RSocket connections. + *

+ * Note: the {@link RSocketFactory.ServerRSocketFactory#acceptor(SocketAcceptor)} + * in the provided {@link #factoryConfigurer} is overridden with an internal {@link IntegrationRSocketAcceptor} + * for the proper Spring Integration channel adapter mappings. + * + * @author Artem Bilan + * + * @since 5.2 + * + * @see RSocketFactory.ServerRSocketFactory + */ +public class ServerRSocketConnector extends AbstractRSocketConnector + implements ApplicationEventPublisherAware { + + private final ServerTransport serverTransport; + + private Consumer factoryConfigurer = (serverRSocketFactory) -> { }; + + private Mono serverMono; + + public ServerRSocketConnector(String bindAddress, int port) { + this(TcpServerTransport.create(bindAddress, port)); + } + + public ServerRSocketConnector(HttpServer server) { + this(WebsocketServerTransport.create(server)); + } + + public ServerRSocketConnector(ServerTransport serverTransport) { + super(new ServerRSocketAcceptor()); + Assert.notNull(serverTransport, "'serverTransport' must not be null"); + this.serverTransport = serverTransport; + } + + public void setFactoryConfigurer(Consumer factoryConfigurer) { + Assert.notNull(factoryConfigurer, "'factoryConfigurer' must not be null"); + this.factoryConfigurer = factoryConfigurer; + } + + public void setClientRSocketKeyStrategy(BiFunction clientRSocketKeyStrategy) { + Assert.notNull(clientRSocketKeyStrategy, "'clientRSocketKeyStrategy' must not be null"); + serverRSocketAcceptor().clientRSocketKeyStrategy = clientRSocketKeyStrategy; + } + + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { + serverRSocketAcceptor().applicationEventPublisher = applicationEventPublisher; + } + + @Override + public void afterPropertiesSet() { + super.afterPropertiesSet(); + RSocketFactory.ServerRSocketFactory serverFactory = RSocketFactory.receive(); + this.factoryConfigurer.accept(serverFactory); + this.serverMono = + serverFactory + .acceptor(serverRSocketAcceptor()) + .transport(this.serverTransport) + .start() + .cache(); + } + + public Map getClientRSocketRequesters() { + return Collections.unmodifiableMap(serverRSocketAcceptor().clientRSocketRequesters); + } + + @Nullable + public RSocketRequester getClientRSocketRequester(Object key) { + return serverRSocketAcceptor().clientRSocketRequesters.get(key); + } + + private ServerRSocketAcceptor serverRSocketAcceptor() { + return (ServerRSocketAcceptor) this.rsocketAcceptor; + } + + @Override + protected void doStart() { + this.serverMono.subscribe(); + } + + @Override + public void destroy() { + this.serverMono + .doOnNext(Disposable::dispose) + .subscribe(); + } + + private static class ServerRSocketAcceptor extends IntegrationRSocketAcceptor implements SocketAcceptor { + + private static final Log LOGGER = LogFactory.getLog(IntegrationRSocket.class); + + private final Map clientRSocketRequesters = new HashMap<>(); + + private BiFunction clientRSocketKeyStrategy = (destination, data) -> destination; + + private ApplicationEventPublisher applicationEventPublisher; + + @Override + public Mono accept(ConnectionSetupPayload setupPayload, RSocket sendingRSocket) { + String destination = IntegrationRSocket.getDestination(setupPayload); + DataBuffer dataBuffer = + IntegrationRSocket.payloadToDataBuffer(setupPayload, getRSocketStrategies().dataBufferFactory()); + int refCount = IntegrationRSocket.refCount(dataBuffer); + return Mono.just(sendingRSocket) + .map(this::createRSocket) + .doOnNext((rsocket) -> { + if (StringUtils.hasText(setupPayload.dataMimeType())) { + rsocket.setDataMimeType(MimeTypeUtils.parseMimeType(setupPayload.dataMimeType())); + } + Object rsocketRequesterKey = this.clientRSocketKeyStrategy.apply(destination, dataBuffer); + this.clientRSocketRequesters.put(rsocketRequesterKey, rsocket.getRequester()); + RSocketConnectedEvent rSocketConnectedEvent = + new RSocketConnectedEvent(rsocket, destination, dataBuffer, rsocket.getRequester()); + if (this.applicationEventPublisher != null) { + this.applicationEventPublisher.publishEvent(rSocketConnectedEvent); + } + else { + if (LOGGER.isInfoEnabled()) { + LOGGER.info("The RSocket has been connected: " + rSocketConnectedEvent); + } + } + }) + .cast(RSocket.class) + .doFinally((signal) -> { + if (IntegrationRSocket.refCount(dataBuffer) == refCount) { + DataBufferUtils.release(dataBuffer); + } + }); + } + + } + +} diff --git a/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/inbound/RSocketInboundGateway.java b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/inbound/RSocketInboundGateway.java new file mode 100644 index 00000000000..27be78dca57 --- /dev/null +++ b/spring-integration-rsocket/src/main/java/org/springframework/integration/rsocket/inbound/RSocketInboundGateway.java @@ -0,0 +1,282 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.rsocket.inbound; + +import java.util.Arrays; + +import org.reactivestreams.Publisher; + +import org.springframework.core.ReactiveAdapter; +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.codec.Decoder; +import org.springframework.core.codec.Encoder; +import org.springframework.core.codec.StringDecoder; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.io.buffer.NettyDataBuffer; +import org.springframework.integration.gateway.MessagingGatewaySupport; +import org.springframework.integration.rsocket.AbstractRSocketConnector; +import org.springframework.integration.rsocket.ClientRSocketConnector; +import org.springframework.integration.rsocket.IntegrationRSocketEndpoint; +import org.springframework.integration.support.MessageBuilder; +import org.springframework.lang.Nullable; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageDeliveryException; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.handler.invocation.reactive.HandlerMethodReturnValueHandler; +import org.springframework.messaging.rsocket.RSocketPayloadReturnValueHandler; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; + +import io.rsocket.Payload; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; + +/** + * The {@link MessagingGatewaySupport} implementation for the {@link IntegrationRSocketEndpoint}. + *

+ * May be configured with the {@link AbstractRSocketConnector} for mapping registration. + * Or existing {@link AbstractRSocketConnector} bean(s) will perform detection automatically. + *

+ * An inbound {@link DataBuffer} (either single or as a {@link Publisher} element) is + * converted to the target expected type which can be configured by the + * {@link #setRequestElementClass} or {@link #setRequestElementType(ResolvableType)}. + * If it is not configured, then target type is determined by the {@code contentType} header: + * If it is a {@code text}, then target type is {@link String}, otherwise - {@code byte[]}. + *

+ * An inbound {@link Publisher} is used as is in the message to send payload. + * It is a target application responsibility to process that payload any possible way. + *

+ * A reply payload is encoded to the {@link Flux} according a type of the payload or a + * {@link Publisher} element type. + * + * @author Artem Bilan + * + * @since 5.2 + */ +public class RSocketInboundGateway extends MessagingGatewaySupport implements IntegrationRSocketEndpoint { + + private final String[] path; + + private RSocketStrategies rsocketStrategies = + RSocketStrategies.builder() + .decoder(StringDecoder.allMimeTypes()) + .encoder(CharSequenceEncoder.allMimeTypes()) + .dataBufferFactory(new DefaultDataBufferFactory()) + .build(); + + @Nullable + private AbstractRSocketConnector rsocketConnector; + + @Nullable + private ResolvableType requestElementType; + + public RSocketInboundGateway(String... path) { + Assert.notNull(path, "'path' must not be null"); + this.path = path; + } + + /** + * Configure {@link RSocketStrategies} instead of a default one. + * Note: if {@link AbstractRSocketConnector} ias provided, then its + * {@link RSocketStrategies} have a precedence. + * @param rsocketStrategies the {@link RSocketStrategies} to use. + * @see RSocketStrategies#builder + */ + public void setRSocketStrategies(RSocketStrategies rsocketStrategies) { + Assert.notNull(rsocketStrategies, "'rsocketStrategies' must not be null"); + this.rsocketStrategies = rsocketStrategies; + } + + /** + * Provide an {@link AbstractRSocketConnector} reference for an explicit endpoint mapping. + * @param rsocketConnector the {@link AbstractRSocketConnector} to use. + */ + public void setRSocketConnector(AbstractRSocketConnector rsocketConnector) { + Assert.notNull(rsocketConnector, "'rsocketConnector' must not be null"); + this.rsocketConnector = rsocketConnector; + } + + /** + * Get an array of the path patterns this endpoint is mapped onto. + * @return the mapping path + */ + public String[] getPath() { + return this.path; + } + + /** + * Specify the type of payload to be generated when the inbound RSocket request + * content is read by the encoders. + * By default this value is null which means at runtime any "text" Content-Type will + * result in String while all others default to byte[].class. + * @param requestElementClass The payload type. + */ + public void setRequestElementClass(Class requestElementClass) { + setRequestElementType(ResolvableType.forClass(requestElementClass)); + } + + /** + * Specify the type of payload to be generated when the inbound RSocket request + * content is read by the converters/encoders. + * By default this value is null which means at runtime any "text" Content-Type will + * result in String while all others default to byte[].class. + * @param requestElementType The payload type. + */ + public void setRequestElementType(ResolvableType requestElementType) { + this.requestElementType = requestElementType; + } + + @Override + protected void onInit() { + super.onInit(); + if (this.rsocketConnector != null) { + this.rsocketConnector.addEndpoint(this); + this.rsocketStrategies = this.rsocketConnector.getRSocketStrategies(); + } + } + + @Override + protected void doStart() { + super.doStart(); + if (this.rsocketConnector instanceof ClientRSocketConnector) { + ((ClientRSocketConnector) this.rsocketConnector).connect(); + } + } + + @Override + public Mono handleMessage(Message requestMessage) { + if (!isRunning()) { + return Mono.error(new MessageDeliveryException(requestMessage, + "The RSocket Inbound Gateway '" + getComponentName() + "' is stopped; " + + "service for path(s) " + Arrays.toString(this.path) + " is not available at the moment.")); + } + + Mono> requestMono = decodeRequestMessage(requestMessage); + MonoProcessor> replyMono = getReplyMono(requestMessage); + if (replyMono != null) { + return requestMono + .flatMap(this::sendAndReceiveMessageReactive) + .doOnNext(replyMessage -> { + replyMono.onNext(createReply(replyMessage.getPayload(), requestMessage)); + replyMono.onComplete(); + }) + .then(); + } + else { + return requestMono + .doOnNext(this::send) + .then(); + } + } + + private Mono> decodeRequestMessage(Message requestMessage) { + return Mono.just(decodePayload(requestMessage)) + .map((payload) -> + MessageBuilder.withPayload(payload) + .copyHeaders(requestMessage.getHeaders()) + .build()); + } + + @SuppressWarnings("unchecked") + private Object decodePayload(Message requestMessage) { + ResolvableType elementType = this.requestElementType; + MimeType mimeType = requestMessage.getHeaders().get(MessageHeaders.CONTENT_TYPE, MimeType.class); + if (elementType == null) { + elementType = + mimeType != null && "text".equals(mimeType.getType()) + ? ResolvableType.forClass(String.class) + : ResolvableType.forClass(byte[].class); + } + + Object payload = requestMessage.getPayload(); + + // The IntegrationRSocket logic ensures that we can have only a single DataBuffer payload or Flux. + Decoder decoder = this.rsocketStrategies.decoder(elementType, mimeType); + if (payload instanceof DataBuffer) { + return decoder.decode((DataBuffer) payload, elementType, mimeType, null); + } + else { + return decoder.decode((Publisher) payload, elementType, mimeType, null); + } + } + + private Flux createReply(Object reply, Message requestMessage) { + MessageHeaders requestMessageHeaders = requestMessage.getHeaders(); + DataBufferFactory bufferFactory = + requestMessageHeaders.get(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, + DataBufferFactory.class); + + MimeType mimeType = requestMessageHeaders.get(MessageHeaders.CONTENT_TYPE, MimeType.class); + + return encodeContent(reply, ResolvableType.forInstance(reply), bufferFactory, mimeType) + .map(RSocketInboundGateway::createPayload); + } + + private Flux encodeContent(Object content, ResolvableType returnValueType, + DataBufferFactory bufferFactory, @Nullable MimeType mimeType) { + + ReactiveAdapter adapter = + this.rsocketStrategies.reactiveAdapterRegistry() + .getAdapter(returnValueType.resolve(), content); + + Publisher publisher; + if (adapter != null) { + publisher = adapter.toPublisher(content); + } + else { + publisher = Flux.just(content); + } + + return Flux.from((Publisher) publisher) + .map((value) -> encodeValue(value, bufferFactory, mimeType)); + } + + private DataBuffer encodeValue(Object element, DataBufferFactory bufferFactory, @Nullable MimeType mimeType) { + ResolvableType elementType = ResolvableType.forInstance(element); + Encoder encoder = this.rsocketStrategies.encoder(elementType, mimeType); + return encoder.encodeValue(element, bufferFactory, elementType, mimeType, null); + } + + @Nullable + @SuppressWarnings("unchecked") + private static MonoProcessor> getReplyMono(Message message) { + Object headerValue = message.getHeaders().get(RSocketPayloadReturnValueHandler.RESPONSE_HEADER); + Assert.state(headerValue == null || headerValue instanceof MonoProcessor, "Expected MonoProcessor"); + return (MonoProcessor>) headerValue; + } + + private static Payload createPayload(DataBuffer data) { + if (data instanceof NettyDataBuffer) { + return ByteBufPayload.create(((NettyDataBuffer) data).getNativeBuffer()); + } + else if (data instanceof DefaultDataBuffer) { + return DefaultPayload.create(((DefaultDataBuffer) data).getNativeBuffer()); + } + else { + return DefaultPayload.create(data.asByteBuffer()); + } + } + +} diff --git a/spring-integration-rsocket/src/test/java/org/springframework/integration/rsocket/inbound/RSocketInboundGatewayIntegrationTests.java b/spring-integration-rsocket/src/test/java/org/springframework/integration/rsocket/inbound/RSocketInboundGatewayIntegrationTests.java new file mode 100644 index 00000000000..78798476732 --- /dev/null +++ b/spring-integration-rsocket/src/test/java/org/springframework/integration/rsocket/inbound/RSocketInboundGatewayIntegrationTests.java @@ -0,0 +1,242 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.rsocket.inbound; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationListener; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.codec.StringDecoder; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.integration.annotation.Transformer; +import org.springframework.integration.channel.FluxMessageChannel; +import org.springframework.integration.channel.QueueChannel; +import org.springframework.integration.config.EnableIntegration; +import org.springframework.integration.rsocket.ClientRSocketConnector; +import org.springframework.integration.rsocket.RSocketConnectedEvent; +import org.springframework.integration.rsocket.ServerRSocketConnector; +import org.springframework.messaging.Message; +import org.springframework.messaging.PollableChannel; +import org.springframework.messaging.rsocket.RSocketRequester; +import org.springframework.messaging.rsocket.RSocketStrategies; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; + +import io.netty.buffer.PooledByteBufAllocator; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.netty.server.TcpServerTransport; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.netty.tcp.TcpServer; +import reactor.test.StepVerifier; + +/** + * @author Artem Bilan + * + * @since 5.2 + */ +@SpringJUnitConfig(RSocketInboundGatewayIntegrationTests.ClientConfig.class) +@DirtiesContext +public class RSocketInboundGatewayIntegrationTests { + + private static AnnotationConfigApplicationContext serverContext; + + private static int port; + + private static ServerConfig serverConfig; + + private static PollableChannel serverFireAndForgetChannelChannel; + + @Autowired + private ClientRSocketConnector clientRSocketConnector; + + @Autowired + private PollableChannel fireAndForgetChannelChannel; + + private RSocketRequester serverRsocketRequester; + + private RSocketRequester clientRsocketRequester; + + @BeforeAll + static void setup() { + serverContext = new AnnotationConfigApplicationContext(ServerConfig.class); + serverConfig = serverContext.getBean(ServerConfig.class); + serverFireAndForgetChannelChannel = serverContext.getBean("fireAndForgetChannelChannel", PollableChannel.class); + } + + @AfterAll + static void tearDown() { + serverContext.close(); + } + + @BeforeEach + void setupTest(TestInfo testInfo) { + if (testInfo.getDisplayName().startsWith("server")) { + this.serverRsocketRequester = serverConfig.clientRequester.block(Duration.ofSeconds(10)); + } + else { + this.clientRsocketRequester = + this.clientRSocketConnector.getRSocketRequester().block(Duration.ofSeconds(10)); + } + } + + @Test + void clientFireAndForget() { + fireAndForget(serverFireAndForgetChannelChannel, this.clientRsocketRequester); + } + + @Test + void serverFireAndForget() { + fireAndForget(this.fireAndForgetChannelChannel, this.serverRsocketRequester); + } + + private void fireAndForget(PollableChannel inputChannel, RSocketRequester rsocketRequester) { + rsocketRequester.route("receive") + .data("Hello") + .send() + .subscribe(); + + Message receive = inputChannel.receive(10_000); + assertThat(receive) + .isNotNull() + .extracting(Message::getPayload) + .isEqualTo("Hello"); + } + + @Test + void clientEcho() { + echo(this.clientRsocketRequester); + } + + @Test + void serverEcho() { + echo(this.serverRsocketRequester); + } + + private void echo(RSocketRequester rsocketRequester) { + Flux result = + Flux.range(1, 3) + .concatMap(i -> + rsocketRequester.route("echo") + .data("hello " + i) + .retrieveMono(String.class)); + + StepVerifier.create(result) + .expectNext("HELLO 1", "HELLO 2", "HELLO 3") + .expectComplete() + .verify(Duration.ofSeconds(10)); + } + + + private abstract static class CommonConfig { + + @Bean + public RSocketStrategies rsocketStrategies() { + return RSocketStrategies.builder() + .decoder(StringDecoder.allMimeTypes()) + .encoder(CharSequenceEncoder.allMimeTypes()) + .dataBufferFactory(new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT)) + .build(); + } + + @Bean + public PollableChannel fireAndForgetChannelChannel() { + return new QueueChannel(); + } + + @Bean + public RSocketInboundGateway rsocketInboundGatewayFireAndForget() { + RSocketInboundGateway rsocketInboundGateway = new RSocketInboundGateway("receive"); + rsocketInboundGateway.setRSocketStrategies(rsocketStrategies()); + rsocketInboundGateway.setRequestChannel(fireAndForgetChannelChannel()); + return rsocketInboundGateway; + } + + @Bean + public RSocketInboundGateway rsocketInboundGatewayRequestReply() { + RSocketInboundGateway rsocketInboundGateway = new RSocketInboundGateway("echo"); + rsocketInboundGateway.setRSocketStrategies(rsocketStrategies()); + rsocketInboundGateway.setRequestChannel(requestReplyChannel()); + return rsocketInboundGateway; + } + + @Bean + public FluxMessageChannel requestReplyChannel() { + return new FluxMessageChannel(); + } + + @Transformer(inputChannel = "requestReplyChannel") + public Mono echoTransformation(Flux payload) { + return payload.next().map(String::toUpperCase); + } + + } + + @Configuration + @EnableIntegration + static class ServerConfig extends CommonConfig implements ApplicationListener { + + final MonoProcessor clientRequester = MonoProcessor.create(); + + @Override + public void onApplicationEvent(RSocketConnectedEvent event) { + this.clientRequester.onNext(event.getRequester()); + } + + @Bean + public ServerRSocketConnector serverRSocketConnector() { + TcpServer tcpServer = + TcpServer.create().port(0) + .doOnBound(server -> port = server.port()); + ServerRSocketConnector serverRSocketConnector = + new ServerRSocketConnector(TcpServerTransport.create(tcpServer)); + serverRSocketConnector.setRSocketStrategies(rsocketStrategies()); + serverRSocketConnector.setFactoryConfigurer((factory) -> factory.frameDecoder(PayloadDecoder.ZERO_COPY)); + return serverRSocketConnector; + } + + } + + @Configuration + @EnableIntegration + public static class ClientConfig extends CommonConfig { + + @Bean + public ClientRSocketConnector clientRSocketConnector() { + ClientRSocketConnector clientRSocketConnector = new ClientRSocketConnector("localhost", port); + clientRSocketConnector.setFactoryConfigurer((factory) -> factory.frameDecoder(PayloadDecoder.ZERO_COPY)); + clientRSocketConnector.setRSocketStrategies(rsocketStrategies()); + clientRSocketConnector.setConnectRoute("clientConnect"); + return clientRSocketConnector; + } + + } + +} diff --git a/spring-integration-rsocket/src/test/java/org/springframework/integration/rsocket/outbound/RSocketOutboundGatewayIntegrationTests.java b/spring-integration-rsocket/src/test/java/org/springframework/integration/rsocket/outbound/RSocketOutboundGatewayIntegrationTests.java index 53e6720a1e8..3044c199dca 100644 --- a/spring-integration-rsocket/src/test/java/org/springframework/integration/rsocket/outbound/RSocketOutboundGatewayIntegrationTests.java +++ b/spring-integration-rsocket/src/test/java/org/springframework/integration/rsocket/outbound/RSocketOutboundGatewayIntegrationTests.java @@ -59,10 +59,13 @@ import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; import io.netty.buffer.PooledByteBufAllocator; +import io.rsocket.RSocket; import io.rsocket.RSocketFactory; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -76,7 +79,7 @@ * * @since 5.2 */ -@SpringJUnitConfig +@SpringJUnitConfig(RSocketOutboundGatewayIntegrationTests.ClientConfig.class) @DirtiesContext public class RSocketOutboundGatewayIntegrationTests { @@ -110,9 +113,6 @@ public class RSocketOutboundGatewayIntegrationTests { @Autowired private TestController clientController; - @Autowired - private ClientRSocketConnector clientRSocketConnector; - private RSocketRequester serverRsocketRequester; @BeforeAll @@ -143,7 +143,6 @@ static void tearDown() { @BeforeEach void setupTest(TestInfo testInfo) { if (testInfo.getDisplayName().startsWith("server")) { - this.clientRSocketConnector.connect(); this.serverRsocketRequester = serverController.clientRequester.block(Duration.ofSeconds(10)); } } @@ -524,19 +523,27 @@ public static class ClientConfig extends CommonConfig { public MessageHandlerAcceptor clientAcceptor() { MessageHandlerAcceptor acceptor = new MessageHandlerAcceptor(); acceptor.setHandlers(Collections.singletonList(controller())); - acceptor.setAutoDetectDisabled(); acceptor.setRSocketStrategies(rsocketStrategies()); return acceptor; } + @Bean(destroyMethod = "dispose") + public RSocket rsocketForServerRequests() { + return RSocketFactory.connect() + .setupPayload(DefaultPayload.create("", "clientConnect")) + .dataMimeType("text/plain") + .frameDecoder(PayloadDecoder.ZERO_COPY) + .acceptor(clientAcceptor()) + .transport(TcpClientTransport.create("localhost", port)) + .start() + .block(); + } + @Bean public ClientRSocketConnector clientRSocketConnector() { ClientRSocketConnector clientRSocketConnector = new ClientRSocketConnector("localhost", port); - clientRSocketConnector.setFactoryConfigurer((factory) -> factory - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(clientAcceptor())); + clientRSocketConnector.setFactoryConfigurer((factory) -> factory.frameDecoder(PayloadDecoder.ZERO_COPY)); clientRSocketConnector.setRSocketStrategies(rsocketStrategies()); - clientRSocketConnector.setConnectRoute("clientConnect"); return clientRSocketConnector; }