Skip to content

Commit 7af1936

Browse files
committed
Implement basic NIO support
Fixes #11
1 parent 85d4681 commit 7af1936

File tree

7 files changed

+560
-2
lines changed

7 files changed

+560
-2
lines changed

src/main/java/com/rabbitmq/client/ConnectionFactory.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ public class ConnectionFactory implements Cloneable {
108108

109109
private MetricsCollector metricsCollector;
110110

111+
private boolean nio = false;
112+
private FrameHandlerFactory frameHandlerFactory;
113+
111114
/** @return the default host to use for connections */
112115
public String getHost() {
113116
return host;
@@ -642,7 +645,15 @@ public MetricsCollector getMetricsCollector() {
642645
}
643646

644647
protected FrameHandlerFactory createFrameHandlerFactory() throws IOException {
645-
return new FrameHandlerFactory(connectionTimeout, factory, socketConf, isSSL(), this.shutdownExecutor);
648+
if(nio) {
649+
if(this.frameHandlerFactory == null) {
650+
this.frameHandlerFactory = new SocketChannelFrameHandlerFactory(connectionTimeout, factory, socketConf, isSSL(), this.shutdownExecutor);
651+
}
652+
return this.frameHandlerFactory;
653+
} else {
654+
return new FrameHandlerFactory(connectionTimeout, factory, socketConf, isSSL(), this.shutdownExecutor);
655+
}
656+
646657
}
647658

648659
/**
@@ -1019,4 +1030,8 @@ public void setNetworkRecoveryInterval(int networkRecoveryInterval) {
10191030
public void setNetworkRecoveryInterval(long networkRecoveryInterval) {
10201031
this.networkRecoveryInterval = networkRecoveryInterval;
10211032
}
1033+
1034+
public void setNio(boolean nio) {
1035+
this.nio = nio;
1036+
}
10221037
}

src/main/java/com/rabbitmq/client/impl/Frame.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121
import java.io.DataOutputStream;
2222
import java.io.IOException;
2323
import java.io.UnsupportedEncodingException;
24+
import java.lang.reflect.Field;
2425
import java.math.BigDecimal;
2526
import java.net.SocketTimeoutException;
27+
import java.nio.ByteBuffer;
28+
import java.nio.channels.SocketChannel;
2629
import java.sql.Timestamp;
2730
import java.util.Date;
2831
import java.util.Map;
@@ -123,6 +126,31 @@ public static Frame readFrom(DataInputStream is) throws IOException {
123126
return new Frame(type, channel, payload);
124127
}
125128

129+
public static Frame readFrom(SocketChannel socketChannel, ByteBuffer buffer) throws IOException {
130+
// FIXME make frame read better
131+
int type;
132+
int channel;
133+
134+
type = buffer.get() & 0xff;
135+
136+
// FIXME check not a version mismatch
137+
138+
int ch1 = buffer.get() & 0xff;
139+
int ch2 = buffer.get() & 0xff;
140+
141+
channel = (ch1 << 8) + (ch2 << 0);
142+
int payloadSize = buffer.getInt();
143+
144+
byte[] payload = new byte[payloadSize];
145+
buffer.get(payload);
146+
int frameEndMarker = buffer.get() & 0xff;
147+
if (frameEndMarker != AMQP.FRAME_END) {
148+
throw new MalformedFrameException("Bad frame end marker: " + frameEndMarker);
149+
}
150+
151+
return new Frame(type, channel, payload);
152+
}
153+
126154
/**
127155
* Private API - A protocol version mismatch is detected by checking the
128156
* three next bytes if a frame type of (int)'A' is read from an input
@@ -197,6 +225,24 @@ public void writeTo(DataOutputStream os) throws IOException {
197225
os.write(AMQP.FRAME_END);
198226
}
199227

228+
public void writeTo(SocketChannel socketChannel, ByteBuffer buffer) throws IOException {
229+
buffer.put((byte) type);
230+
buffer.put((byte) ((channel >>> 8) & 0xFF));
231+
buffer.put((byte) ((channel >>> 0) & 0xFF));
232+
233+
if(accumulator != null) {
234+
buffer.putInt(accumulator.size());
235+
buffer.put(accumulator.toByteArray());
236+
} else {
237+
buffer.putInt(payload.length);
238+
buffer.put(payload);
239+
}
240+
buffer.put((byte) AMQP.FRAME_END);
241+
242+
buffer.flip();
243+
while(buffer.hasRemaining() && socketChannel.write(buffer) != -1);
244+
}
245+
200246
/**
201247
* Public API - retrieves the frame payload
202248
*/

src/main/java/com/rabbitmq/client/impl/FrameHandlerFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
public class FrameHandlerFactory {
2929
private final int connectionTimeout;
3030
private final SocketFactory factory;
31-
private final SocketConfigurator configurator;
31+
protected final SocketConfigurator configurator;
3232
private final ExecutorService shutdownExecutor;
3333
private final boolean ssl;
3434

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (c) 2007-Present Pivotal Software, Inc. All rights reserved.
2+
//
3+
// This software, the RabbitMQ Java client library, is triple-licensed under the
4+
// Mozilla Public License 1.1 ("MPL"), the GNU General Public License version 2
5+
// ("GPL") and the Apache License version 2 ("ASL"). For the MPL, please see
6+
// LICENSE-MPL-RabbitMQ. For the GPL, please see LICENSE-GPL2. For the ASL,
7+
// please see LICENSE-APACHE2.
8+
//
9+
// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND,
10+
// either express or implied. See the LICENSE file for specific language governing
11+
// rights and limitations of this software.
12+
//
13+
// If you have any questions regarding licensing, please contact us at
14+
15+
16+
package com.rabbitmq.client.impl;
17+
18+
import com.rabbitmq.client.ConnectionFactory;
19+
import org.slf4j.Logger;
20+
import org.slf4j.LoggerFactory;
21+
22+
import java.io.IOException;
23+
import java.net.InetAddress;
24+
import java.net.SocketException;
25+
import java.util.concurrent.TimeUnit;
26+
27+
/**
28+
*
29+
*/
30+
public class SocketChannelFrameHandler implements FrameHandler {
31+
32+
private static final Logger LOGGER = LoggerFactory.getLogger(SocketChannelFrameHandler.class);
33+
34+
private final SocketChannelFrameHandlerState state;
35+
36+
private volatile int readTimeout = ConnectionFactory.DEFAULT_HEARTBEAT * 1000;
37+
38+
public SocketChannelFrameHandler(SocketChannelFrameHandlerState state) {
39+
this.state = state;
40+
}
41+
42+
@Override
43+
public InetAddress getLocalAddress() {
44+
return state.getChannel().socket().getLocalAddress();
45+
}
46+
47+
@Override
48+
public int getLocalPort() {
49+
return state.getChannel().socket().getLocalPort();
50+
}
51+
52+
@Override
53+
public InetAddress getAddress() {
54+
return state.getChannel().socket().getInetAddress();
55+
}
56+
57+
@Override
58+
public int getPort() {
59+
return state.getChannel().socket().getPort();
60+
}
61+
62+
@Override
63+
public void setTimeout(int timeoutMs) throws SocketException {
64+
state.getChannel().socket().setSoTimeout(timeoutMs);
65+
this.readTimeout = timeoutMs;
66+
}
67+
68+
@Override
69+
public int getTimeout() throws SocketException {
70+
return state.getChannel().socket().getSoTimeout();
71+
}
72+
73+
@Override
74+
public void sendHeader() throws IOException {
75+
state.setSendHeader(true);
76+
}
77+
78+
@Override
79+
public Frame readFrame() throws IOException {
80+
try {
81+
return state.getReadQueue().poll(readTimeout, TimeUnit.MILLISECONDS);
82+
} catch (InterruptedException e) {
83+
throw new IOException("Timeout while polling read queue", e);
84+
}
85+
}
86+
87+
@Override
88+
public void writeFrame(Frame frame) throws IOException {
89+
state.write(frame);
90+
}
91+
92+
@Override
93+
public void flush() throws IOException {
94+
95+
}
96+
97+
@Override
98+
public void close() {
99+
state.getReadSelectionKey().cancel();
100+
state.getWriteSelectionKey().cancel();
101+
try {
102+
state.getChannel().close();
103+
} catch (IOException e) {
104+
LOGGER.error("Error while closing SocketChannel", e);
105+
}
106+
}
107+
}

0 commit comments

Comments
 (0)