Skip to content

refactor: allow cleanly stopping TxSubmission server #1076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 10, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions protocol/txsubmission/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"errors"
"fmt"
"math"
"sync"

"github.com/blinklabs-io/gouroboros/ledger/common"
"github.com/blinklabs-io/gouroboros/protocol"
Expand All @@ -31,9 +30,13 @@ type Server struct {
callbackContext CallbackContext
protoOptions protocol.ProtocolOptions
ackCount int
requestTxIdsResultChan chan []TxIdAndSize
requestTxIdsResultChan chan requestTxIdsResult
requestTxsResultChan chan []TxBody
onceStart sync.Once
}

type requestTxIdsResult struct {
txIds []TxIdAndSize
err error
}

// NewServer returns a new TxSubmission server object
Expand All @@ -42,7 +45,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
config: cfg,
// Save this for re-use later
protoOptions: protoOptions,
requestTxIdsResultChan: make(chan []TxIdAndSize),
requestTxIdsResultChan: make(chan requestTxIdsResult),
requestTxsResultChan: make(chan []TxBody),
}
s.callbackContext = CallbackContext{
Expand Down Expand Up @@ -71,15 +74,22 @@ func (s *Server) initProtocol() {
}

func (s *Server) Start() {
s.onceStart.Do(func() {
s.Protocol.Logger().
Debug("starting server protocol",
"component", "network",
"protocol", ProtocolName,
"connection_id", s.callbackContext.ConnectionId.String(),
)
s.Protocol.Start()
})
s.Protocol.Logger().
Debug("starting server protocol",
"component", "network",
"protocol", ProtocolName,
"connection_id", s.callbackContext.ConnectionId.String(),
)
s.Protocol.Start()
// Start goroutine to cleanup resources on protocol shutdown
go func() {
// We create our own vars for these channels since they get replaced on restart
requestTxIdsResultChan := s.requestTxIdsResultChan
requestTxsResultChan := s.requestTxsResultChan
<-s.DoneChan()
close(requestTxIdsResultChan)
close(requestTxsResultChan)
}()
}

// RequestTxIds requests the next set of TX identifiers from the remote node's mempool
Expand Down Expand Up @@ -111,14 +121,16 @@ func (s *Server) RequestTxIds(
return nil, err
}
// Wait for result
select {
case <-s.DoneChan():
result, ok := <-s.requestTxIdsResultChan
if !ok {
return nil, protocol.ErrProtocolShuttingDown
case txIds := <-s.requestTxIdsResultChan:
// Update ack count for next call
s.ackCount = len(txIds)
return txIds, nil
}
if result.err != nil {
return nil, result.err
}
// Update ack count for next call
s.ackCount = len(result.txIds)
return result.txIds, nil
}

// RequestTxs requests the content of the requested TX identifiers from the remote node's mempool
Expand Down Expand Up @@ -182,7 +194,9 @@ func (s *Server) handleReplyTxIds(msg protocol.Message) error {
"connection_id", s.callbackContext.ConnectionId.String(),
)
msgReplyTxIds := msg.(*MsgReplyTxIds)
s.requestTxIdsResultChan <- msgReplyTxIds.TxIds
s.requestTxIdsResultChan <- requestTxIdsResult{
txIds: msgReplyTxIds.TxIds,
}
return nil
}

Expand All @@ -207,6 +221,10 @@ func (s *Server) handleDone() error {
"role", "server",
"connection_id", s.callbackContext.ConnectionId.String(),
)
// Signal the RequestTxIds function to stop waiting
s.requestTxIdsResultChan <- requestTxIdsResult{
err: ErrStopServerProcess,
}
// Call the user callback function
if s.config != nil && s.config.DoneFunc != nil {
if err := s.config.DoneFunc(s.callbackContext); err != nil {
Expand All @@ -216,9 +234,9 @@ func (s *Server) handleDone() error {
// Restart protocol
s.Stop()
s.initProtocol()
s.requestTxIdsResultChan = make(chan []TxIdAndSize)
s.requestTxIdsResultChan = make(chan requestTxIdsResult)
s.requestTxsResultChan = make(chan []TxBody)
s.Protocol.Start()
s.Start()
return nil
}

Expand Down
Loading