Skip to content

Commit c4b8870

Browse files
authored
fix(finality-grandpa): eliminate data races in voter, timer, and test infrastructure (#4849)
1 parent 04cb96b commit c4b8870

5 files changed

Lines changed: 212 additions & 55 deletions

File tree

pkg/finality-grandpa/environment_test.go

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -187,63 +187,108 @@ func (*environment) PrecommitEquivocation(
187187

188188
// p2p network data for a round.
189189
type BroadcastNetwork[M, N any] struct {
190-
receiver chan M
191-
senders []chan M
192-
history []M
193-
routing bool
194-
wg sync.WaitGroup
190+
receiver chan M
191+
stop chan struct{}
192+
mu sync.Mutex
193+
senders []chan M
194+
history []M
195+
routing bool
196+
stopped bool
197+
routeWG sync.WaitGroup
198+
forwarderWG sync.WaitGroup
195199
}
196200

197201
func NewBroadcastNetwork[M, N any]() *BroadcastNetwork[M, N] {
198202
bn := BroadcastNetwork[M, N]{
199203
receiver: make(chan M, 10000),
204+
stop: make(chan struct{}),
200205
}
201206
return &bn
202207
}
203208

204209
func (bm *BroadcastNetwork[M, N]) SendMessage(message M) {
205-
bm.receiver <- message
210+
select {
211+
case bm.receiver <- message:
212+
case <-bm.stop:
213+
}
206214
}
207215

208216
func (bm *BroadcastNetwork[M, N]) AddNode(f func(N) M, out chan N) (in chan M) {
209217
// buffer to 100 messages for now
210218
in = make(chan M, 10000)
211219

220+
bm.mu.Lock()
212221
// get history to the node.
213222
for _, priorMessage := range bm.history {
214223
in <- priorMessage
215224
}
216-
217225
bm.senders = append(bm.senders, in)
218-
219-
if !bm.routing {
226+
startRoute := !bm.routing
227+
if startRoute {
220228
bm.routing = true
221-
bm.wg.Add(1)
229+
bm.routeWG.Add(1)
230+
}
231+
bm.mu.Unlock()
232+
233+
if startRoute {
222234
go bm.route()
223235
}
224236

237+
bm.forwarderWG.Add(1)
225238
go func() {
226-
for n := range out {
227-
bm.receiver <- f(n)
239+
defer bm.forwarderWG.Done()
240+
for {
241+
select {
242+
case n, ok := <-out:
243+
if !ok {
244+
return
245+
}
246+
select {
247+
case bm.receiver <- f(n):
248+
case <-bm.stop:
249+
return
250+
}
251+
case <-bm.stop:
252+
return
253+
}
228254
}
229255
}()
230256
return in
231257
}
232258

233259
func (bm *BroadcastNetwork[M, N]) route() {
234-
defer bm.wg.Done()
260+
defer bm.routeWG.Done()
235261
for msg := range bm.receiver {
262+
bm.mu.Lock()
236263
bm.history = append(bm.history, msg)
237-
for _, sender := range bm.senders {
264+
senders := append([]chan M(nil), bm.senders...)
265+
bm.mu.Unlock()
266+
for _, sender := range senders {
238267
sender <- msg
239268
}
240269
}
241270
}
242271

243272
func (bm *BroadcastNetwork[M, N]) Stop() {
273+
bm.mu.Lock()
274+
if bm.stopped {
275+
bm.mu.Unlock()
276+
return
277+
}
278+
bm.stopped = true
279+
close(bm.stop)
280+
bm.mu.Unlock()
281+
282+
// Order matters: drain forwarders first so they stop sending into receiver,
283+
// then close receiver so route can exit, then close per-node senders.
284+
bm.forwarderWG.Wait()
244285
close(bm.receiver)
245-
bm.wg.Wait()
246-
for _, sender := range bm.senders {
286+
bm.routeWG.Wait()
287+
bm.mu.Lock()
288+
senders := bm.senders
289+
bm.senders = nil
290+
bm.mu.Unlock()
291+
for _, sender := range senders {
247292
close(sender)
248293
}
249294
}

pkg/finality-grandpa/timer.go

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ package grandpa
55

66
import (
77
"sync"
8+
"sync/atomic"
89
"time"
910
)
1011

1112
type timer struct {
1213
wakerChan *wakerChan[error]
13-
mtx sync.Mutex
14-
expired bool
14+
closeOnce sync.Once
15+
expired atomic.Bool
1516
}
1617

1718
func newTimer(in <-chan time.Time) *timer {
@@ -24,29 +25,23 @@ func newTimer(in <-chan time.Time) *timer {
2425

2526
func (t *timer) poll(in <-chan time.Time) {
2627
<-in
27-
t.mtx.Lock()
28-
defer t.mtx.Unlock()
29-
if t.wakerChan.in != nil {
28+
t.closeOnce.Do(func() {
3029
t.wakerChan.in <- nil
3130
close(t.wakerChan.in)
32-
t.wakerChan.in = nil
33-
}
34-
t.expired = true
31+
})
32+
t.expired.Store(true)
3533
}
3634

3735
func (t *timer) SetWaker(waker *waker) {
3836
t.wakerChan.setWaker(waker)
3937
}
4038

4139
func (t *timer) Elapsed() (bool, error) {
42-
return t.expired, nil
40+
return t.expired.Load(), nil
4341
}
4442

4543
func (t *timer) Close() {
46-
t.mtx.Lock()
47-
defer t.mtx.Unlock()
48-
if t.wakerChan.in != nil {
44+
t.closeOnce.Do(func() {
4945
close(t.wakerChan.in)
50-
t.wakerChan.in = nil
51-
}
46+
})
5247
}

pkg/finality-grandpa/timer_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Copyright 2025 ChainSafe Systems (ON)
2+
// SPDX-License-Identifier: LGPL-3.0-only
3+
4+
package grandpa
5+
6+
import (
7+
"sync"
8+
"testing"
9+
"time"
10+
)
11+
12+
// TestTimer_ElapsedConcurrentWithFiring exercises the read of `expired` from
13+
// one goroutine while the timer's `poll` goroutine is writing it. With the
14+
// pre-fix code (unsynchronized read in Elapsed) this test trips the race
15+
// detector under `go test -race`.
16+
func TestTimer_ElapsedConcurrentWithFiring(t *testing.T) {
17+
t.Parallel()
18+
tick := make(chan time.Time, 1)
19+
timer := newTimer(tick)
20+
21+
stop := make(chan struct{})
22+
var wg sync.WaitGroup
23+
wg.Add(1)
24+
go func() {
25+
defer wg.Done()
26+
for {
27+
select {
28+
case <-stop:
29+
return
30+
default:
31+
_, _ = timer.Elapsed()
32+
}
33+
}
34+
}()
35+
36+
tick <- time.Now()
37+
close(tick)
38+
39+
deadline := time.Now().Add(2 * time.Second)
40+
for {
41+
elapsed, _ := timer.Elapsed()
42+
if elapsed {
43+
break
44+
}
45+
if time.Now().After(deadline) {
46+
close(stop)
47+
wg.Wait()
48+
t.Fatal("timer never reported elapsed after the tick was consumed")
49+
}
50+
time.Sleep(time.Millisecond)
51+
}
52+
close(stop)
53+
wg.Wait()
54+
}
55+
56+
// TestTimer_CloseIsIdempotent ensures Close() can be called more than once
57+
// (and after poll has already drained the channel) without panicking — the
58+
// `closed` flag prevents the double-close.
59+
func TestTimer_CloseIsIdempotent(t *testing.T) {
60+
t.Parallel()
61+
tick := make(chan time.Time)
62+
timer := newTimer(tick)
63+
64+
timer.Close()
65+
timer.Close()
66+
}
67+
68+
// TestWakerChan_SetWakerConcurrentWithItems exercises the write of `waker`
69+
// from one goroutine while the `start` goroutine is reading it on every item.
70+
// With the pre-fix code (plain *waker field) this trips the race detector.
71+
func TestWakerChan_SetWakerConcurrentWithItems(t *testing.T) {
72+
t.Parallel()
73+
in := make(chan int, 100)
74+
wc := newWakerChan(in)
75+
76+
w1 := &waker{wakeCh: make(chan struct{}, 1000)}
77+
w2 := &waker{wakeCh: make(chan struct{}, 1000)}
78+
79+
// Drain the output channel so start() can keep making progress.
80+
drained := make(chan struct{})
81+
go func() {
82+
defer close(drained)
83+
for range wc.channel() {
84+
}
85+
}()
86+
87+
var wg sync.WaitGroup
88+
wg.Add(2)
89+
90+
go func() {
91+
defer wg.Done()
92+
for i := 0; i < 500; i++ {
93+
in <- i
94+
}
95+
}()
96+
97+
go func() {
98+
defer wg.Done()
99+
for i := 0; i < 500; i++ {
100+
if i%2 == 0 {
101+
wc.setWaker(w1)
102+
} else {
103+
wc.setWaker(w2)
104+
}
105+
}
106+
}()
107+
108+
wg.Wait()
109+
close(in)
110+
<-drained
111+
}

0 commit comments

Comments
 (0)