Skip to content

Commit da3aa8c

Browse files
committed
Improve chat example test
1 parent 190981d commit da3aa8c

File tree

6 files changed

+284
-98
lines changed

6 files changed

+284
-98
lines changed

chat-example/README.md

+6
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,9 @@ assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoin
2525

2626
The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by
2727
`index.html` and then `index.js`.
28+
29+
There are two automated tests for the server included in `chat_test.go`. The first is a simple one
30+
client echo test. It publishes a single message and ensures it's received.
31+
32+
The second is a complex concurrency test where 10 clients send 128 unique messages
33+
of max 128 bytes concurrently. The test ensures all messages are seen by every client.

chat-example/chat.go

+47-20
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,57 @@ package main
33
import (
44
"context"
55
"errors"
6-
"io"
76
"io/ioutil"
87
"log"
98
"net/http"
109
"sync"
1110
"time"
1211

12+
"golang.org/x/time/rate"
13+
1314
"nhooyr.io/websocket"
1415
)
1516

1617
// chatServer enables broadcasting to a set of subscribers.
1718
type chatServer struct {
18-
registerOnce sync.Once
19-
m http.ServeMux
20-
21-
subscribersMu sync.RWMutex
19+
// subscriberMessageBuffer controls the max number
20+
// of messages that can be queued for a subscriber
21+
// before it is kicked.
22+
//
23+
// Defaults to 16.
24+
subscriberMessageBuffer int
25+
26+
// publishLimiter controls the rate limit applied to the publish endpoint.
27+
//
28+
// Defaults to one publish every 100ms with a burst of 8.
29+
publishLimiter *rate.Limiter
30+
31+
// logf controls where logs are sent.
32+
// Defaults to log.Printf.
33+
logf func(f string, v ...interface{})
34+
35+
// serveMux routes the various endpoints to the appropriate handler.
36+
serveMux http.ServeMux
37+
38+
subscribersMu sync.Mutex
2239
subscribers map[*subscriber]struct{}
2340
}
2441

42+
// newChatServer constructs a chatServer with the defaults.
43+
func newChatServer() *chatServer {
44+
cs := &chatServer{
45+
subscriberMessageBuffer: 16,
46+
logf: log.Printf,
47+
subscribers: make(map[*subscriber]struct{}),
48+
publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8),
49+
}
50+
cs.serveMux.Handle("/", http.FileServer(http.Dir(".")))
51+
cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler)
52+
cs.serveMux.HandleFunc("/publish", cs.publishHandler)
53+
54+
return cs
55+
}
56+
2557
// subscriber represents a subscriber.
2658
// Messages are sent on the msgs channel and if the client
2759
// cannot keep up with the messages, closeSlow is called.
@@ -31,20 +63,15 @@ type subscriber struct {
3163
}
3264

3365
func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
34-
cs.registerOnce.Do(func() {
35-
cs.m.Handle("/", http.FileServer(http.Dir(".")))
36-
cs.m.HandleFunc("/subscribe", cs.subscribeHandler)
37-
cs.m.HandleFunc("/publish", cs.publishHandler)
38-
})
39-
cs.m.ServeHTTP(w, r)
66+
cs.serveMux.ServeHTTP(w, r)
4067
}
4168

4269
// subscribeHandler accepts the WebSocket connection and then subscribes
4370
// it to all future messages.
4471
func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
4572
c, err := websocket.Accept(w, r, nil)
4673
if err != nil {
47-
log.Print(err)
74+
cs.logf("%v", err)
4875
return
4976
}
5077
defer c.Close(websocket.StatusInternalError, "")
@@ -58,7 +85,8 @@ func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
5885
return
5986
}
6087
if err != nil {
61-
log.Print(err)
88+
cs.logf("%v", err)
89+
return
6290
}
6391
}
6492

@@ -69,7 +97,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
6997
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
7098
return
7199
}
72-
body := io.LimitReader(r.Body, 8192)
100+
body := http.MaxBytesReader(w, r.Body, 8192)
73101
msg, err := ioutil.ReadAll(body)
74102
if err != nil {
75103
http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge)
@@ -93,7 +121,7 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
93121
ctx = c.CloseRead(ctx)
94122

95123
s := &subscriber{
96-
msgs: make(chan []byte, 16),
124+
msgs: make(chan []byte, cs.subscriberMessageBuffer),
97125
closeSlow: func() {
98126
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
99127
},
@@ -118,8 +146,10 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
118146
// It never blocks and so messages to slow subscribers
119147
// are dropped.
120148
func (cs *chatServer) publish(msg []byte) {
121-
cs.subscribersMu.RLock()
122-
defer cs.subscribersMu.RUnlock()
149+
cs.subscribersMu.Lock()
150+
defer cs.subscribersMu.Unlock()
151+
152+
cs.publishLimiter.Wait(context.Background())
123153

124154
for s := range cs.subscribers {
125155
select {
@@ -133,9 +163,6 @@ func (cs *chatServer) publish(msg []byte) {
133163
// addSubscriber registers a subscriber.
134164
func (cs *chatServer) addSubscriber(s *subscriber) {
135165
cs.subscribersMu.Lock()
136-
if cs.subscribers == nil {
137-
cs.subscribers = make(map[*subscriber]struct{})
138-
}
139166
cs.subscribers[s] = struct{}{}
140167
cs.subscribersMu.Unlock()
141168
}

0 commit comments

Comments
 (0)