1515package test
1616
1717import (
18+ "context"
1819 "runtime"
1920 "strconv"
2021 "strings"
@@ -130,17 +131,26 @@ type ConcurrentT struct {
130131 t require.TestingT
131132 failed bool
132133 failedCh chan struct {}
134+ ctx context.Context
133135
134136 mutex sync.Mutex
135137 stages map [string ]* stage
136138}
137139
138140// NewConcurrent creates a new concurrent testing object.
139141func NewConcurrent (t require.TestingT ) * ConcurrentT {
142+ return NewConcurrentCtx (t , context .Background ())
143+ }
144+
145+ // NewConcurrentCtx creates a new concurrent testing object controlled by a
146+ // context. If that context expires, any ongoing stages and wait calls will
147+ // fail.
148+ func NewConcurrentCtx (t require.TestingT , ctx context.Context ) * ConcurrentT {
140149 return & ConcurrentT {
141150 t : t ,
142151 stages : make (map [string ]* stage ),
143152 failedCh : make (chan struct {}),
153+ ctx : ctx ,
144154 }
145155}
146156
@@ -167,8 +177,10 @@ func (t *ConcurrentT) getStage(name string) *stage {
167177 return s
168178}
169179
170- // Wait waits until the stages and barriers with the requested names terminate.
171- // If any stage or barrier fails, terminates the current goroutine or test.
180+ // Wait waits until the stages and barriers with the requested names
181+ // terminate or the test's context expires. If the context expires, fails the
182+ // test. If any stage or barrier fails, terminates the current goroutine or
183+ // test.
172184func (t * ConcurrentT ) Wait (names ... string ) {
173185 if len (names ) == 0 {
174186 panic ("Wait(): called with 0 names" )
@@ -177,6 +189,11 @@ func (t *ConcurrentT) Wait(names ...string) {
177189 for _ , name := range names {
178190 stage := t .getStage (name )
179191 select {
192+ case <- t .ctx .Done ():
193+ t .failNowMutex .Lock ()
194+ t .t .Errorf ("Wait for stage %s: %v" , name , t .ctx .Err ())
195+ t .failNowMutex .Unlock ()
196+ t .FailNow ()
180197 case <- stage .wg .WaitCh ():
181198 if stage .failed .IsSet () {
182199 t .FailNow ()
@@ -209,28 +226,41 @@ func (t *ConcurrentT) FailNow() {
209226// fn must not spawn any goroutines or pass along the T object to goroutines
210227// that call T.Fatal. To achieve this, make other goroutines call
211228// ConcurrentT.StageN() instead.
229+ // If the test's context expires before the call returns, fails the test.
212230func (t * ConcurrentT ) StageN (name string , goroutines int , fn func (ConcT )) {
213231 stage := t .spawnStage (name , goroutines )
214232
215233 stageT := ConcT {TestingT : stage , ct : t }
216- abort := CheckAbort ( func () {
234+ abort , ok := CheckAbortCtx ( t . ctx , func () {
217235 fn (stageT )
218236 })
219237
220- if abort != nil {
221- // Fail the stage, if it had not been marked as such, yet.
222- if stage .failed .TrySet () {
223- defer stage .wg .Done ()
224- }
225- // If it is a panic or Goexit from certain contexts, print stack trace.
226- if _ , ok := abort .(* Panic ); ok || shouldPrintStack (abort .Stack ()) {
227- print ("\n " , abort .String ())
228- }
238+ if ok && abort == nil {
239+ stage .pass ()
240+ t .Wait (name )
241+ return
242+ }
243+
244+ // Fail the stage, if it had not been marked as such, yet.
245+ if stage .failed .TrySet () {
246+ defer stage .wg .Done ()
247+ }
248+
249+ // If it did not terminate, just abort the test.
250+ if ! ok {
251+ t .failNowMutex .Lock ()
252+ t .t .Errorf ("Stage %s: %v" , name , t .ctx .Err ())
253+ t .failNowMutex .Unlock ()
229254 t .FailNow ()
230255 }
231256
232- stage .pass ()
233- t .Wait (name )
257+ // If it is a panic or Goexit from certain contexts, print stack trace.
258+ if _ , ok := abort .(* Panic ); ok || shouldPrintStack (abort .Stack ()) {
259+ t .failNowMutex .Lock ()
260+ t .t .Errorf ("Stage %s: %s" , name , abort .String ())
261+ t .failNowMutex .Unlock ()
262+ }
263+ t .FailNow ()
234264}
235265
236266func shouldPrintStack (stack string ) bool {
0 commit comments