@@ -2,6 +2,7 @@ package websocket
2
2
3
3
import (
4
4
"context"
5
+ "fmt"
5
6
"io"
6
7
"math"
7
8
"net"
@@ -17,8 +18,11 @@ import (
17
18
// correctly and so provided in the library.
18
19
// See https://github.com/nhooyr/websocket/issues/100.
19
20
//
20
- // Every Write to the net.Conn will correspond to a binary message
21
- // write on *webscoket.Conn.
21
+ // Every Write to the net.Conn will correspond to a message write of
22
+ // the given type on *websocket.Conn.
23
+ //
24
+ // If a message is read that is not of the correct type, an error
25
+ // will be thrown.
22
26
//
23
27
// Close will close the *websocket.Conn with StatusNormalClosure.
24
28
//
@@ -30,9 +34,10 @@ import (
30
34
// and "websocket/unknown-addr" for String.
31
35
//
32
36
// A received StatusNormalClosure close frame will be translated to EOF when reading.
33
- func NetConn (c * Conn ) net.Conn {
37
+ func NetConn (c * Conn , msgType MessageType ) net.Conn {
34
38
nc := & netConn {
35
- c : c ,
39
+ c : c ,
40
+ msgType : msgType ,
36
41
}
37
42
38
43
var cancel context.CancelFunc
@@ -52,7 +57,8 @@ func NetConn(c *Conn) net.Conn {
52
57
}
53
58
54
59
type netConn struct {
55
- c * Conn
60
+ c * Conn
61
+ msgType MessageType
56
62
57
63
writeTimer * time.Timer
58
64
writeContext context.Context
@@ -71,7 +77,7 @@ func (c *netConn) Close() error {
71
77
}
72
78
73
79
func (c * netConn ) Write (p []byte ) (int , error ) {
74
- err := c .c .Write (c .writeContext , MessageBinary , p )
80
+ err := c .c .Write (c .writeContext , c . msgType , p )
75
81
if err != nil {
76
82
return 0 , err
77
83
}
@@ -93,9 +99,9 @@ func (c *netConn) Read(p []byte) (int, error) {
93
99
}
94
100
return 0 , err
95
101
}
96
- if typ != MessageBinary {
97
- c .c .Close (StatusUnsupportedData , "can only accept binary messages" )
98
- return 0 , xerrors .Errorf ("unexpected frame type read for net conn adapter (expected %v): %v" , MessageBinary , typ )
102
+ if typ != c . msgType {
103
+ c .c .Close (StatusUnsupportedData , fmt . Sprintf ( "can only accept %v messages" , c . msgType ) )
104
+ return 0 , xerrors .Errorf ("unexpected frame type read for net conn adapter (expected %v): %v" , c . msgType , typ )
99
105
}
100
106
c .reader = r
101
107
}
0 commit comments