Skip to content

Commit c325f46

Browse files
committed
fix: make assert.CollectT concurrency safe
1 parent 486eb6f commit c325f46

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

assert/assertions.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"runtime"
1414
"runtime/debug"
1515
"strings"
16+
"sync"
1617
"time"
1718
"unicode"
1819
"unicode/utf8"
@@ -1862,10 +1863,13 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t
18621863
// CollectT implements the TestingT interface and collects all errors.
18631864
type CollectT struct {
18641865
errors []error
1866+
mu sync.RWMutex
18651867
}
18661868

18671869
// Errorf collects the error.
18681870
func (c *CollectT) Errorf(format string, args ...interface{}) {
1871+
c.mu.Lock()
1872+
defer c.mu.Unlock()
18691873
c.errors = append(c.errors, fmt.Errorf(format, args...))
18701874
}
18711875

@@ -1876,6 +1880,8 @@ func (c *CollectT) FailNow() {
18761880

18771881
// Reset clears the collected errors.
18781882
func (c *CollectT) Reset() {
1883+
c.mu.Lock()
1884+
defer c.mu.Unlock()
18791885
c.errors = nil
18801886
}
18811887

@@ -1884,11 +1890,20 @@ func (c *CollectT) Copy(t TestingT) {
18841890
if tt, ok := t.(tHelper); ok {
18851891
tt.Helper()
18861892
}
1893+
c.mu.RLock()
1894+
defer c.mu.RUnlock()
18871895
for _, err := range c.errors {
18881896
t.Errorf("%v", err)
18891897
}
18901898
}
18911899

1900+
// hasErrors returns true if any errors were collected.
1901+
func (c *CollectT) hasErrors() bool {
1902+
c.mu.RLock()
1903+
defer c.mu.RUnlock()
1904+
return len(c.errors) > 0
1905+
}
1906+
18921907
// EventuallyWithT asserts that given condition will be met in waitFor time,
18931908
// periodically checking target function each tick. In contrast to Eventually,
18941909
// it supplies a CollectT to the condition function, so that the condition
@@ -1931,10 +1946,10 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
19311946
collect.Reset()
19321947
go func() {
19331948
condition(collect)
1934-
ch <- len(collect.errors) == 0
1949+
ch <- collect.hasErrors()
19351950
}()
1936-
case v := <-ch:
1937-
if v {
1951+
case hasErrors := <-ch:
1952+
if !hasErrors {
19381953
return true
19391954
}
19401955
tick = ticker.C

assert/assertions_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2786,6 +2786,13 @@ func TestEventuallyWithTTrue(t *testing.T) {
27862786
Len(t, mockT.errors, 0)
27872787
}
27882788

2789+
func TestEventuallyWithT_ConcurrencySafe(t *testing.T) {
2790+
mockT := new(CollectT)
2791+
EventuallyWithT(mockT, func(c *CollectT) {
2792+
NoError(c, AnError)
2793+
}, time.Millisecond, time.Nanosecond)
2794+
}
2795+
27892796
func TestNeverFalse(t *testing.T) {
27902797
condition := func() bool {
27912798
return false

0 commit comments

Comments
 (0)