Skip to content

Commit 53231f0

Browse files
committed
WIP
1 parent 4c8b99e commit 53231f0

18 files changed

+842
-996
lines changed

accept.go

+6-9
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func (opts *AcceptOptions) ensure() *AcceptOptions {
6464
if opts == nil {
6565
return &AcceptOptions{}
6666
}
67-
return nil
67+
return opts
6868
}
6969

7070
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
@@ -119,16 +119,14 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
119119
b, _ := brw.Reader.Peek(brw.Reader.Buffered())
120120
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
121121

122-
c := &Conn{
122+
return newConn(connConfig{
123123
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
124+
rwc: netConn,
125+
client: false,
126+
copts: copts,
124127
br: brw.Reader,
125128
bw: brw.Writer,
126-
closer: netConn,
127-
copts: copts,
128-
}
129-
c.init()
130-
131-
return c, nil
129+
}), nil
132130
}
133131

134132
func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
@@ -278,7 +276,6 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com
278276
return copts, nil
279277
}
280278

281-
282279
func headerContainsToken(h http.Header, key, token string) bool {
283280
token = strings.ToLower(token)
284281

assert_test.go

+14
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ func randBytes(n int) []byte {
2323
}
2424

2525
func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) {
26+
t.Helper()
27+
2628
exp := randString(n)
2729
err := wsjson.Write(ctx, c, exp)
2830
assert.Success(t, err)
@@ -35,6 +37,8 @@ func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int)
3537
}
3638

3739
func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) {
40+
t.Helper()
41+
3842
var act interface{}
3943
err := wsjson.Read(ctx, c, &act)
4044
assert.Success(t, err)
@@ -56,6 +60,8 @@ func randString(n int) string {
5660
}
5761

5862
func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) {
63+
t.Helper()
64+
5965
p := randBytes(n)
6066
err := c.Write(ctx, typ, p)
6167
assert.Success(t, err)
@@ -68,5 +74,13 @@ func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websoc
6874
}
6975

7076
func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) {
77+
t.Helper()
78+
7179
assert.Equalf(t, exp, c.Subprotocol(), "unexpected subprotocol")
7280
}
81+
82+
func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) {
83+
t.Helper()
84+
85+
assert.Equalf(t, exp, websocket.CloseStatus(err), "unexpected status code")
86+
}

close.go

+92-37
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import (
66
"errors"
77
"fmt"
88
"log"
9-
"nhooyr.io/websocket/internal/wsframe"
9+
"nhooyr.io/websocket/internal/bufpool"
10+
"time"
1011
)
1112

1213
// StatusCode represents a WebSocket status code.
@@ -75,21 +76,101 @@ func CloseStatus(err error) StatusCode {
7576
return -1
7677
}
7778

