@@ -3,25 +3,57 @@ package main
3
3
import (
4
4
"context"
5
5
"errors"
6
- "io"
7
6
"io/ioutil"
8
7
"log"
9
8
"net/http"
10
9
"sync"
11
10
"time"
12
11
12
+ "golang.org/x/time/rate"
13
+
13
14
"nhooyr.io/websocket"
14
15
)
15
16
16
17
// chatServer enables broadcasting to a set of subscribers.
17
18
type chatServer struct {
18
- registerOnce sync.Once
19
- m http.ServeMux
20
-
21
- subscribersMu sync.RWMutex
19
+ // subscriberMessageBuffer controls the max number
20
+ // of messages that can be queued for a subscriber
21
+ // before it is kicked.
22
+ //
23
+ // Defaults to 16.
24
+ subscriberMessageBuffer int
25
+
26
+ // publishLimiter controls the rate limit applied to the publish endpoint.
27
+ //
28
+ // Defaults to one publish every 100ms with a burst of 8.
29
+ publishLimiter * rate.Limiter
30
+
31
+ // logf controls where logs are sent.
32
+ // Defaults to log.Printf.
33
+ logf func (f string , v ... interface {})
34
+
35
+ // serveMux routes the various endpoints to the appropriate handler.
36
+ serveMux http.ServeMux
37
+
38
+ subscribersMu sync.Mutex
22
39
subscribers map [* subscriber ]struct {}
23
40
}
24
41
42
+ // newChatServer constructs a chatServer with the defaults.
43
+ func newChatServer () * chatServer {
44
+ cs := & chatServer {
45
+ subscriberMessageBuffer : 16 ,
46
+ logf : log .Printf ,
47
+ subscribers : make (map [* subscriber ]struct {}),
48
+ publishLimiter : rate .NewLimiter (rate .Every (time .Millisecond * 100 ), 8 ),
49
+ }
50
+ cs .serveMux .Handle ("/" , http .FileServer (http .Dir ("." )))
51
+ cs .serveMux .HandleFunc ("/subscribe" , cs .subscribeHandler )
52
+ cs .serveMux .HandleFunc ("/publish" , cs .publishHandler )
53
+
54
+ return cs
55
+ }
56
+
25
57
// subscriber represents a subscriber.
26
58
// Messages are sent on the msgs channel and if the client
27
59
// cannot keep up with the messages, closeSlow is called.
@@ -31,20 +63,15 @@ type subscriber struct {
31
63
}
32
64
33
65
func (cs * chatServer ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
34
- cs .registerOnce .Do (func () {
35
- cs .m .Handle ("/" , http .FileServer (http .Dir ("." )))
36
- cs .m .HandleFunc ("/subscribe" , cs .subscribeHandler )
37
- cs .m .HandleFunc ("/publish" , cs .publishHandler )
38
- })
39
- cs .m .ServeHTTP (w , r )
66
+ cs .serveMux .ServeHTTP (w , r )
40
67
}
41
68
42
69
// subscribeHandler accepts the WebSocket connection and then subscribes
43
70
// it to all future messages.
44
71
func (cs * chatServer ) subscribeHandler (w http.ResponseWriter , r * http.Request ) {
45
72
c , err := websocket .Accept (w , r , nil )
46
73
if err != nil {
47
- log . Print ( err )
74
+ cs . logf ( "%v" , err )
48
75
return
49
76
}
50
77
defer c .Close (websocket .StatusInternalError , "" )
@@ -58,7 +85,8 @@ func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
58
85
return
59
86
}
60
87
if err != nil {
61
- log .Print (err )
88
+ cs .logf ("%v" , err )
89
+ return
62
90
}
63
91
}
64
92
@@ -69,7 +97,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
69
97
http .Error (w , http .StatusText (http .StatusMethodNotAllowed ), http .StatusMethodNotAllowed )
70
98
return
71
99
}
72
- body := io . LimitReader ( r .Body , 8192 )
100
+ body := http . MaxBytesReader ( w , r .Body , 8192 )
73
101
msg , err := ioutil .ReadAll (body )
74
102
if err != nil {
75
103
http .Error (w , http .StatusText (http .StatusRequestEntityTooLarge ), http .StatusRequestEntityTooLarge )
@@ -93,7 +121,7 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
93
121
ctx = c .CloseRead (ctx )
94
122
95
123
s := & subscriber {
96
- msgs : make (chan []byte , 16 ),
124
+ msgs : make (chan []byte , cs . subscriberMessageBuffer ),
97
125
closeSlow : func () {
98
126
c .Close (websocket .StatusPolicyViolation , "connection too slow to keep up with messages" )
99
127
},
@@ -118,8 +146,10 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
118
146
// It never blocks and so messages to slow subscribers
119
147
// are dropped.
120
148
func (cs * chatServer ) publish (msg []byte ) {
121
- cs .subscribersMu .RLock ()
122
- defer cs .subscribersMu .RUnlock ()
149
+ cs .subscribersMu .Lock ()
150
+ defer cs .subscribersMu .Unlock ()
151
+
152
+ cs .publishLimiter .Wait (context .Background ())
123
153
124
154
for s := range cs .subscribers {
125
155
select {
@@ -133,9 +163,6 @@ func (cs *chatServer) publish(msg []byte) {
133
163
// addSubscriber registers a subscriber.
134
164
func (cs * chatServer ) addSubscriber (s * subscriber ) {
135
165
cs .subscribersMu .Lock ()
136
- if cs .subscribers == nil {
137
- cs .subscribers = make (map [* subscriber ]struct {})
138
- }
139
166
cs .subscribers [s ] = struct {}{}
140
167
cs .subscribersMu .Unlock ()
141
168
}
0 commit comments