Skip to content

Commit 176b144

Browse files
authored
Merge pull request #99 from nhooyr/closeread
Add CloseRead and closeError test
2 parents 3149225 + 6eda9c5 commit 176b144

File tree

3 files changed

+63
-10
lines changed

3 files changed

+63
-10
lines changed

example_test.go

+1-5
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,7 @@ func Example_writeOnly() {
7474
ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10)
7575
defer cancel()
7676

77-
go func() {
78-
defer cancel()
79-
c.Reader(ctx)
80-
c.Close(websocket.StatusPolicyViolation, "server doesn't accept data messages")
81-
}()
77+
ctx = c.CloseRead(ctx)
8278

8379
t := time.NewTicker(time.Second * 30)
8480
defer t.Stop()

websocket.go

+18-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
// and SetReadLimit.
2323
//
2424
// You must always read from the connection. Otherwise control
25-
// frames will not be handled. See the docs on Reader.
25+
// frames will not be handled. See the docs on Reader and CloseRead.
2626
//
2727
// Please be sure to call Close on the connection when you
2828
// are finished with it to release the associated resources.
@@ -319,10 +319,8 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
319319
// to be closed so you do not need to write your own error message.
320320
// This applies to the Read methods in the wsjson/wspb subpackages as well.
321321
//
322-
// You must read from the connection for close frames to be read.
323-
// If you do not expect any data messages from the peer, just call
324-
// Reader in a separate goroutine and close the connection with StatusPolicyViolation
325-
// when it returns. See the writeOnly example.
322+
// You must read from the connection for control frames to be handled.
323+
// If you do not expect any data messages from the peer, call CloseRead.
326324
//
327325
// Only one Reader may be open at a time.
328326
//
@@ -388,6 +386,21 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
388386
return MessageType(h.opcode), r, nil
389387
}
390388

389+
// CloseRead will close the connection if any data message is received from the peer.
390+
// Call this when you are done reading data messages from the connection but will still write
391+
// to it. Since CloseRead is still reading from the connection, it will respond to ping, pong
392+
// and close frames automatically. It will only close the connection on a data frame. The returned
393+
// context will be cancelled when the connection is closed.
394+
func (c *Conn) CloseRead(ctx context.Context) context.Context {
395+
ctx, cancel := context.WithCancel(ctx)
396+
go func() {
397+
defer cancel()
398+
c.Reader(ctx)
399+
c.Close(StatusPolicyViolation, "unexpected data message")
400+
}()
401+
return ctx
402+
}
403+
391404
// messageReader enables reading a data frame from the WebSocket connection.
392405
type messageReader struct {
393406
c *Conn

websocket_test.go

+44
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,50 @@ func TestHandshake(t *testing.T) {
7474
return nil
7575
},
7676
},
77+
{
78+
name: "closeError",
79+
server: func(w http.ResponseWriter, r *http.Request) error {
80+
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
81+
if err != nil {
82+
return err
83+
}
84+
defer c.Close(websocket.StatusInternalError, "")
85+
86+
err = wsjson.Write(r.Context(), c, "hello")
87+
if err != nil {
88+
return err
89+
}
90+
91+
return nil
92+
},
93+
client: func(ctx context.Context, u string) error {
94+
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
95+
Subprotocols: []string{"meow"},
96+
})
97+
if err != nil {
98+
return err
99+
}
100+
defer c.Close(websocket.StatusInternalError, "")
101+
102+
var m string
103+
err = wsjson.Read(ctx, c, &m)
104+
if err != nil {
105+
return err
106+
}
107+
108+
if m != "hello" {
109+
return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m)
110+
}
111+
112+
_, _, err = c.Reader(ctx)
113+
var cerr websocket.CloseError
114+
if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError {
115+
return xerrors.Errorf("unexpected error: %+v", err)
116+
}
117+
118+
return nil
119+
},
120+
},
77121
{
78122
name: "defaultSubprotocol",
79123
server: func(w http.ResponseWriter, r *http.Request) error {

0 commit comments

Comments
 (0)