79+
// Close closes the WebSocket connection with the given status code and reason.
80+
//
81+
// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
82+
// the peer to send a close frame.
83+
// Thus, it implements the full WebSocket close handshake.
84+
// All data messages received from the peer during the close handshake
85+
// will be discarded.
86+
//
87+
// The connection can only be closed once. Additional calls to Close
88+
// are no-ops.
89+
//
90+
// The maximum length of reason must be 125 bytes otherwise an internal
91+
// error will be sent to the peer. For this reason, you should avoid
92+
// sending a dynamic reason.
93+
//
94+
// Close will unblock all goroutines interacting with the connection once
95+
// complete.
96+
func (c *Conn) Close(code StatusCode, reason string) error {
97+
err := c.closeHandshake(code, reason)
98+
if err != nil {
99+
return fmt.Errorf("failed to close websocket: %w", err)
100+
}
101+
return nil
102+
}
103+
104+
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
105+
err := c.cw.sendClose(code, reason)
106+
if err != nil {
107+
return err
108+
}
109+
110+
return c.cr.waitClose()
111+
}
112+
113+
func (cw *connWriter) error(code StatusCode, err error) {
114+
cw.c.setCloseErr(err)
115+
cw.sendClose(code, err.Error())
116+
cw.c.close(nil)
117+
}
118+
119+
func (cw *connWriter) sendClose(code StatusCode, reason string) error {
120+
ce := CloseError{
121+
Code: code,
122+
Reason: reason,
123+
}
124+
125+
cw.c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
126+
127+
var p []byte
128+
if ce.Code != StatusNoStatusRcvd {
129+
p = ce.bytes()
130+
}
131+
132+
return cw.control(context.Background(), opClose, p)
133+
}
134+
135+
func (cr *connReader) waitClose() error {
136+
defer cr.c.close(nil)
137+
138+
return nil
139+
140+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
141+
defer cancel()
142+
143+
err := cr.mu.Lock(ctx)
144+
if err != nil {
145+
return err
146+
}
147+
defer cr.mu.Unlock()
148+
149+
b := bufpool.Get()
150+
buf := b.Bytes()
151+
buf = buf[:cap(buf)]
152+
defer bufpool.Put(b)
153+
154+
for {
155+
// TODO
156+
return nil
157+
}
158+
}
159+
78160
func parseClosePayload(p []byte) (CloseError, error) {
79161
if len(p) == 0 {
80162
return CloseError{
81163
Code: StatusNoStatusRcvd,
82164
}, nil
83165
}
84166

85-
code, reason, err := wsframe.ParseClosePayload(p)
86-
if err != nil {
87-
return CloseError{}, err
167+
if len(p) < 2 {
168+
return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
88169
}
89170

90171
ce := CloseError{
91-
Code: StatusCode(code),
92-
Reason: reason,
172+
Code: StatusCode(binary.BigEndian.Uint16(p)),
173+
Reason: string(p[2:]),
93174
}
94175

95176
if !validWireCloseCode(ce.Code) {
@@ -129,11 +210,13 @@ func (ce CloseError) bytes() []byte {
129210
return p
130211
}
131212

213+
const maxCloseReason = maxControlPayload - 2
214+
132215
func (ce CloseError) bytesErr() ([]byte, error) {
133-
const maxReason = maxControlPayload-2
134-
if len(ce.Reason) > maxReason {
135-
return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxReason, ce.Reason, len(ce.Reason))
216+
if len(ce.Reason) > maxCloseReason {
217+
return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
136218
}
219+
137220
if !validWireCloseCode(ce.Code) {
138221
return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
139222
}
@@ -144,34 +227,6 @@ func (ce CloseError) bytesErr() ([]byte, error) {
144227
return buf, nil
145228
}
146229

147-
// CloseRead will start a goroutine to read from the connection until it is closed or a data message
148-
// is received. If a data message is received, the connection will be closed with StatusPolicyViolation.
149-
// Since CloseRead reads from the connection, it will respond to ping, pong and close frames.
150-
// After calling this method, you cannot read any data messages from the connection.
151-
// The returned context will be cancelled when the connection is closed.
152-
//
153-
// Use this when you do not want to read data messages from the connection anymore but will
154-
// want to write messages to it.
155-
func (c *Conn) CloseRead(ctx context.Context) context.Context {
156-
ctx, cancel := context.WithCancel(ctx)
157-
go func() {
158-
defer cancel()
159-
c.Reader(ctx)
160-
c.Close(StatusPolicyViolation, "unexpected data message")
161-
}()
162-
return ctx
163-
}
164-
165-
// SetReadLimit sets the max number of bytes to read for a single message.
166-
// It applies to the Reader and Read methods.
167-
//
168-
// By default, the connection has a message read limit of 32768 bytes.
169-
//
170-
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
171-
func (c *Conn) SetReadLimit(n int64) {
172-
c.r.lr.limit.Store(n)
173-
}
174-
175230
func (c *Conn) setCloseErr(err error) {
176231
c.closeMu.Lock()
177232
c.setCloseErrNoLock(err)

close_test.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"io"
66
"math"
77
"nhooyr.io/websocket/internal/assert"
8-
"nhooyr.io/websocket/internal/wsframe"
98
"strings"
109
"testing"
1110
)
@@ -22,23 +21,23 @@ func TestCloseError(t *testing.T) {
2221
name: "normal",
2322
ce: CloseError{
2423
Code: StatusNormalClosure,
25-
Reason: strings.Repeat("x", wsframe.maxControlPayload-2),
24+
Reason: strings.Repeat("x", maxCloseReason),
2625
},
2726
success: true,
2827
},
2928
{
3029
name: "bigReason",
3130
ce: CloseError{
3231
Code: StatusNormalClosure,
33-
Reason: strings.Repeat("x", wsframe.maxControlPayload-1),
32+
Reason: strings.Repeat("x", maxCloseReason+1),
3433
},
3534
success: false,
3635
},
3736
{
3837
name: "bigCode",
3938
ce: CloseError{
4039
Code: math.MaxUint16,
41-
Reason: strings.Repeat("x", wsframe.maxControlPayload-2),
40+
Reason: strings.Repeat("x", maxCloseReason),
4241
},
4342
success: false,
4443
},

compress.go

+14-12
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ func (c *Conn) writeNoContextTakeOver() bool {
8484
return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover
8585
}
8686

87-
8887
func (c *Conn) readNoContextTakeOver() bool {
8988
return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover
9089
}
@@ -94,42 +93,45 @@ type trimLastFourBytesWriter struct {
9493
tail []byte
9594
}
9695

97-
func (w *trimLastFourBytesWriter) Write(p []byte) (int, error) {
98-
extra := len(w.tail) + len(p) - 4
96+
func (tw *trimLastFourBytesWriter) reset() {
97+
tw.tail = tw.tail[:0]
98+
}
99+
100+
func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
101+
extra := len(tw.tail) + len(p) - 4
99102

100103
if extra <= 0 {
101-
w.tail = append(w.tail, p...)
104+
tw.tail = append(tw.tail, p...)
102105
return len(p), nil
103106
}
104107

105108
// Now we need to write as many extra bytes as we can from the previous tail.
106-
if extra > len(w.tail) {
107-
extra = len(w.tail)
109+
if extra > len(tw.tail) {
110+
extra = len(tw.tail)
108111
}
109112
if extra > 0 {
110-
_, err := w.Write(w.tail[:extra])
113+
_, err := tw.w.Write(tw.tail[:extra])
111114
if err != nil {
112115
return 0, err
113116
}
114-
w.tail = w.tail[extra:]
117+
tw.tail = tw.tail[extra:]
115118
}
116119

117120
// If p is less than or equal to 4 bytes,
118121
// all of it is is part of the tail.
119122
if len(p) <= 4 {
120-
w.tail = append(w.tail, p...)
123+
tw.tail = append(tw.tail, p...)
121124
return len(p), nil
122125
}
123126

124127
// Otherwise, only the last 4 bytes are.
125-
w.tail = append(w.tail, p[len(p)-4:]...)
128+
tw.tail = append(tw.tail, p[len(p)-4:]...)
126129

127130
p = p[:len(p)-4]
128-
n, err := w.w.Write(p)
131+
n, err := tw.w.Write(p)
129132
return n + 4, err
130133
}
131134

132-
133135
var flateReaderPool sync.Pool
134136

135137
func getFlateReader(r io.Reader) io.Reader {

0 commit comments

Comments
 (0)