Skip to content

Commit 30d4159

Browse files
committed
GH-10487: Add STOMP CONNECT frame from the client (#10488)
Fixes: #10487 * Fix `WebSocketInboundChannelAdapter` to register own client session in the `StompSubProtocolHandler` for a proper correlation for upcoming messages from the server * Fix `WebSocketOutboundMessageHandlerTests` to produce required STOMP `CONNECT` before publishing data **Auto-cherry-pick to `6.4.x`** # Conflicts: # spring-integration-websocket/src/main/java/org/springframework/integration/websocket/inbound/WebSocketInboundChannelAdapter.java
1 parent 4212dcc commit 30d4159

File tree

4 files changed

+76
-37
lines changed

4 files changed

+76
-37
lines changed

spring-integration-websocket/src/main/java/org/springframework/integration/websocket/inbound/WebSocketInboundChannelAdapter.java

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@
5050
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
5151
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
5252
import org.springframework.messaging.simp.stomp.StompCommand;
53+
import org.springframework.messaging.simp.stomp.StompEncoder;
5354
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
5455
import org.springframework.messaging.support.MessageBuilder;
5556
import org.springframework.util.Assert;
5657
import org.springframework.util.CollectionUtils;
5758
import org.springframework.util.MimeTypeUtils;
59+
import org.springframework.web.socket.BinaryMessage;
5860
import org.springframework.web.socket.CloseStatus;
5961
import org.springframework.web.socket.WebSocketMessage;
6062
import org.springframework.web.socket.WebSocketSession;
@@ -137,7 +139,7 @@ public WebSocketInboundChannelAdapter(IntegrationWebSocketContainer webSocketCon
137139
}
138140

139141
/**
140-
* Set the message converters to use. These converters are used to convert the message to send for appropriate
142+
* Set the message converters to use. These converters are used to convert the message to send for the appropriate
141143
* internal subProtocols type.
142144
* @param messageConverters The message converters.
143145
*/
@@ -156,7 +158,7 @@ public void setMergeWithDefaultConverters(boolean mergeWithDefaultConverters) {
156158
}
157159

158160
/**
159-
* Set the type for target message payload to convert the WebSocket message body to.
161+
* Set the type for the target message payload to convert the WebSocket message body to.
160162
* @param payloadType to convert inbound WebSocket message body
161163
* @see CompositeMessageConverter
162164
*/
@@ -170,9 +172,9 @@ public void setPayloadType(Class<?> payloadType) {
170172
* bean for {@code non-MESSAGE} {@link org.springframework.web.socket.WebSocketMessage}s
171173
* and to route messages with broker destinations.
172174
* Since only single {@link AbstractBrokerMessageHandler} bean is allowed in the current
173-
* application context, the algorithm to lookup the former by type, rather than applying
175+
* application context, the algorithm is to look up the former by type, rather than applying
174176
* the bean reference.
175-
* This is used only on server side and is ignored from client side.
177+
* This is used only on the server side and is ignored from the client side.
176178
* @param useBroker the boolean flag.
177179
*/
178180
public void setUseBroker(boolean useBroker) {
@@ -230,13 +232,23 @@ public void afterSessionStarted(WebSocketSession session) {
230232
SubProtocolHandler protocolHandler = this.subProtocolHandlerRegistry.findProtocolHandler(session);
231233
protocolHandler.afterSessionStarted(session, this.subProtocolHandlerChannel);
232234
if (!this.server && protocolHandler instanceof StompSubProtocolHandler) {
235+
// The CONNECT frame is required by the STOMP specification.
233236
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT);
234237
accessor.setSessionId(session.getId());
235238
accessor.setLeaveMutable(true);
236239
accessor.setAcceptVersion("1.1,1.2");
237240

238-
Message<?> connectMessage =
241+
Message<byte[]> connectMessage =
239242
MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
243+
244+
// In the client mode, the client session has to register itself
245+
// into the StompSubProtocolHandler cache
246+
// for proper correlation of the messages from the server side.
247+
StompEncoder stompEncoder = new StompEncoder();
248+
byte[] connectMessageBytes = stompEncoder.encode(connectMessage);
249+
protocolHandler.handleMessageFromClient(session, new BinaryMessage(connectMessageBytes),
250+
this.subProtocolHandlerChannel);
251+
240252
protocolHandler.handleMessageToClient(session, connectMessage);
241253
}
242254
}
@@ -309,7 +321,11 @@ private void handleMessageAndSend(final Message<?> message) {
309321
SimpMessageType messageType = headerAccessor.getMessageType();
310322
if (isProcessingTypeOrCommand(headerAccessor, stompCommand, messageType)) {
311323
if (SimpMessageType.CONNECT.equals(messageType)) {
312-
produceConnectAckMessage(message, headerAccessor);
324+
// Ignore the CONNECT frame in the client mode.
325+
// Essentially, it has been just initiated from the {@link #afterSessionStarted}.
326+
if (this.server) {
327+
produceConnectAckMessage(message, headerAccessor);
328+
}
313329
}
314330
else if (StompCommand.CONNECTED.equals(stompCommand)) {
315331
this.eventPublisher.publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message));
@@ -337,7 +353,7 @@ else if (StompCommand.RECEIPT.equals(stompCommand)) {
337353
private boolean isProcessingTypeOrCommand(SimpMessageHeaderAccessor headerAccessor, StompCommand stompCommand,
338354
SimpMessageType messageType) {
339355

340-
return (messageType == null // NOSONAR pretty simple logic
356+
return (messageType == null
341357
|| SimpMessageType.MESSAGE.equals(messageType)
342358
|| (SimpMessageType.CONNECT.equals(messageType) && !this.useBroker)
343359
|| StompCommand.CONNECTED.equals(stompCommand)

spring-integration-websocket/src/test/java/org/springframework/integration/websocket/client/StompIntegrationTests.java

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
8585
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
8686
import org.springframework.web.socket.messaging.AbstractSubProtocolEvent;
87+
import org.springframework.web.socket.messaging.SessionConnectEvent;
8788
import org.springframework.web.socket.messaging.SessionConnectedEvent;
8889
import org.springframework.web.socket.messaging.SessionSubscribeEvent;
8990
import org.springframework.web.socket.messaging.StompSubProtocolHandler;
@@ -94,6 +95,7 @@
9495
import org.springframework.web.socket.sockjs.client.WebSocketTransport;
9596

9697
import static org.assertj.core.api.Assertions.assertThat;
98+
import static org.assertj.core.api.InstanceOfAssertFactories.type;
9799

98100
/**
99101
* @author Artem Bilan
@@ -124,35 +126,44 @@ public class StompIntegrationTests {
124126

125127
@Test
126128
public void sendMessageToController() throws Exception {
127-
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
128-
this.webSocketOutputChannel.send(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build());
129-
130129
Message<?> receive = this.webSocketEvents.receive(20000);
131-
assertThat(receive).isNotNull();
132-
Object event = receive.getPayload();
133-
assertThat(event).isInstanceOf(SessionConnectedEvent.class);
134-
Message<?> connectedMessage = ((SessionConnectedEvent) event).getMessage();
135-
headers = StompHeaderAccessor.wrap(connectedMessage);
136-
assertThat(headers.getCommand()).isEqualTo(StompCommand.CONNECTED);
130+
assertThat(receive)
131+
.extracting(Message::getPayload)
132+
// We've just registered our own connected client session from the WebSocketInboundChannelAdapter
133+
.isInstanceOf(SessionConnectEvent.class);
137134

138-
headers = StompHeaderAccessor.create(StompCommand.SEND);
135+
receive = this.webSocketEvents.receive(20000);
136+
assertThat(receive)
137+
.extracting(Message::getPayload)
138+
.asInstanceOf(type(SessionConnectedEvent.class))
139+
.extracting(SessionConnectedEvent::getMessage)
140+
.extracting(connectedMessage -> StompHeaderAccessor.wrap(connectedMessage).getCommand())
141+
.isEqualTo(StompCommand.CONNECTED);
142+
143+
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
139144
headers.setSubscriptionId("sub1");
140145
headers.setDestination("/app/simple");
141146
Message<String> message = MessageBuilder.withPayload("foo").setHeaders(headers).build();
142147

143148
this.webSocketOutputChannel.send(message);
144149

145150
SimpleController controller = this.serverContext.getBean(SimpleController.class);
146-
assertThat(controller.latch.await(20, TimeUnit.SECONDS)).isTrue();
151+
assertThat(controller.latch.await(10, TimeUnit.SECONDS)).isTrue();
147152
assertThat(controller.stompCommand).isEqualTo(StompCommand.SEND.name());
148153
}
149154

150155
@Test
151156
public void sendMessageToControllerAndReceiveReplyViaTopic() throws Exception {
152157
Message<?> receive = this.webSocketEvents.receive(20000);
153-
assertThat(receive).isNotNull();
154-
Object event = receive.getPayload();
155-
assertThat(event).isInstanceOf(SessionConnectedEvent.class);
158+
assertThat(receive)
159+
.extracting(Message::getPayload)
160+
// We've just registered our own connected client session from the WebSocketInboundChannelAdapter
161+
.isInstanceOf(SessionConnectEvent.class);
162+
163+
receive = this.webSocketEvents.receive(20000);
164+
assertThat(receive)
165+
.extracting(Message::getPayload)
166+
.isInstanceOf(SessionConnectedEvent.class);
156167

157168
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
158169
headers.setSubscriptionId("subs1");
@@ -165,13 +176,14 @@ public void sendMessageToControllerAndReceiveReplyViaTopic() throws Exception {
165176
this.webSocketOutputChannel.send(message);
166177

167178
receive = this.webSocketEvents.receive(20000);
168-
assertThat(receive).isNotNull();
169-
event = receive.getPayload();
170-
assertThat(event).isInstanceOf(ReceiptEvent.class);
171-
Message<?> receiptMessage = ((ReceiptEvent) event).getMessage();
172-
headers = StompHeaderAccessor.wrap(receiptMessage);
173-
assertThat(headers.getCommand()).isEqualTo(StompCommand.RECEIPT);
174-
assertThat(headers.getReceiptId()).isEqualTo("myReceipt");
179+
assertThat(receive)
180+
.extracting(Message::getPayload)
181+
.asInstanceOf(type(ReceiptEvent.class))
182+
.extracting(event -> StompHeaderAccessor.wrap(event.getMessage()))
183+
.satisfies(headerAccessor -> {
184+
assertThat(headerAccessor.getCommand()).isEqualTo(StompCommand.RECEIPT);
185+
assertThat(headerAccessor.getReceiptId()).isEqualTo("myReceipt");
186+
});
175187

176188
waitForSubscribe("/topic/increment");
177189

@@ -492,7 +504,7 @@ public void configureMessageBroker(MessageBrokerRegistry configurer) {
492504
public ApplicationListener<SessionSubscribeEvent> webSocketEventListener(
493505
final AbstractSubscribableChannel clientOutboundChannel) {
494506
// Cannot be lambda because Java can't infer generic type from lambdas,
495-
// therefore we end up with ClassCastException for other event types
507+
// therefore, we end up with ClassCastException for other event types
496508
return new ApplicationListener<SessionSubscribeEvent>() {
497509

498510
@Override

spring-integration-websocket/src/test/java/org/springframework/integration/websocket/inbound/WebSocketInboundChannelAdapterTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ public void testWebSocketInboundChannelAdapter() {
9696
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
9797
headers.setLeaveMutable(true);
9898
headers.setSessionId(sessionId);
99+
headers.setSubscriptionId("sub1");
99100
Message<byte[]> message =
100101
MessageBuilder.createMessage(ByteBuffer.allocate(0).array(), headers.getMessageHeaders());
101102

spring-integration-websocket/src/test/java/org/springframework/integration/websocket/outbound/WebSocketOutboundMessageHandlerTests.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,32 @@ public class WebSocketOutboundMessageHandlerTests {
6666

6767
@Test
6868
public void testWebSocketOutboundMessageHandler() {
69-
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
69+
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
70+
this.messageHandler.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build());
71+
72+
headers = StompHeaderAccessor.create(StompCommand.SEND);
7073
headers.setMessageId("mess0");
7174
headers.setSubscriptionId("sub0");
72-
headers.setDestination("/foo");
75+
headers.setDestination("/dest");
7376
String payload = "Hello World";
7477
Message<String> message = MessageBuilder.withPayload(payload).setHeaders(headers).build();
7578

7679
this.messageHandler.handleMessage(message);
7780

7881
Message<?> received = this.clientInboundChannel.receive(10000);
79-
assertThat(received).isNotNull();
80-
81-
StompHeaderAccessor receivedHeaders = StompHeaderAccessor.wrap(received);
82-
assertThat(receivedHeaders.getMessageId()).isEqualTo("mess0");
83-
assertThat(receivedHeaders.getSubscriptionId()).isEqualTo("sub0");
84-
assertThat(receivedHeaders.getDestination()).isEqualTo("/foo");
82+
assertThat(received)
83+
.extracting(StompHeaderAccessor::wrap)
84+
.extracting(StompHeaderAccessor::getCommand)
85+
.isEqualTo(StompCommand.CONNECT);
86+
87+
received = this.clientInboundChannel.receive(10000);
88+
assertThat(received)
89+
.extracting(StompHeaderAccessor::wrap)
90+
.satisfies(headerAccessor -> {
91+
assertThat(headerAccessor.getMessageId()).isEqualTo("mess0");
92+
assertThat(headerAccessor.getSubscriptionId()).isEqualTo("sub0");
93+
assertThat(headerAccessor.getDestination()).isEqualTo("/dest");
94+
});
8595

8696
Object receivedPayload = received.getPayload();
8797
assertThat(receivedPayload).isInstanceOf(byte[].class);

0 commit comments

Comments
 (0)