Skip to content

Commit bf9e2ba

Browse files
committed
closing handshake for uni_websocket
1 parent a641e49 commit bf9e2ba

File tree

6 files changed

+298
-27
lines changed

6 files changed

+298
-27
lines changed

internal/configtypes/types.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,15 @@ type UniWebSocket struct {
121121
WriteBufferSize int `mapstructure:"write_buffer_size" json:"write_buffer_size" envconfig:"write_buffer_size" yaml:"write_buffer_size" toml:"write_buffer_size"`
122122
WriteTimeout Duration `mapstructure:"write_timeout" json:"write_timeout" envconfig:"write_timeout" default:"1000ms" yaml:"write_timeout" toml:"write_timeout"`
123123
MessageSizeLimit int `mapstructure:"message_size_limit" json:"message_size_limit" envconfig:"message_size_limit" default:"65536" yaml:"message_size_limit" toml:"message_size_limit"`
124-
124+
// DisableClosingHandshake disables WebSocket closing handshake. This restores the behavior prior to
125+
// Centrifugo v6.5.1 where server never sent a close frame on connection close initiated by server.
126+
// Normally closing handshake is recommended to be performed according to WebSocket protocol RFC,
127+
// so this option is useful only in some specific cases when you need to restore the previous behavior.
128+
DisableClosingHandshake bool `mapstructure:"disable_closing_handshake" json:"disable_closing_handshake" envconfig:"disable_closing_handshake" yaml:"disable_closing_handshake" toml:"disable_closing_handshake"`
129+
// DisableDisconnectPush disables sending disconnect push messages to clients. It's sent by default to make
130+
// unidirectional transports similar, but since uni_websocket transport also sends close frame to the client
131+
// with the same code/reason – some users may want to disable disconnect push to avoid ambiguity.
132+
DisableDisconnectPush bool `mapstructure:"disable_disconnect_push" json:"disable_disconnect_push" envconfig:"disable_disconnect_push" yaml:"disable_disconnect_push" toml:"disable_disconnect_push"`
125133
// JoinPushMessages when enabled allow uni_websocket transport to join messages together into
126134
// one frame using Centrifugal client protocol delimiters: new line for JSON protocol and
127135
// length-prefixed format for Protobuf protocol. This can be useful to reduce system call

internal/timers/pool.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package timers
2+
3+
import (
4+
"sync"
5+
"time"
6+
)
7+
8+
var timerPool sync.Pool
9+
10+
// AcquireTimer from pool.
11+
func AcquireTimer(d time.Duration) *time.Timer {
12+
v := timerPool.Get()
13+
if v == nil {
14+
return time.NewTimer(d)
15+
}
16+
tm := v.(*time.Timer)
17+
if tm.Reset(d) {
18+
panic("Received an active timer from the pool!")
19+
}
20+
return tm
21+
}
22+
23+
// ReleaseTimer to pool.
24+
func ReleaseTimer(tm *time.Timer) {
25+
if !tm.Stop() {
26+
// Collect possibly added time from the channel
27+
// If timer has been stopped and nobody collected its value.
28+
select {
29+
case <-tm.C:
30+
default:
31+
}
32+
}
33+
timerPool.Put(tm)
34+
}

internal/timers/pool_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package timers
2+
3+
import (
4+
"sync"
5+
"testing"
6+
"time"
7+
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
// waitTimer waits for a timer to fire or a timeout, returning true if fired.
12+
func waitTimer(tm *time.Timer, timeout time.Duration) bool {
13+
select {
14+
case <-tm.C:
15+
return true
16+
case <-time.After(timeout):
17+
return false
18+
}
19+
}
20+
21+
func TestTimersReliable(t *testing.T) {
22+
t.Run("acquire timer does not fire immediately", func(t *testing.T) {
23+
tm := AcquireTimer(50 * time.Millisecond)
24+
require.NotNil(t, tm)
25+
26+
fired := waitTimer(tm, 10*time.Millisecond)
27+
require.False(t, fired, "timer should not have fired immediately")
28+
29+
ReleaseTimer(tm)
30+
})
31+
32+
t.Run("timer fires after duration", func(t *testing.T) {
33+
tm := AcquireTimer(20 * time.Millisecond)
34+
require.NotNil(t, tm)
35+
36+
fired := waitTimer(tm, 100*time.Millisecond)
37+
require.True(t, fired, "timer did not fire as expected")
38+
39+
ReleaseTimer(tm)
40+
})
41+
42+
t.Run("release stops timer", func(t *testing.T) {
43+
tm := AcquireTimer(50 * time.Millisecond)
44+
require.NotNil(t, tm)
45+
46+
ReleaseTimer(tm)
47+
48+
fired := waitTimer(tm, 60*time.Millisecond)
49+
require.False(t, fired, "timer should have been stopped")
50+
})
51+
52+
t.Run("reuse timer from pool", func(t *testing.T) {
53+
// Acquire and release a timer to populate the pool
54+
tm1 := AcquireTimer(50 * time.Millisecond)
55+
ReleaseTimer(tm1)
56+
57+
// Acquire again from pool
58+
tm2 := AcquireTimer(20 * time.Millisecond)
59+
require.NotNil(t, tm2)
60+
61+
fired := waitTimer(tm2, 100*time.Millisecond)
62+
require.True(t, fired, "reused timer should have fired")
63+
64+
ReleaseTimer(tm2)
65+
})
66+
67+
t.Run("concurrent acquire and release", func(t *testing.T) {
68+
var wg sync.WaitGroup
69+
numGoroutines := 50
70+
71+
wg.Add(numGoroutines)
72+
for i := 0; i < numGoroutines; i++ {
73+
go func() {
74+
defer wg.Done()
75+
tm := AcquireTimer(10 * time.Millisecond)
76+
// deterministically wait for timer to fire or grace period
77+
waitTimer(tm, 20*time.Millisecond)
78+
ReleaseTimer(tm)
79+
}()
80+
}
81+
82+
wg.Wait()
83+
})
84+
85+
t.Run("release timer that already fired", func(t *testing.T) {
86+
tm := AcquireTimer(10 * time.Millisecond)
87+
require.NotNil(t, tm)
88+
89+
fired := waitTimer(tm, 50*time.Millisecond)
90+
require.True(t, fired, "timer should have fired")
91+
92+
ReleaseTimer(tm)
93+
94+
// Acquire again to ensure pool is working
95+
tm2 := AcquireTimer(20 * time.Millisecond)
96+
require.NotNil(t, tm2)
97+
ReleaseTimer(tm2)
98+
})
99+
}

internal/uniws/handler.go

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,15 @@ func (s *Handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
112112
// Separate goroutine for better GC of caller's data.
113113
go func() {
114114
opts := websocketTransportOptions{
115-
framePingInterval: framePingInterval,
116-
framePongTimeout: framePongTimeout,
117-
writeTimeout: writeTimeout,
118-
compressionMinSize: compressionMinSize,
119-
pingPongConfig: s.pingPong,
120-
joinMessages: s.config.JoinPushMessages,
121-
protoMajor: r.ProtoMajor,
115+
framePingInterval: framePingInterval,
116+
framePongTimeout: framePongTimeout,
117+
writeTimeout: writeTimeout,
118+
compressionMinSize: compressionMinSize,
119+
pingPongConfig: s.pingPong,
120+
joinMessages: s.config.JoinPushMessages,
121+
protoMajor: r.ProtoMajor,
122+
disableClosingHandshake: s.config.DisableClosingHandshake,
123+
disableDisconnectPush: s.config.DisableDisconnectPush,
122124
}
123125

124126
graceCh := make(chan struct{})
@@ -156,14 +158,45 @@ func (s *Handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
156158
return nil
157159
})
158160

161+
waitClose := func() {
162+
// https://github.com/gorilla/websocket/issues/448
163+
conn.SetPingHandler(nil)
164+
conn.SetPongHandler(nil)
165+
if s.config.DisableClosingHandshake {
166+
close(graceCh)
167+
return
168+
}
169+
_ = conn.SetReadDeadline(time.Now().Add(closeFrameWait))
170+
for {
171+
if _, _, err := conn.NextReader(); err != nil {
172+
close(graceCh)
173+
return
174+
}
175+
}
176+
}
177+
159178
if req == nil {
160179
_, data, err := conn.ReadMessage()
161180
if err != nil {
181+
waitClose()
162182
return
163183
}
164184
_, err = json.Parse(data, &req, json.ZeroCopy)
165185
if err != nil {
166186
log.Info().Err(err).Str("transport", transportName).Msg("error unmarshalling connect request")
187+
if !s.config.DisableClosingHandshake {
188+
err = conn.WriteControl(
189+
websocket.CloseMessage,
190+
websocket.FormatCloseMessage(
191+
int(centrifuge.DisconnectBadRequest.Code),
192+
centrifuge.DisconnectBadRequest.Reason,
193+
),
194+
time.Now().Add(writeTimeout))
195+
if err != nil {
196+
return
197+
}
198+
}
199+
waitClose()
167200
return
168201
}
169202
}
@@ -176,16 +209,6 @@ func (s *Handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
176209
break
177210
}
178211
}
179-
180-
// https://github.com/gorilla/websocket/issues/448
181-
conn.SetPingHandler(nil)
182-
conn.SetPongHandler(nil)
183-
_ = conn.SetReadDeadline(time.Now().Add(closeFrameWait))
184-
for {
185-
if _, _, err := conn.NextReader(); err != nil {
186-
close(graceCh)
187-
break
188-
}
189-
}
212+
waitClose()
190213
}()
191214
}

