Skip to content

Commit e3729a6

Browse files
committed
🔀 Merge branch 'errgatherer-onfail' into 'dev'
[pkg/errors] Add Gatherer.OnFail hooks See merge request perun/go-perun!395 Signed-off-by: Steffen Rattay <steffen@perun.network>
2 parents 28e535b + 7e80d09 commit e3729a6

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

pkg/errors/gatherer.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ type Gatherer struct {
3535
errs accumulatedError
3636
wg sync.WaitGroup
3737

38-
failed chan struct{} // Closed when an error has occurred.
38+
onFails []func()
39+
failed chan struct{} // Closed when an error has occurred.
3940
}
4041

4142
// Failed returns a channel that is closed when an error occurs.
@@ -56,8 +57,14 @@ func (g *Gatherer) Add(err error) {
5657

5758
select {
5859
case <-g.failed:
60+
return
5961
default:
60-
close(g.failed)
62+
}
63+
64+
close(g.failed)
65+
66+
for _, fn := range g.onFails {
67+
fn()
6168
}
6269
}
6370

@@ -87,6 +94,15 @@ func (g *Gatherer) Err() error {
8794
return g.errs
8895
}
8996

97+
// OnFail adds fn to the list of functions that are executed right after any
98+
// non-nil error is added with Add (or any routine started with Go failed). The
99+
// functions are guaranteed to be executed in the order that they were added.
100+
//
101+
// The channel returned by Failed is closed before those functions are executed.
102+
func (g *Gatherer) OnFail(fn func()) {
103+
g.onFails = append(g.onFails, fn)
104+
}
105+
90106
// stackTracer is taken from the github.com/pkg/errors documentation.
91107
type stackTracer interface {
92108
StackTrace() errors.StackTrace

pkg/errors/gatherer_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,36 @@ func TestAccumulatedError_StackTrace(t *testing.T) {
112112
g.Add(pkgerrors.New("2"))
113113
assert.NotNil(t, g.Err().(stackTracer).StackTrace())
114114
}
115+
116+
func TestGatherer_OnFail(t *testing.T) {
117+
var (
118+
assert = assert.New(t)
119+
g = errors.NewGatherer()
120+
called bool
121+
called2 bool
122+
)
123+
124+
g.OnFail(func() {
125+
select {
126+
case <-g.Failed():
127+
case <-time.After(time.Second):
128+
assert.Fail("Failed not closed before OnFail hooks are executed")
129+
}
130+
131+
called = true
132+
assert.False(called2)
133+
})
134+
135+
g.OnFail(func() {
136+
assert.True(called)
137+
called2 = true
138+
})
139+
140+
g.Add(nil)
141+
assert.False(called)
142+
assert.False(called2)
143+
144+
g.Add(stderrors.New("error"))
145+
assert.True(called)
146+
assert.True(called2)
147+
}

0 commit comments

Comments
 (0)