Skip to content

Commit 7329b27

Browse files
committed
Add automated test to chat example
1 parent 6c61818 commit 7329b27

File tree

4 files changed

+183
-32
lines changed

4 files changed

+183
-32
lines changed

chat-example/chat.go

+44-17
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,28 @@ import (
1515

1616
// chatServer enables broadcasting to a set of subscribers.
1717
type chatServer struct {
18+
registerOnce sync.Once
19+
m http.ServeMux
20+
1821
subscribersMu sync.RWMutex
19-
subscribers map[chan<- []byte]struct{}
22+
subscribers map[*subscriber]struct{}
23+
}
24+
25+
// subscriber represents a subscriber.
26+
// Messages are sent on the msgs channel and if the client
27+
// cannot keep up with the messages, closeSlow is called.
28+
type subscriber struct {
29+
msgs chan []byte
30+
closeSlow func()
31+
}
32+
33+
func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
34+
cs.registerOnce.Do(func() {
35+
cs.m.Handle("/", http.FileServer(http.Dir(".")))
36+
cs.m.HandleFunc("/subscribe", cs.subscribeHandler)
37+
cs.m.HandleFunc("/publish", cs.publishHandler)
38+
})
39+
cs.m.ServeHTTP(w, r)
2040
}
2141

2242
// subscribeHandler accepts the WebSocket connection and then subscribes
@@ -57,11 +77,13 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
5777
}
5878

5979
cs.publish(msg)
80+
81+
w.WriteHeader(http.StatusAccepted)
6082
}
6183

6284
// subscribe subscribes the given WebSocket to all broadcast messages.
63-
// It creates a msgs chan with a buffer of 16 to give some room to slower
64-
// connections and then registers it. It then listens for all messages
85+
// It creates a subscriber with a buffered msgs chan to give some room to slower
86+
// connections and then registers the subscriber. It then listens for all messages
6587
// and writes them to the WebSocket. If the context is cancelled or
6688
// an error occurs, it returns and deletes the subscription.
6789
//
@@ -70,13 +92,18 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
7092
func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
7193
ctx = c.CloseRead(ctx)
7294

73-
msgs := make(chan []byte, 16)
74-
cs.addSubscriber(msgs)
75-
defer cs.deleteSubscriber(msgs)
95+
s := &subscriber{
96+
msgs: make(chan []byte, 16),
97+
closeSlow: func() {
98+
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
99+
},
100+
}
101+
cs.addSubscriber(s)
102+
defer cs.deleteSubscriber(s)
76103

