@@ -10,6 +10,7 @@ import (
10
10
"reflect"
11
11
"runtime"
12
12
"sync"
13
+ "sync/atomic"
13
14
"syscall/js"
14
15
15
16
"nhooyr.io/websocket/internal/wsjs"
@@ -19,9 +20,10 @@ import (
19
20
type Conn struct {
20
21
ws wsjs.WebSocket
21
22
22
- closeOnce sync.Once
23
- closed chan struct {}
24
- closeErr error
23
+ readClosed int64
24
+ closeOnce sync.Once
25
+ closed chan struct {}
26
+ closeErr error
25
27
26
28
releaseOnClose func ()
27
29
releaseOnMessage func ()
@@ -67,6 +69,10 @@ func (c *Conn) init() {
67
69
// Read attempts to read a message from the connection.
68
70
// The maximum time spent waiting is bounded by the context.
69
71
func (c * Conn ) Read (ctx context.Context ) (MessageType , []byte , error ) {
72
+ if atomic .LoadInt64 (& c .readClosed ) == 1 {
73
+ return 0 , nil , fmt .Errorf ("websocket connection read closed" )
74
+ }
75
+
70
76
typ , p , err := c .read (ctx )
71
77
if err != nil {
72
78
return 0 , nil , fmt .Errorf ("failed to read: %w" , err )
@@ -78,6 +84,7 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
78
84
var me wsjs.MessageEvent
79
85
select {
80
86
case <- ctx .Done ():
87
+ c .Close (StatusPolicyViolation , "read timed out" )
81
88
return 0 , nil , ctx .Err ()
82
89
case me = <- c .readch :
83
90
case <- c .closed :
@@ -198,6 +205,7 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp
198
205
199
206
select {
200
207
case <- ctx .Done ():
208
+ c .Close (StatusPolicyViolation , "dial timed out" )
201
209
return nil , nil , ctx .Err ()
202
210
case <- opench :
203
211
case <- c .closed :
@@ -215,3 +223,8 @@ func (c *netConn) netConnReader(ctx context.Context) (MessageType, io.Reader, er
215
223
}
216
224
return typ , bytes .NewReader (p ), nil
217
225
}
226
+
227
+ // Only implemented for use by *Conn.CloseRead in netconn.go
228
+ func (c * Conn ) reader (ctx context.Context ) {
229
+ c .read (ctx )
230
+ }
0 commit comments