1515package test
1616
1717import (
18+ "context"
1819 "runtime"
1920 "strconv"
2021 "strings"
@@ -130,17 +131,26 @@ type ConcurrentT struct {
130131 t require.TestingT
131132 failed bool
132133 failedCh chan struct {}
134+ ctx context.Context
133135
134136 mutex sync.Mutex
135137 stages map [string ]* stage
136138}
137139
138140// NewConcurrent creates a new concurrent testing object.
139141func NewConcurrent (t require.TestingT ) * ConcurrentT {
142+ return NewConcurrentCtx (t , context .Background ())
143+ }
144+
145+ // NewConcurrentCtx creates a new concurrent testing object controlled by a
146+ // context. If that context expires, any ongoing stages and wait calls will
147+ // fail.
148+ func NewConcurrentCtx (t require.TestingT , ctx context.Context ) * ConcurrentT {
140149 return & ConcurrentT {
141150 t : t ,
142151 stages : make (map [string ]* stage ),
143152 failedCh : make (chan struct {}),
153+ ctx : ctx ,
144154 }
145155}
146156
@@ -167,8 +177,10 @@ func (t *ConcurrentT) getStage(name string) *stage {
167177 return s
168178}
169179
170- // Wait waits until the stages and barriers with the requested names terminate.
171- // If any stage or barrier fails, terminates the current goroutine or test.
180+ // Wait waits until the stages and barriers with the requested names
181+ // terminate or the test's context expires. If the context expires, fails the
182+ // test. If any stage or barrier fails, terminates the current goroutine or
183+ // test.
172184func (t * ConcurrentT ) Wait (names ... string ) {
173185 if len (names ) == 0 {
174186 panic ("Wait(): called with 0 names" )
@@ -177,6 +189,8 @@ func (t *ConcurrentT) Wait(names ...string) {
177189 for _ , name := range names {
178190 stage := t .getStage (name )
179191 select {
192+ case <- t .ctx .Done ():
193+ t .FailNow ()
180194 case <- stage .wg .WaitCh ():
181195 if stage .failed .IsSet () {
182196 t .FailNow ()
@@ -209,28 +223,36 @@ func (t *ConcurrentT) FailNow() {
209223// fn must not spawn any goroutines or pass along the T object to goroutines
210224// that call T.Fatal. To achieve this, make other goroutines call
211225// ConcurrentT.StageN() instead.
226+ // If the test's context expires before the call returns, fails the test.
212227func (t * ConcurrentT ) StageN (name string , goroutines int , fn func (ConcT )) {
213228 stage := t .spawnStage (name , goroutines )
214229
215230 stageT := ConcT {TestingT : stage , ct : t }
216- abort := CheckAbort ( func () {
231+ abort , ok := CheckAbortCtx ( t . ctx , func () {
217232 fn (stageT )
218233 })
219234
220- if abort != nil {
221- // Fail the stage, if it had not been marked as such, yet.
222- if stage .failed .TrySet () {
223- defer stage .wg .Done ()
224- }
225- // If it is a panic or Goexit from certain contexts, print stack trace.
226- if _ , ok := abort .(* Panic ); ok || shouldPrintStack (abort .Stack ()) {
227- print ("\n " , abort .String ())
228- }
235+ if ok && abort == nil {
236+ stage .pass ()
237+ t .Wait (name )
238+ return
239+ }
240+
241+ // Fail the stage, if it had not been marked as such, yet.
242+ if stage .failed .TrySet () {
243+ defer stage .wg .Done ()
244+ }
245+
246+ // If it did not terminate, just abort the test.
247+ if ! ok {
229248 t .FailNow ()
230249 }
231250
232- stage .pass ()
233- t .Wait (name )
251+ // If it is a panic or Goexit from certain contexts, print stack trace.
252+ if _ , ok := abort .(* Panic ); ok || shouldPrintStack (abort .Stack ()) {
253+ print ("\n " , abort .String ())
254+ }
255+ t .FailNow ()
234256}
235257
236258func shouldPrintStack (stack string ) bool {
0 commit comments