77104
for {
78105
select {
79-
case msg := <-msgs:
106+
case msg := <-s.msgs:
80107
err := writeTimeout(ctx, time.Second*5, c, msg)
81108
if err != nil {
82109
return err
@@ -94,29 +121,29 @@ func (cs *chatServer) publish(msg []byte) {
94121
cs.subscribersMu.RLock()
95122
defer cs.subscribersMu.RUnlock()
96123

97-
for c := range cs.subscribers {
124+
for s := range cs.subscribers {
98125
select {
99-
case c <- msg:
126+
case s.msgs <- msg:
100127
default:
128+
go s.closeSlow()
101129
}
102130
}
103131
}
104132

105-
// addSubscriber registers a subscriber with a channel
106-
// on which to send messages.
107-
func (cs *chatServer) addSubscriber(msgs chan<- []byte) {
133+
// addSubscriber registers a subscriber.
134+
func (cs *chatServer) addSubscriber(s *subscriber) {
108135
cs.subscribersMu.Lock()
109136
if cs.subscribers == nil {
110-
cs.subscribers = make(map[chan<- []byte]struct{})
137+
cs.subscribers = make(map[*subscriber]struct{})
111138
}
112-
cs.subscribers[msgs] = struct{}{}
139+
cs.subscribers[s] = struct{}{}
113140
cs.subscribersMu.Unlock()
114141
}
115142

116-
// deleteSubscriber deletes the subscriber with the given msgs channel.
117-
func (cs *chatServer) deleteSubscriber(msgs chan []byte) {
143+
// deleteSubscriber deletes the given subscriber.
144+
func (cs *chatServer) deleteSubscriber(s *subscriber) {
118145
cs.subscribersMu.Lock()
119-
delete(cs.subscribers, msgs)
146+
delete(cs.subscribers, s)
120147
cs.subscribersMu.Unlock()
121148
}
122149

chat-example/chat_test.go

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// +build !js
2+
3+
package main
4+
5+
import (
6+
"context"
7+
"errors"
8+
"fmt"
9+
"net/http"
10+
"net/http/httptest"
11+
"strings"
12+
"testing"
13+
"time"
14+
15+
"nhooyr.io/websocket"
16+
)
17+
18+
func TestGrace(t *testing.T) {
19+
t.Parallel()
20+
21+
var cs chatServer
22+
var g websocket.Grace
23+
s := httptest.NewServer(g.Handler(&cs))
24+
defer s.Close()
25+
defer g.Close()
26+
27+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
28+
defer cancel()
29+
30+
cl1, err := newClient(ctx, s.URL)
31+
assertSuccess(t, err)
32+
defer cl1.Close()
33+
34+
cl2, err := newClient(ctx, s.URL)
35+
assertSuccess(t, err)
36+
defer cl2.Close()
37+
38+
err = cl1.publish(ctx, "hello")
39+
assertSuccess(t, err)
40+
41+
assertReceivedMessage(ctx, cl1, "hello")
42+
assertReceivedMessage(ctx, cl2, "hello")
43+
}
44+
45+
type client struct {
46+
msgs chan string
47+
url string
48+
c *websocket.Conn
49+
}
50+
51+
func newClient(ctx context.Context, url string) (*client, error) {
52+
wsURL := strings.ReplaceAll(url, "http://", "ws://")
53+
c, _, err := websocket.Dial(ctx, wsURL+"/subscribe", nil)
54+
if err != nil {
55+
return nil, err
56+
}
57+
58+
cl := &client{
59+
msgs: make(chan string, 16),
60+
url: url,
61+
c: c,
62+
}
63+
go cl.readLoop()
64+
65+
return cl, nil
66+
}
67+
68+
func (cl *client) readLoop() {
69+
defer cl.c.Close(websocket.StatusInternalError, "")
70+
defer close(cl.msgs)
71+
72+
for {
73+
typ, b, err := cl.c.Read(context.Background())
74+
if err != nil {
75+
return
76+
}
77+
78+
if typ != websocket.MessageText {
79+
cl.c.Close(websocket.StatusUnsupportedData, "expected text message")
80+
return
81+
}
82+
83+
select {
84+
case cl.msgs <- string(b):
85+
default:
86+
cl.c.Close(websocket.StatusInternalError, "messages coming in too fast to handle")
87+
return
88+
}
89+
}
90+
}
91+
92+
func (cl *client) receive(ctx context.Context) (string, error) {
93+
select {
94+
case msg, ok := <-cl.msgs:
95+
if !ok {
96+
return "", errors.New("client closed")
97+
}
98+
return msg, nil
99+
case <-ctx.Done():
100+
return "", ctx.Err()
101+
}
102+
}
103+
104+
func (cl *client) publish(ctx context.Context, msg string) error {
105+
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg))
106+
resp, err := http.DefaultClient.Do(req)
107+
if err != nil {
108+
return err
109+
}
110+
defer resp.Body.Close()
111+
if resp.StatusCode != http.StatusAccepted {
112+
return fmt.Errorf("publish request failed: %v", resp.StatusCode)
113+
}
114+
return nil
115+
}
116+
117+
func (cl *client) Close() error {
118+
return cl.c.Close(websocket.StatusNormalClosure, "")
119+
}
120+
121+
func assertSuccess(t *testing.T, err error) {
122+
t.Helper()
123+
if err != nil {
124+
t.Fatal(err)
125+
}
126+
}
127+
128+
func assertReceivedMessage(ctx context.Context, cl *client, msg string) error {
129+
msg, err := cl.receive(ctx)
130+
if err != nil {
131+
return err
132+
}
133+
if msg != "hello" {
134+
return fmt.Errorf("expected hello but got %q", msg)
135+
}
136+
return nil
137+
}

chat-example/go.mod

-7
This file was deleted.

chat-example/main.go

+2-8
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,10 @@ func run() error {
3535
}
3636
log.Printf("listening on http://%v", l.Addr())
3737

38-
var ws chatServer
39-
40-
m := http.NewServeMux()
41-
m.Handle("/", http.FileServer(http.Dir(".")))
42-
m.HandleFunc("/subscribe", ws.subscribeHandler)
43-
m.HandleFunc("/publish", ws.publishHandler)
44-
38+
var cs chatServer
4539
var g websocket.Grace
4640
s := http.Server{
47-
Handler: g.Handler(m),
41+
Handler: g.Handler(&cs),
4842
ReadTimeout: time.Second * 10,
4943
WriteTimeout: time.Second * 10,
5044
}

0 commit comments

Comments
 (0)