@@ -18,7 +18,6 @@ import (
18
18
"errors"
19
19
"fmt"
20
20
"math"
21
- "sync"
22
21
23
22
"github.com/blinklabs-io/gouroboros/ledger/common"
24
23
"github.com/blinklabs-io/gouroboros/protocol"
@@ -31,9 +30,13 @@ type Server struct {
31
30
callbackContext CallbackContext
32
31
protoOptions protocol.ProtocolOptions
33
32
ackCount int
34
- requestTxIdsResultChan chan [] TxIdAndSize
33
+ requestTxIdsResultChan chan requestTxIdsResult
35
34
requestTxsResultChan chan []TxBody
36
- onceStart sync.Once
35
+ }
36
+
37
+ type requestTxIdsResult struct {
38
+ txIds []TxIdAndSize
39
+ err error
37
40
}
38
41
39
42
// NewServer returns a new TxSubmission server object
@@ -42,7 +45,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
42
45
config : cfg ,
43
46
// Save this for re-use later
44
47
protoOptions : protoOptions ,
45
- requestTxIdsResultChan : make (chan [] TxIdAndSize ),
48
+ requestTxIdsResultChan : make (chan requestTxIdsResult ),
46
49
requestTxsResultChan : make (chan []TxBody ),
47
50
}
48
51
s .callbackContext = CallbackContext {
@@ -71,15 +74,22 @@ func (s *Server) initProtocol() {
71
74
}
72
75
73
76
func (s * Server ) Start () {
74
- s .onceStart .Do (func () {
75
- s .Protocol .Logger ().
76
- Debug ("starting server protocol" ,
77
- "component" , "network" ,
78
- "protocol" , ProtocolName ,
79
- "connection_id" , s .callbackContext .ConnectionId .String (),
80
- )
81
- s .Protocol .Start ()
82
- })
77
+ s .Protocol .Logger ().
78
+ Debug ("starting server protocol" ,
79
+ "component" , "network" ,
80
+ "protocol" , ProtocolName ,
81
+ "connection_id" , s .callbackContext .ConnectionId .String (),
82
+ )
83
+ s .Protocol .Start ()
84
+ // Start goroutine to cleanup resources on protocol shutdown
85
+ go func () {
86
+ // We create our own vars for these channels since they get replaced on restart
87
+ requestTxIdsResultChan := s .requestTxIdsResultChan
88
+ requestTxsResultChan := s .requestTxsResultChan
89
+ <- s .DoneChan ()
90
+ close (requestTxIdsResultChan )
91
+ close (requestTxsResultChan )
92
+ }()
83
93
}
84
94
85
95
// RequestTxIds requests the next set of TX identifiers from the remote node's mempool
@@ -112,12 +122,16 @@ func (s *Server) RequestTxIds(
112
122
}
113
123
// Wait for result
114
124
select {
115
- case <- s .DoneChan ():
116
- return nil , protocol .ErrProtocolShuttingDown
117
- case txIds := <- s .requestTxIdsResultChan :
125
+ case result , ok := <- s .requestTxIdsResultChan :
126
+ if ! ok {
127
+ return nil , protocol .ErrProtocolShuttingDown
128
+ }
129
+ if result .err != nil {
130
+ return nil , result .err
131
+ }
118
132
// Update ack count for next call
119
- s .ackCount = len (txIds )
120
- return txIds , nil
133
+ s .ackCount = len (result . txIds )
134
+ return result . txIds , nil
121
135
}
122
136
}
123
137
@@ -182,7 +196,9 @@ func (s *Server) handleReplyTxIds(msg protocol.Message) error {
182
196
"connection_id" , s .callbackContext .ConnectionId .String (),
183
197
)
184
198
msgReplyTxIds := msg .(* MsgReplyTxIds )
185
- s .requestTxIdsResultChan <- msgReplyTxIds .TxIds
199
+ s .requestTxIdsResultChan <- requestTxIdsResult {
200
+ txIds : msgReplyTxIds .TxIds ,
201
+ }
186
202
return nil
187
203
}
188
204
@@ -207,6 +223,10 @@ func (s *Server) handleDone() error {
207
223
"role" , "server" ,
208
224
"connection_id" , s .callbackContext .ConnectionId .String (),
209
225
)
226
+ // Signal the RequestTxIds function to stop waiting
227
+ s .requestTxIdsResultChan <- requestTxIdsResult {
228
+ err : ErrStopServerProcess ,
229
+ }
210
230
// Call the user callback function
211
231
if s .config != nil && s .config .DoneFunc != nil {
212
232
if err := s .config .DoneFunc (s .callbackContext ); err != nil {
@@ -216,9 +236,9 @@ func (s *Server) handleDone() error {
216
236
// Restart protocol
217
237
s .Stop ()
218
238
s .initProtocol ()
219
- s .requestTxIdsResultChan = make (chan [] TxIdAndSize )
239
+ s .requestTxIdsResultChan = make (chan requestTxIdsResult )
220
240
s .requestTxsResultChan = make (chan []TxBody )
221
- s .Protocol . Start ()
241
+ s .Start ()
222
242
return nil
223
243
}
224
244
0 commit comments