@@ -18,7 +18,6 @@ import (
18
18
"errors"
19
19
"fmt"
20
20
"math"
21
- "sync"
22
21
23
22
"github.com/blinklabs-io/gouroboros/ledger/common"
24
23
"github.com/blinklabs-io/gouroboros/protocol"
@@ -31,9 +30,13 @@ type Server struct {
31
30
callbackContext CallbackContext
32
31
protoOptions protocol.ProtocolOptions
33
32
ackCount int
34
- requestTxIdsResultChan chan [] TxIdAndSize
33
+ requestTxIdsResultChan chan requestTxIdsResult
35
34
requestTxsResultChan chan []TxBody
36
- onceStart sync.Once
35
+ }
36
+
37
+ type requestTxIdsResult struct {
38
+ txIds []TxIdAndSize
39
+ err error
37
40
}
38
41
39
42
// NewServer returns a new TxSubmission server object
@@ -42,7 +45,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
42
45
config : cfg ,
43
46
// Save this for re-use later
44
47
protoOptions : protoOptions ,
45
- requestTxIdsResultChan : make (chan [] TxIdAndSize ),
48
+ requestTxIdsResultChan : make (chan requestTxIdsResult ),
46
49
requestTxsResultChan : make (chan []TxBody ),
47
50
}
48
51
s .callbackContext = CallbackContext {
@@ -71,15 +74,22 @@ func (s *Server) initProtocol() {
71
74
}
72
75
73
76
func (s * Server ) Start () {
74
- s .onceStart .Do (func () {
75
- s .Protocol .Logger ().
76
- Debug ("starting server protocol" ,
77
- "component" , "network" ,
78
- "protocol" , ProtocolName ,
79
- "connection_id" , s .callbackContext .ConnectionId .String (),
80
- )
81
- s .Protocol .Start ()
82
- })
77
+ s .Protocol .Logger ().
78
+ Debug ("starting server protocol" ,
79
+ "component" , "network" ,
80
+ "protocol" , ProtocolName ,
81
+ "connection_id" , s .callbackContext .ConnectionId .String (),
82
+ )
83
+ s .Protocol .Start ()
84
+ // Start goroutine to cleanup resources on protocol shutdown
85
+ go func () {
86
+ // We create our own vars for these channels since they get replaced on restart
87
+ requestTxIdsResultChan := s .requestTxIdsResultChan
88
+ requestTxsResultChan := s .requestTxsResultChan
89
+ <- s .DoneChan ()
90
+ close (requestTxIdsResultChan )
91
+ close (requestTxsResultChan )
92
+ }()
83
93
}
84
94
85
95
// RequestTxIds requests the next set of TX identifiers from the remote node's mempool
@@ -111,14 +121,16 @@ func (s *Server) RequestTxIds(
111
121
return nil , err
112
122
}
113
123
// Wait for result
114
- select {
115
- case <- s . DoneChan ():
124
+ result , ok := <- s . requestTxIdsResultChan
125
+ if ! ok {
116
126
return nil , protocol .ErrProtocolShuttingDown
117
- case txIds := <- s .requestTxIdsResultChan :
118
- // Update ack count for next call
119
- s .ackCount = len (txIds )
120
- return txIds , nil
121
127
}
128
+ if result .err != nil {
129
+ return nil , result .err
130
+ }
131
+ // Update ack count for next call
132
+ s .ackCount = len (result .txIds )
133
+ return result .txIds , nil
122
134
}
123
135
124
136
// RequestTxs requests the content of the requested TX identifiers from the remote node's mempool
@@ -182,7 +194,9 @@ func (s *Server) handleReplyTxIds(msg protocol.Message) error {
182
194
"connection_id" , s .callbackContext .ConnectionId .String (),
183
195
)
184
196
msgReplyTxIds := msg .(* MsgReplyTxIds )
185
- s .requestTxIdsResultChan <- msgReplyTxIds .TxIds
197
+ s .requestTxIdsResultChan <- requestTxIdsResult {
198
+ txIds : msgReplyTxIds .TxIds ,
199
+ }
186
200
return nil
187
201
}
188
202
@@ -207,6 +221,10 @@ func (s *Server) handleDone() error {
207
221
"role" , "server" ,
208
222
"connection_id" , s .callbackContext .ConnectionId .String (),
209
223
)
224
+ // Signal the RequestTxIds function to stop waiting
225
+ s .requestTxIdsResultChan <- requestTxIdsResult {
226
+ err : ErrStopServerProcess ,
227
+ }
210
228
// Call the user callback function
211
229
if s .config != nil && s .config .DoneFunc != nil {
212
230
if err := s .config .DoneFunc (s .callbackContext ); err != nil {
@@ -216,9 +234,9 @@ func (s *Server) handleDone() error {
216
234
// Restart protocol
217
235
s .Stop ()
218
236
s .initProtocol ()
219
- s .requestTxIdsResultChan = make (chan [] TxIdAndSize )
237
+ s .requestTxIdsResultChan = make (chan requestTxIdsResult )
220
238
s .requestTxsResultChan = make (chan []TxBody )
221
- s .Protocol . Start ()
239
+ s .Start ()
222
240
return nil
223
241
}
224
242
0 commit comments