Skip to content

Commit a2a2d31

Browse files
committed
Add NetConn adapter
Closes #100
1 parent 1c4fdf2 commit a2a2d31

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

netconn.go

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package websocket
2+
3+
import (
4+
"context"
5+
"golang.org/x/xerrors"
6+
"io"
7+
"math"
8+
"net"
9+
"time"
10+
)
11+
12+
// NetConn converts a *websocket.Conn into a net.Conn.
13+
// Every Write to the net.Conn will correspond to a binary message
14+
// write on *webscoket.Conn.
15+
// Close will close the *websocket.Conn with StatusNormalClosure.
16+
// When a deadline is hit, the connection will be closed. This is
17+
// different from most net.Conn implementations where only the
18+
// reading/writing goroutines are interrupted but the connection is kept alive.
19+
// The Addr methods will return zero value net.TCPAddr.
20+
func NetConn(c *Conn) net.Conn {
21+
nc := &netConn{
22+
c: c,
23+
}
24+
25+
var cancel context.CancelFunc
26+
nc.writeContext, cancel = context.WithCancel(context.Background())
27+
nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
28+
nc.writeTimer.Stop()
29+
30+
nc.readContext, cancel = context.WithCancel(context.Background())
31+
nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
32+
nc.readTimer.Stop()
33+
34+
return nc
35+
}
36+
37+
type netConn struct {
38+
c *Conn
39+
40+
writeTimer *time.Timer
41+
writeContext context.Context
42+
43+
readTimer *time.Timer
44+
readContext context.Context
45+
46+
reader io.Reader
47+
}
48+
49+
var _ net.Conn = &netConn{}
50+
51+
func (c *netConn) Close() error {
52+
return c.c.Close(StatusNormalClosure, "")
53+
}
54+
55+
func (c *netConn) Write(p []byte) (int, error) {
56+
err := c.c.Write(c.writeContext, MessageBinary, p)
57+
if err != nil {
58+
return 0, err
59+
}
60+
return len(p), nil
61+
}
62+
63+
func (c *netConn) Read(p []byte) (int, error) {
64+
if c.reader == nil {
65+
typ, r, err := c.c.Reader(c.readContext)
66+
if err != nil {
67+
return 0, err
68+
}
69+
if typ != MessageBinary {
70+
c.c.Close(StatusUnsupportedData, "can only accept binary messages")
71+
return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", MessageBinary, typ)
72+
}
73+
c.reader = r
74+
}
75+
76+
n, err := c.reader.Read(p)
77+
if err == io.EOF {
78+
c.reader = nil
79+
}
80+
return n, err
81+
}
82+
83+
type unknownAddr struct {
84+
}
85+
86+
func (a unknownAddr) Network() string {
87+
return "unknown"
88+
}
89+
90+
func (a unknownAddr) String() string {
91+
return "unknown"
92+
}
93+
94+
func (c *netConn) RemoteAddr() net.Addr {
95+
return unknownAddr{}
96+
}
97+
98+
func (c *netConn) LocalAddr() net.Addr {
99+
return unknownAddr{}
100+
}
101+
102+
func (c *netConn) SetDeadline(t time.Time) error {
103+
c.SetWriteDeadline(t)
104+
c.SetReadDeadline(t)
105+
return nil
106+
}
107+
108+
func (c *netConn) SetWriteDeadline(t time.Time) error {
109+
c.writeTimer.Reset(t.Sub(time.Now()))
110+
return nil
111+
}
112+
113+
func (c *netConn) SetReadDeadline(t time.Time) error {
114+
c.readTimer.Reset(t.Sub(time.Now()))
115+
return nil
116+
}

websocket_test.go

+48
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,54 @@ func TestHandshake(t *testing.T) {
118118
return nil
119119
},
120120
},
121+
{
122+
name: "netConn",
123+
server: func(w http.ResponseWriter, r *http.Request) error {
124+
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
125+
if err != nil {
126+
return err
127+
}
128+
defer c.Close(websocket.StatusInternalError, "")
129+
130+
nc := websocket.NetConn(c)
131+
defer nc.Close()
132+
133+
nc.SetWriteDeadline(time.Now().Add(time.Second * 10))
134+
135+
_, err = nc.Write([]byte("hello"))
136+
if err != nil {
137+
return err
138+
}
139+
140+
return nil
141+
},
142+
client: func(ctx context.Context, u string) error {
143+
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
144+
Subprotocols: []string{"meow"},
145+
})
146+
if err != nil {
147+
return err
148+
}
149+
defer c.Close(websocket.StatusInternalError, "")
150+
151+
nc := websocket.NetConn(c)
152+
defer nc.Close()
153+
154+
nc.SetReadDeadline(time.Now().Add(time.Second * 10))
155+
156+
p := make([]byte, len("hello"))
157+
_, err = io.ReadFull(nc, p)
158+
if err != nil {
159+
return err
160+
}
161+
162+
if string(p) != "hello" {
163+
return xerrors.Errorf("unexpected payload %q received", string(p))
164+
}
165+
166+
return nil
167+
},
168+
},
121169
{
122170
name: "defaultSubprotocol",
123171
server: func(w http.ResponseWriter, r *http.Request) error {

0 commit comments

Comments
 (0)