Skip to content

Commit cf5d407

Browse files
committed
refactor: allow cleanly stopping TxSubmission server
Fixes #1075 Signed-off-by: Aurora Gaffney <[email protected]>
1 parent f221d8c commit cf5d407

File tree

1 file changed

+41
-21
lines changed

1 file changed

+41
-21
lines changed

protocol/txsubmission/server.go

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"errors"
1919
"fmt"
2020
"math"
21-
"sync"
2221

2322
"github.com/blinklabs-io/gouroboros/ledger/common"
2423
"github.com/blinklabs-io/gouroboros/protocol"
@@ -31,9 +30,13 @@ type Server struct {
3130
callbackContext CallbackContext
3231
protoOptions protocol.ProtocolOptions
3332
ackCount int
34-
requestTxIdsResultChan chan []TxIdAndSize
33+
requestTxIdsResultChan chan requestTxIdsResult
3534
requestTxsResultChan chan []TxBody
36-
onceStart sync.Once
35+
}
36+
37+
type requestTxIdsResult struct {
38+
txIds []TxIdAndSize
39+
err error
3740
}
3841

3942
// NewServer returns a new TxSubmission server object
@@ -42,7 +45,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
4245
config: cfg,
4346
// Save this for re-use later
4447
protoOptions: protoOptions,
45-
requestTxIdsResultChan: make(chan []TxIdAndSize),
48+
requestTxIdsResultChan: make(chan requestTxIdsResult),
4649
requestTxsResultChan: make(chan []TxBody),
4750
}
4851
s.callbackContext = CallbackContext{
@@ -71,15 +74,22 @@ func (s *Server) initProtocol() {
7174
}
7275

7376
func (s *Server) Start() {
74-
s.onceStart.Do(func() {
75-
s.Protocol.Logger().
76-
Debug("starting server protocol",
77-
"component", "network",
78-
"protocol", ProtocolName,
79-
"connection_id", s.callbackContext.ConnectionId.String(),
80-
)
81-
s.Protocol.Start()
82-
})
77+
s.Protocol.Logger().
78+
Debug("starting server protocol",
79+
"component", "network",
80+
"protocol", ProtocolName,
81+
"connection_id", s.callbackContext.ConnectionId.String(),
82+
)
83+
s.Protocol.Start()
84+
// Start goroutine to cleanup resources on protocol shutdown
85+
go func() {
86+
// We create our own vars for these channels since they get replaced on restart
87+
requestTxIdsResultChan := s.requestTxIdsResultChan
88+
requestTxsResultChan := s.requestTxsResultChan
89+
<-s.DoneChan()
90+
close(requestTxIdsResultChan)
91+
close(requestTxsResultChan)
92+
}()
8393
}
8494

8595
// RequestTxIds requests the next set of TX identifiers from the remote node's mempool
@@ -112,12 +122,16 @@ func (s *Server) RequestTxIds(
112122
}
113123
// Wait for result
114124
select {
115-
case <-s.DoneChan():
116-
return nil, protocol.ErrProtocolShuttingDown
117-
case txIds := <-s.requestTxIdsResultChan:
125+
case result, ok := <-s.requestTxIdsResultChan:
126+
if !ok {
127+
return nil, protocol.ErrProtocolShuttingDown
128+
}
129+
if result.err != nil {
130+
return nil, result.err
131+
}
118132
// Update ack count for next call
119-
s.ackCount = len(txIds)
120-
return txIds, nil
133+
s.ackCount = len(result.txIds)
134+
return result.txIds, nil
121135
}
122136
}
123137

@@ -182,7 +196,9 @@ func (s *Server) handleReplyTxIds(msg protocol.Message) error {
182196
"connection_id", s.callbackContext.ConnectionId.String(),
183197
)
184198
msgReplyTxIds := msg.(*MsgReplyTxIds)
185-
s.requestTxIdsResultChan <- msgReplyTxIds.TxIds
199+
s.requestTxIdsResultChan <- requestTxIdsResult{
200+
txIds: msgReplyTxIds.TxIds,
201+
}
186202
return nil
187203
}
188204

@@ -207,6 +223,10 @@ func (s *Server) handleDone() error {
207223
"role", "server",
208224
"connection_id", s.callbackContext.ConnectionId.String(),
209225
)
226+
// Signal the RequestTxIds function to stop waiting
227+
s.requestTxIdsResultChan <- requestTxIdsResult{
228+
err: ErrStopServerProcess,
229+
}
210230
// Call the user callback function
211231
if s.config != nil && s.config.DoneFunc != nil {
212232
if err := s.config.DoneFunc(s.callbackContext); err != nil {
@@ -216,9 +236,9 @@ func (s *Server) handleDone() error {
216236
// Restart protocol
217237
s.Stop()
218238
s.initProtocol()
219-
s.requestTxIdsResultChan = make(chan []TxIdAndSize)
239+
s.requestTxIdsResultChan = make(chan requestTxIdsResult)
220240
s.requestTxsResultChan = make(chan []TxBody)
221-
s.Protocol.Start()
241+
s.Start()
222242
return nil
223243
}
224244

0 commit comments

Comments
 (0)