Skip to content

Commit 008b616

Browse files
committed
Remove Grace partially
1 parent 07343c2 commit 008b616

File tree

4 files changed

+54
-75
lines changed

4 files changed

+54
-75
lines changed

accept.go

-14
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,6 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
7575
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
7676
defer errd.Wrap(&err, "failed to accept WebSocket connection")
7777

78-
g := graceFromRequest(r)
79-
if g != nil && g.isShuttingdown() {
80-
err := errors.New("server shutting down")
81-
http.Error(w, err.Error(), http.StatusServiceUnavailable)
82-
return nil, err
83-
}
84-
8578
if opts == nil {
8679
opts = &AcceptOptions{}
8780
}
@@ -152,13 +145,6 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
152145
bw: brw.Writer,
153146
})
154147

155-
if g != nil {
156-
err = g.addConn(c)
157-
if err != nil {
158-
return nil, err
159-
}
160-
}
161-
162148
return c, nil
163149
}
164150

conn_notjs.go

-5
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ type Conn struct {
3333
flateThreshold int
3434
br *bufio.Reader
3535
bw *bufio.Writer
36-
g *Grace
3736

3837
readTimeout chan context.Context
3938
writeTimeout chan context.Context
@@ -139,10 +138,6 @@ func (c *Conn) close(err error) {
139138
// closeErr.
140139
c.rwc.Close()
141140

142-
if c.g != nil {
143-
c.g.delConn(c)
144-
}
145-
146141
go func() {
147142
c.msgWriterState.close()
148143

example_echo_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ import (
1818
"nhooyr.io/websocket/wsjson"
1919
)
2020

21+
// TODO IMPROVE CANCELLATION AND SHUTDOWN
22+
// TODO on context cancel send websocket going away and fix the read timeout error to be dependant on context deadline reached.
23+
// TODO this way you cancel your context and the right message automatically gets sent. Furthrmore, then u can just use a simple waitgroup to wait for connections.
24+
// TODO grace is wrong as it doesn't wait for the individual goroutines.
25+
2126
// This example starts a WebSocket echo server,
2227
// dials the server and then sends 5 different messages
2328
// and prints out the server's responses.

grace.go

+49-56
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package websocket
22

33
import (
44
"context"
5-
"errors"
65
"fmt"
76
"net/http"
87
"sync"
@@ -17,98 +16,92 @@ import (
1716
// Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods.
1817
// It's required as net/http's Shutdown and Close methods do not keep track of WebSocket
1918
// connections.
19+
//
20+
// Make sure to Close or Shutdown the *http.Server first as you don't want to accept
21+
// any new connections while the existing websockets are being shut down.
2022
type Grace struct {
21-
mu sync.Mutex
22-
closed bool
23-
shuttingDown bool
24-
conns map[*Conn]struct{}
23+
handlersMu sync.Mutex
24+
closing bool
25+
handlers map[context.Context]context.CancelFunc
2526
}
2627

2728
// Handler returns a handler that wraps around h to record
2829
// all WebSocket connections accepted.
2930
//
3031
// Use Close or Shutdown to gracefully close recorded connections.
32+
// Make sure to Close or Shutdown the *http.Server first.
3133
func (g *Grace) Handler(h http.Handler) http.Handler {
3234
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
33-
ctx := context.WithValue(r.Context(), gracefulContextKey{}, g)
35+
ctx, cancel := context.WithCancel(r.Context())
36+
defer cancel()
37+
3438
r = r.WithContext(ctx)
39+
40+
ok := g.add(w, ctx, cancel)
41+
if !ok {
42+
return
43+
}
44+
defer g.del(ctx)
45+
3546
h.ServeHTTP(w, r)
3647
})
3748
}
3849

39-
func (g *Grace) isShuttingdown() bool {
40-
g.mu.Lock()
41-
defer g.mu.Unlock()
42-
return g.shuttingDown
43-
}
44-
45-
func graceFromRequest(r *http.Request) *Grace {
46-
g, _ := r.Context().Value(gracefulContextKey{}).(*Grace)
47-
return g
48-
}
50+
func (g *Grace) add(w http.ResponseWriter, ctx context.Context, cancel context.CancelFunc) bool {
51+
g.handlersMu.Lock()
52+
defer g.handlersMu.Unlock()
4953

50-
func (g *Grace) addConn(c *Conn) error {
51-
g.mu.Lock()
52-
defer g.mu.Unlock()
53-
if g.closed {
54-
c.Close(StatusGoingAway, "server shutting down")
55-
return errors.New("server shutting down")
54+
if g.closing {
55+
http.Error(w, "shutting down", http.StatusServiceUnavailable)
56+
return false
5657
}
57-
if g.conns == nil {
58-
g.conns = make(map[*Conn]struct{})
58+
59+
if g.handlers == nil {
60+
g.handlers = make(map[context.Context]context.CancelFunc)
5961
}
60-
g.conns[c] = struct{}{}
61-
c.g = g
62-
return nil
63-
}
62+
g.handlers[ctx] = cancel
6463

65-
func (g *Grace) delConn(c *Conn) {
66-
g.mu.Lock()
67-
defer g.mu.Unlock()
68-
delete(g.conns, c)
64+
return true
6965
}
7066

71-
type gracefulContextKey struct{}
67+
func (g *Grace) del(ctx context.Context) {
68+
g.handlersMu.Lock()
69+
defer g.handlersMu.Unlock()
70+
71+
delete(g.handlers, ctx)
72+
}
7273

7374
// Close prevents the acceptance of new connections with
7475
// http.StatusServiceUnavailable and closes all accepted
7576
// connections with StatusGoingAway.
77+
//
78+
// Make sure to Close or Shutdown the *http.Server first.
7679
func (g *Grace) Close() error {
77-
g.mu.Lock()
78-
g.shuttingDown = true
79-
g.closed = true
80-
var wg sync.WaitGroup
81-
for c := range g.conns {
82-
wg.Add(1)
83-
go func(c *Conn) {
84-
defer wg.Done()
85-
c.Close(StatusGoingAway, "server shutting down")
86-
}(c)
87-
88-
delete(g.conns, c)
80+
g.handlersMu.Lock()
81+
for _, cancel := range g.handlers {
82+
cancel()
8983
}
90-
g.mu.Unlock()
84+
g.handlersMu.Unlock()
9185

92-
wg.Wait()
86+
// Wait for all goroutines to exit.
87+
g.Shutdown(context.Background())
9388

9489
return nil
9590
}
9691

9792
// Shutdown prevents the acceptance of new connections and waits until
9893
// all connections close. If the context is cancelled before that, it
9994
// calls Close to close all connections immediately.
95+
//
96+
// Make sure to Close or Shutdown the *http.Server first.
10097
func (g *Grace) Shutdown(ctx context.Context) error {
10198
defer g.Close()
10299

103-
g.mu.Lock()
104-
g.shuttingDown = true
105-
g.mu.Unlock()
106-
107100
// Same poll period used by net/http.
108101
t := time.NewTicker(500 * time.Millisecond)
109102
defer t.Stop()
110103
for {
111-
if g.zeroConns() {
104+
if g.zeroHandlers() {
112105
return nil
113106
}
114107

@@ -120,8 +113,8 @@ func (g *Grace) Shutdown(ctx context.Context) error {
120113
}
121114
}
122115

123-
func (g *Grace) zeroConns() bool {
124-
g.mu.Lock()
125-
defer g.mu.Unlock()
126-
return len(g.conns) == 0
116+
func (g *Grace) zeroHandlers() bool {
117+
g.handlersMu.Lock()
118+
defer g.handlersMu.Unlock()
119+
return len(g.handlers) == 0
127120
}

0 commit comments

Comments
 (0)