internal/uniws/handler_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package uniws
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"net/http"
78
"net/http/httptest"
89
"net/url"
@@ -192,3 +193,79 @@ func TestUnidirectionalWebSocket(t *testing.T) {
192193
require.True(t, pingReceived, "Expected to receive ping message")
193194
})
194195
}
196+
197+
func TestUnidirectionalWebSocket_CloseFrameSent(t *testing.T) {
198+
t.Parallel()
199+
200+
// Create a custom node to trigger server-side disconnect.
201+
testNode, err := centrifuge.New(centrifuge.Config{})
202+
require.NoError(t, err)
203+
t.Cleanup(func() { _ = testNode.Shutdown(context.Background()) })
204+
205+
testNode.OnConnecting(func(ctx context.Context, event centrifuge.ConnectEvent) (centrifuge.ConnectReply, error) {
206+
return centrifuge.ConnectReply{
207+
Credentials: &centrifuge.Credentials{},
208+
}, nil
209+
})
210+
211+
testNode.OnConnect(func(client *centrifuge.Client) {
212+
client.Disconnect(centrifuge.DisconnectConnectionLimit)
213+
})
214+
215+
testHandler := NewHandler(testNode, Config{}, func(r *http.Request) bool {
216+
return true
217+
}, centrifuge.PingPongConfig{})
218+
219+
testServer := httptest.NewServer(middleware.LogRequest(testHandler))
220+
t.Cleanup(func() { testServer.Close() })
221+
222+
testWsURL := "ws" + strings.TrimPrefix(testServer.URL, "http")
223+
224+
// Wait for test server to start.
225+
for {
226+
resp, err := http.Get(testServer.URL)
227+
if err != nil {
228+
time.Sleep(100 * time.Millisecond)
229+
continue
230+
}
231+
_ = resp.Body.Close()
232+
break
233+
}
234+
235+
dialer := websocket.Dialer{}
236+
conn, _, _, err := dialer.Dial(testWsURL, nil)
237+
require.NoError(t, err)
238+
defer func() { _ = conn.Close() }()
239+
240+
// Send connect request.
241+
err = conn.WriteMessage(websocket.TextMessage, []byte(`{}`))
242+
require.NoError(t, err)
243+
244+
// Read connect reply.
245+
_, data, err := conn.ReadMessage()
246+
require.NoError(t, err)
247+
ensureMessageHasClient(t, data)
248+
249+
require.NoError(t, conn.SetReadDeadline(time.Now().Add(10*time.Second)))
250+
251+
// Read until connection is closed - server should send close frame.
252+
var closeErr *websocket.CloseError
253+
for {
254+
_, _, err := conn.ReadMessage()
255+
if err != nil {
256+
t.Logf("ReadMessage error: %v", err)
257+
closeErr = &websocket.CloseError{}
258+
if ok := errors.As(err, &closeErr); ok {
259+
t.Logf("Close frame received: code=%d, text=%s", closeErr.Code, closeErr.Text)
260+
break
261+
}
262+
// Some other error occurred.
263+
break
264+
}
265+
}
266+
267+
// Verify close frame was received with correct code and reason.
268+
require.NotNil(t, closeErr, "Expected to receive close frame")
269+
require.Equal(t, centrifuge.DisconnectConnectionLimit.Code, uint32(closeErr.Code))
270+
require.Equal(t, centrifuge.DisconnectConnectionLimit.Reason, closeErr.Text)
271+
}

0 commit comments

Comments
 (0)