diff --git a/accounts/checkers.go b/accounts/checkers.go index 00a41d393..0b99bd68a 100644 --- a/accounts/checkers.go +++ b/accounts/checkers.go @@ -131,7 +131,7 @@ func NewAccountChecker(service Service, } return nil, service.AssociateInvoice( - acct.ID, hash, + ctx, acct.ID, hash, ) }, mid.PassThroughErrorHandler, ), @@ -615,12 +615,12 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params, fee := lnrpc.CalculateFeeLimit(limit, sendAmt) sendAmt += fee - err = service.CheckBalance(acct.ID, sendAmt) + err = service.CheckBalance(ctx, acct.ID, sendAmt) if err != nil { return fmt.Errorf("error validating account balance: %w", err) } - err = service.AssociatePayment(acct.ID, pHash, sendAmt) + err = service.AssociatePayment(ctx, acct.ID, pHash, sendAmt) if err != nil { return fmt.Errorf("error associating payment: %w", err) } @@ -661,11 +661,13 @@ func checkSendResponse(ctx context.Context, service Service, if status == lnrpc.Payment_FAILED { service.DeleteValues(reqID) - return nil, service.RemovePayment(hash) + return nil, service.RemovePayment(ctx, hash) } // If there is no immediate failure, make sure we track the payment. - err = service.TrackPayment(acct.ID, hash, lnwire.MilliSatoshi(fullAmt)) + err = service.TrackPayment( + ctx, acct.ID, hash, lnwire.MilliSatoshi(fullAmt), + ) if err != nil { return nil, err } @@ -713,12 +715,12 @@ func checkSendToRoute(ctx context.Context, service Service, paymentHash []byte, } sendAmt += fee - err = service.CheckBalance(acct.ID, sendAmt) + err = service.CheckBalance(ctx, acct.ID, sendAmt) if err != nil { return fmt.Errorf("error validating account balance: %w", err) } - err = service.AssociatePayment(acct.ID, hash, sendAmt) + err = service.AssociatePayment(ctx, acct.ID, hash, sendAmt) if err != nil { return fmt.Errorf("error associating payment with hash %s: %w", hash, err) @@ -749,7 +751,7 @@ func erroredPaymentHandler(service Service) mid.ErrorHandler { "hash: %s and amount: %d", reqVals.PaymentHash, reqVals.PaymentAmount) - err = service.PaymentErrored(acct.ID, reqVals.PaymentHash) + err = service.PaymentErrored(ctx, acct.ID, reqVals.PaymentHash) if err != nil { return nil, err } @@ -812,7 +814,7 @@ func sendToRouteHTLCResponseHandler(service Service) func(ctx context.Context, } err = service.TrackPayment( - acct.ID, reqValues.PaymentHash, + ctx, acct.ID, reqValues.PaymentHash, lnwire.MilliSatoshi(totalAmount), ) if err != nil { diff --git a/accounts/checkers_test.go b/accounts/checkers_test.go index 964142380..8c8c6c763 100644 --- a/accounts/checkers_test.go +++ b/accounts/checkers_test.go @@ -71,7 +71,7 @@ func newMockService() *mockService { } } -func (m *mockService) CheckBalance(_ AccountID, +func (m *mockService) CheckBalance(_ context.Context, _ AccountID, wantBalance lnwire.MilliSatoshi) error { if wantBalance > m.acctBalanceMsat { @@ -81,24 +81,28 @@ func (m *mockService) CheckBalance(_ AccountID, return nil } -func (m *mockService) AssociateInvoice(id AccountID, hash lntypes.Hash) error { +func (m *mockService) AssociateInvoice(_ context.Context, id AccountID, + hash lntypes.Hash) error { + m.trackedInvoices[hash] = id return nil } -func (m *mockService) AssociatePayment(id AccountID, paymentHash lntypes.Hash, - amt lnwire.MilliSatoshi) error { +func (m *mockService) AssociatePayment(_ context.Context, id AccountID, + paymentHash lntypes.Hash, amt lnwire.MilliSatoshi) error { return nil } -func (m *mockService) PaymentErrored(id AccountID, hash lntypes.Hash) error { +func (m *mockService) PaymentErrored(_ context.Context, id AccountID, + hash lntypes.Hash) error { + return nil } -func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash, - amt lnwire.MilliSatoshi) error { +func (m *mockService) TrackPayment(_ context.Context, _ AccountID, + hash lntypes.Hash, amt lnwire.MilliSatoshi) error { m.trackedPayments[hash] = &PaymentEntry{ Status: lnrpc.Payment_UNKNOWN, @@ -108,7 +112,9 @@ func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash, return nil } -func (m *mockService) RemovePayment(hash lntypes.Hash) error { +func (m *mockService) RemovePayment(_ context.Context, + hash lntypes.Hash) error { + delete(m.trackedPayments, hash) return nil @@ -517,14 +523,15 @@ func testSendPayment(t *testing.T, uri string) { errFunc := func(err error) { lndMock.mainErrChan <- err } - service, err := NewService(t.TempDir(), errFunc) + store := NewTestDB(t) + service, err := NewService(store, errFunc) require.NoError(t, err) err = service.Start(ctx, lndMock, routerMock, chainParams) require.NoError(t, err) assertBalance := func(id AccountID, expectedBalance int64) { - acct, err := service.Account(id) + acct, err := service.Account(ctx, id) require.NoError(t, err) require.Equal(t, expectedBalance, @@ -539,7 +546,7 @@ func testSendPayment(t *testing.T, uri string) { // Create an account and add it to the context. acct, err := service.NewAccount( - 5000, time.Now().Add(time.Hour), "test", + ctx, 5000, time.Now().Add(time.Hour), "test", ) require.NoError(t, err) @@ -713,14 +720,15 @@ func TestSendPaymentV2(t *testing.T) { errFunc := func(err error) { lndMock.mainErrChan <- err } - service, err := NewService(t.TempDir(), errFunc) + store := NewTestDB(t) + service, err := NewService(store, errFunc) require.NoError(t, err) err = service.Start(ctx, lndMock, routerMock, chainParams) require.NoError(t, err) assertBalance := func(id AccountID, expectedBalance int64) { - acct, err := service.Account(id) + acct, err := service.Account(ctx, id) require.NoError(t, err) require.Equal(t, expectedBalance, @@ -735,7 +743,7 @@ func TestSendPaymentV2(t *testing.T) { // Create an account and add it to the context. acct, err := service.NewAccount( - 5000, time.Now().Add(time.Hour), "test", + ctx, 5000, time.Now().Add(time.Hour), "test", ) require.NoError(t, err) @@ -900,14 +908,15 @@ func TestSendToRouteV2(t *testing.T) { errFunc := func(err error) { lndMock.mainErrChan <- err } - service, err := NewService(t.TempDir(), errFunc) + store := NewTestDB(t) + service, err := NewService(store, errFunc) require.NoError(t, err) err = service.Start(ctx, lndMock, routerMock, chainParams) require.NoError(t, err) assertBalance := func(id AccountID, expectedBalance int64) { - acct, err := service.Account(id) + acct, err := service.Account(ctx, id) require.NoError(t, err) require.Equal(t, expectedBalance, @@ -922,7 +931,7 @@ func TestSendToRouteV2(t *testing.T) { // Create an account and add it to the context. acct, err := service.NewAccount( - 5000, time.Now().Add(time.Hour), "test", + ctx, 5000, time.Now().Add(time.Hour), "test", ) require.NoError(t, err) diff --git a/accounts/errors.go b/accounts/errors.go new file mode 100644 index 000000000..8b3a59afb --- /dev/null +++ b/accounts/errors.go @@ -0,0 +1,11 @@ +package accounts + +import "errors" + +var ( + // ErrLabelAlreadyExists is returned by the CreateAccount method if the + // account label is already used by an existing account. + ErrLabelAlreadyExists = errors.New( + "account label uniqueness constraint violation", + ) +) diff --git a/accounts/interceptor.go b/accounts/interceptor.go index aa3d759f0..079f4ba07 100644 --- a/accounts/interceptor.go +++ b/accounts/interceptor.go @@ -84,7 +84,7 @@ func (s *InterceptorService) Intercept(ctx context.Context, "macaroon caveat") } - acct, err := s.Account(*acctID) + acct, err := s.Account(ctx, *acctID) if err != nil { return mid.RPCErrString( req, "error getting account %x: %v", acctID[:], err, diff --git a/accounts/interface.go b/accounts/interface.go index 68bb5ad5c..c9c1e78e0 100644 --- a/accounts/interface.go +++ b/accounts/interface.go @@ -1,6 +1,7 @@ package accounts import ( + "context" "encoding/hex" "errors" "fmt" @@ -201,30 +202,34 @@ var ( type Store interface { // NewAccount creates a new OffChainBalanceAccount with the given // balance and a randomly chosen ID. - NewAccount(balance lnwire.MilliSatoshi, expirationDate time.Time, - label string) (*OffChainBalanceAccount, error) + NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, + expirationDate time.Time, label string) ( + *OffChainBalanceAccount, error) // UpdateAccount writes an account to the database, overwriting the // existing one if it exists. - UpdateAccount(account *OffChainBalanceAccount) error + UpdateAccount(ctx context.Context, + account *OffChainBalanceAccount) error // Account retrieves an account from the Store and un-marshals it. If // the account cannot be found, then ErrAccNotFound is returned. - Account(id AccountID) (*OffChainBalanceAccount, error) + Account(ctx context.Context, id AccountID) (*OffChainBalanceAccount, + error) // Accounts retrieves all accounts from the store and un-marshals them. - Accounts() ([]*OffChainBalanceAccount, error) + Accounts(ctx context.Context) ([]*OffChainBalanceAccount, error) // RemoveAccount finds an account by its ID and removes it from the¨ // store. - RemoveAccount(id AccountID) error + RemoveAccount(ctx context.Context, id AccountID) error // LastIndexes returns the last invoice add and settle index or // ErrNoInvoiceIndexKnown if no indexes are known yet. - LastIndexes() (uint64, uint64, error) + LastIndexes(ctx context.Context) (uint64, uint64, error) // StoreLastIndexes stores the last invoice add and settle index. - StoreLastIndexes(addIndex, settleIndex uint64) error + StoreLastIndexes(ctx context.Context, addIndex, + settleIndex uint64) error // Close closes the underlying store. Close() error @@ -234,34 +239,37 @@ type Store interface { type Service interface { // CheckBalance ensures an account is valid and has a balance equal to // or larger than the amount that is required. - CheckBalance(id AccountID, requiredBalance lnwire.MilliSatoshi) error + CheckBalance(ctx context.Context, id AccountID, + requiredBalance lnwire.MilliSatoshi) error // AssociateInvoice associates a generated invoice with the given // account, making it possible for the account to be credited in case // the invoice is paid. - AssociateInvoice(id AccountID, hash lntypes.Hash) error + AssociateInvoice(ctx context.Context, id AccountID, + hash lntypes.Hash) error // TrackPayment adds a new payment to be tracked to the service. If the // payment is eventually settled, its amount needs to be debited from // the given account. - TrackPayment(id AccountID, hash lntypes.Hash, + TrackPayment(ctx context.Context, id AccountID, hash lntypes.Hash, fullAmt lnwire.MilliSatoshi) error // RemovePayment removes a failed payment from the service because it no // longer needs to be tracked. The payment is certain to never succeed, // so we never need to debit the amount from the account. - RemovePayment(hash lntypes.Hash) error + RemovePayment(ctx context.Context, hash lntypes.Hash) error // AssociatePayment associates a payment (hash) with the given account, // ensuring that the payment will be tracked for a user when LiT is // restarted. - AssociatePayment(id AccountID, paymentHash lntypes.Hash, - fullAmt lnwire.MilliSatoshi) error + AssociatePayment(ctx context.Context, id AccountID, + paymentHash lntypes.Hash, fullAmt lnwire.MilliSatoshi) error // PaymentErrored removes a pending payment from the accounts // registered payment list. This should only ever be called if we are // sure that the payment request errored out. - PaymentErrored(id AccountID, hash lntypes.Hash) error + PaymentErrored(ctx context.Context, id AccountID, + hash lntypes.Hash) error RequestValuesStore } diff --git a/accounts/rpcserver.go b/accounts/rpcserver.go index 22135f634..5c1314820 100644 --- a/accounts/rpcserver.go +++ b/accounts/rpcserver.go @@ -71,7 +71,7 @@ func (s *RPCServer) CreateAccount(ctx context.Context, // Create the actual account in the macaroon account store. account, err := s.service.NewAccount( - balanceMsat, expirationDate, req.Label, + ctx, balanceMsat, expirationDate, req.Label, ) if err != nil { return nil, fmt.Errorf("unable to create account: %w", err) @@ -109,20 +109,20 @@ func (s *RPCServer) CreateAccount(ctx context.Context, } // UpdateAccount updates an existing account in the account database. -func (s *RPCServer) UpdateAccount(_ context.Context, +func (s *RPCServer) UpdateAccount(ctx context.Context, req *litrpc.UpdateAccountRequest) (*litrpc.Account, error) { log.Infof("[updateaccount] id=%s, label=%v, balance=%d, expiration=%d", req.Id, req.Label, req.AccountBalance, req.ExpirationDate) - accountID, err := s.findAccount(req.Id, req.Label) + accountID, err := s.findAccount(ctx, req.Id, req.Label) if err != nil { return nil, err } // Ask the service to update the account. account, err := s.service.UpdateAccount( - accountID, req.AccountBalance, req.ExpirationDate, + ctx, accountID, req.AccountBalance, req.ExpirationDate, ) if err != nil { return nil, err @@ -133,13 +133,13 @@ func (s *RPCServer) UpdateAccount(_ context.Context, // ListAccounts returns all accounts that are currently stored in the account // database. -func (s *RPCServer) ListAccounts(context.Context, - *litrpc.ListAccountsRequest) (*litrpc.ListAccountsResponse, error) { +func (s *RPCServer) ListAccounts(ctx context.Context, + _ *litrpc.ListAccountsRequest) (*litrpc.ListAccountsResponse, error) { log.Info("[listaccounts]") // Retrieve all accounts from the macaroon account store. - accts, err := s.service.Accounts() + accts, err := s.service.Accounts(ctx) if err != nil { return nil, fmt.Errorf("unable to list accounts: %w", err) } @@ -158,17 +158,17 @@ func (s *RPCServer) ListAccounts(context.Context, } // AccountInfo returns the account with the given ID or label. -func (s *RPCServer) AccountInfo(_ context.Context, +func (s *RPCServer) AccountInfo(ctx context.Context, req *litrpc.AccountInfoRequest) (*litrpc.Account, error) { log.Infof("[accountinfo] id=%v, label=%v", req.Id, req.Label) - accountID, err := s.findAccount(req.Id, req.Label) + accountID, err := s.findAccount(ctx, req.Id, req.Label) if err != nil { return nil, err } - dbAccount, err := s.service.Account(accountID) + dbAccount, err := s.service.Account(ctx, accountID) if err != nil { return nil, fmt.Errorf("error retrieving account: %w", err) } @@ -177,19 +177,19 @@ func (s *RPCServer) AccountInfo(_ context.Context, } // RemoveAccount removes the given account from the account database. -func (s *RPCServer) RemoveAccount(_ context.Context, +func (s *RPCServer) RemoveAccount(ctx context.Context, req *litrpc.RemoveAccountRequest) (*litrpc.RemoveAccountResponse, error) { log.Infof("[removeaccount] id=%v, label=%v", req.Id, req.Label) - accountID, err := s.findAccount(req.Id, req.Label) + accountID, err := s.findAccount(ctx, req.Id, req.Label) if err != nil { return nil, err } // Now remove the account. - err = s.service.RemoveAccount(accountID) + err = s.service.RemoveAccount(ctx, accountID) if err != nil { return nil, fmt.Errorf("error removing account: %w", err) } @@ -198,7 +198,9 @@ func (s *RPCServer) RemoveAccount(_ context.Context, } // findAccount finds an account by its ID or label. -func (s *RPCServer) findAccount(id string, label string) (AccountID, error) { +func (s *RPCServer) findAccount(ctx context.Context, id string, label string) ( + AccountID, error) { + switch { case id != "" && label != "": return AccountID{}, fmt.Errorf("either account ID or label " + @@ -219,7 +221,7 @@ func (s *RPCServer) findAccount(id string, label string) (AccountID, error) { case label != "": // We need to find the account by its label. - accounts, err := s.service.Accounts() + accounts, err := s.service.Accounts(ctx) if err != nil { return AccountID{}, fmt.Errorf("unable to list "+ "accounts: %w", err) diff --git a/accounts/service.go b/accounts/service.go index 5db9ae872..820dad23e 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -78,16 +78,11 @@ type InterceptorService struct { // NewService returns a service backed by the macaroon Bolt DB stored in the // passed-in directory. -func NewService(dir string, - errCallback func(error)) (*InterceptorService, error) { - - accountStore, err := NewBoltStore(dir, DBFilename) - if err != nil { - return nil, err - } +func NewService(store Store, errCallback func(error)) (*InterceptorService, + error) { return &InterceptorService{ - store: accountStore, + store: store, invoiceToAccount: make(map[lntypes.Hash]AccountID), pendingPayments: make(map[lntypes.Hash]*trackedPayment), requestValuesStore: newRequestValuesStore(), @@ -114,7 +109,7 @@ func (s *InterceptorService) Start(ctx context.Context, // Let's first fill our cache that maps invoices to accounts, which // allows us to credit an account easily once an invoice is settled. We // also track payments that aren't in a final state yet. - existingAccounts, err := s.store.Accounts() + existingAccounts, err := s.store.Accounts(ctx) if err != nil { return s.disableAndErrorf("error querying existing "+ "accounts: %w", err) @@ -132,7 +127,7 @@ func (s *InterceptorService) Start(ctx context.Context, entry := entry if !successState(entry.Status) { err := s.TrackPayment( - acct.ID, hash, entry.FullAmount, + ctx, acct.ID, hash, entry.FullAmount, ) if err != nil { return s.disableAndErrorf("error "+ @@ -145,7 +140,7 @@ func (s *InterceptorService) Start(ctx context.Context, // First ask our DB about the highest indexes we know. If this is the // first startup then the ErrNoInvoiceIndexKnown error is returned, and // we know we need to do a lookup. - s.currentAddIndex, s.currentSettleIndex, err = s.store.LastIndexes() + s.currentAddIndex, s.currentSettleIndex, err = s.store.LastIndexes(ctx) switch err { case nil: // All good, we stored indexes in the DB, use those values. @@ -193,7 +188,8 @@ func (s *InterceptorService) Start(ctx context.Context, return } - if err := s.invoiceUpdate(invoice); err != nil { + err := s.invoiceUpdate(ctx, invoice) + if err != nil { log.Errorf("Error processing invoice "+ "update: %v", err) @@ -241,7 +237,7 @@ func (s *InterceptorService) Stop() error { close(s.quit) s.wg.Wait() - return s.store.Close() + return nil } // IsRunning checks if the account service is running, and returns a boolean @@ -289,19 +285,21 @@ func (s *InterceptorService) disableAndErrorfUnsafe(format string, // NewAccount creates a new OffChainBalanceAccount with the given balance and a // randomly chosen ID. -func (s *InterceptorService) NewAccount(balance lnwire.MilliSatoshi, +func (s *InterceptorService) NewAccount(ctx context.Context, + balance lnwire.MilliSatoshi, expirationDate time.Time, label string) (*OffChainBalanceAccount, error) { s.Lock() defer s.Unlock() - return s.store.NewAccount(balance, expirationDate, label) + return s.store.NewAccount(ctx, balance, expirationDate, label) } // UpdateAccount writes an account to the database, overwriting the existing one // if it exists. -func (s *InterceptorService) UpdateAccount(accountID AccountID, accountBalance, +func (s *InterceptorService) UpdateAccount(ctx context.Context, + accountID AccountID, accountBalance, expirationDate int64) (*OffChainBalanceAccount, error) { s.Lock() @@ -315,7 +313,7 @@ func (s *InterceptorService) UpdateAccount(accountID AccountID, accountBalance, return nil, ErrAccountServiceDisabled } - account, err := s.store.Account(accountID) + account, err := s.store.Account(ctx, accountID) if err != nil { return nil, fmt.Errorf("error fetching account: %w", err) } @@ -339,7 +337,7 @@ func (s *InterceptorService) UpdateAccount(accountID AccountID, accountBalance, } // Create the actual account in the macaroon account store. - err = s.store.UpdateAccount(account) + err = s.store.UpdateAccount(ctx, account) if err != nil { return nil, fmt.Errorf("unable to update account: %w", err) } @@ -349,25 +347,29 @@ func (s *InterceptorService) UpdateAccount(accountID AccountID, accountBalance, // Account retrieves an account from the bolt DB and un-marshals it. If the // account cannot be found, then ErrAccNotFound is returned. -func (s *InterceptorService) Account(id AccountID) (*OffChainBalanceAccount, - error) { +func (s *InterceptorService) Account(ctx context.Context, + id AccountID) (*OffChainBalanceAccount, error) { s.RLock() defer s.RUnlock() - return s.store.Account(id) + return s.store.Account(ctx, id) } // Accounts retrieves all accounts from the bolt DB and un-marshals them. -func (s *InterceptorService) Accounts() ([]*OffChainBalanceAccount, error) { +func (s *InterceptorService) Accounts(ctx context.Context) ( + []*OffChainBalanceAccount, error) { + s.RLock() defer s.RUnlock() - return s.store.Accounts() + return s.store.Accounts(ctx) } // RemoveAccount finds an account by its ID and removes it from the DB. -func (s *InterceptorService) RemoveAccount(id AccountID) error { +func (s *InterceptorService) RemoveAccount(ctx context.Context, + id AccountID) error { + s.Lock() defer s.Unlock() @@ -378,18 +380,18 @@ func (s *InterceptorService) RemoveAccount(id AccountID) error { } // Let's remove the payment (which also cancels the tracking). - err := s.removePayment(hash, lnrpc.Payment_FAILED) + err := s.removePayment(ctx, hash, lnrpc.Payment_FAILED) if err != nil { return err } } - return s.store.RemoveAccount(id) + return s.store.RemoveAccount(ctx, id) } // CheckBalance ensures an account is valid and has a balance equal to or larger // than the amount that is required. -func (s *InterceptorService) CheckBalance(id AccountID, +func (s *InterceptorService) CheckBalance(ctx context.Context, id AccountID, requiredBalance lnwire.MilliSatoshi) error { s.RLock() @@ -397,7 +399,7 @@ func (s *InterceptorService) CheckBalance(id AccountID, // Check that the account exists, it hasn't expired and has sufficient // balance. - account, err := s.store.Account(id) + account, err := s.store.Account(ctx, id) if err != nil { return err } @@ -431,26 +433,27 @@ func calcAvailableAccountBalance(account *OffChainBalanceAccount) int64 { // AssociateInvoice associates a generated invoice with the given account, // making it possible for the account to be credited in case the invoice is // paid. -func (s *InterceptorService) AssociateInvoice(id AccountID, +func (s *InterceptorService) AssociateInvoice(ctx context.Context, id AccountID, hash lntypes.Hash) error { s.Lock() defer s.Unlock() - account, err := s.store.Account(id) + account, err := s.store.Account(ctx, id) if err != nil { return err } account.Invoices[hash] = struct{}{} s.invoiceToAccount[hash] = id - return s.store.UpdateAccount(account) + + return s.store.UpdateAccount(ctx, account) } // PaymentErrored removes a pending payment from the account's registered // payment list. This should only ever be called if we are sure that the payment // request errored out. -func (s *InterceptorService) PaymentErrored(id AccountID, +func (s *InterceptorService) PaymentErrored(ctx context.Context, id AccountID, hash lntypes.Hash) error { s.Lock() @@ -464,7 +467,7 @@ func (s *InterceptorService) PaymentErrored(id AccountID, "has already started") } - account, err := s.store.Account(id) + account, err := s.store.Account(ctx, id) if err != nil { return err } @@ -479,7 +482,7 @@ func (s *InterceptorService) PaymentErrored(id AccountID, // Delete the payment and update the persisted account. delete(account.Payments, hash) - if err := s.store.UpdateAccount(account); err != nil { + if err := s.store.UpdateAccount(ctx, account); err != nil { return fmt.Errorf("error updating account: %w", err) } @@ -489,13 +492,13 @@ func (s *InterceptorService) PaymentErrored(id AccountID, // AssociatePayment associates a payment (hash) with the given account, // ensuring that the payment will be tracked for a user when LiT is // restarted. -func (s *InterceptorService) AssociatePayment(id AccountID, +func (s *InterceptorService) AssociatePayment(ctx context.Context, id AccountID, paymentHash lntypes.Hash, fullAmt lnwire.MilliSatoshi) error { s.Lock() defer s.Unlock() - account, err := s.store.Account(id) + account, err := s.store.Account(ctx, id) if err != nil { return err } @@ -528,7 +531,7 @@ func (s *InterceptorService) AssociatePayment(id AccountID, FullAmount: fullAmt, } - if err := s.store.UpdateAccount(account); err != nil { + if err := s.store.UpdateAccount(ctx, account); err != nil { return fmt.Errorf("error updating account: %w", err) } @@ -543,7 +546,9 @@ func (s *InterceptorService) AssociatePayment(id AccountID, // the same lock. Else we risk that other threads will try to update invoices // while the service should be disabled, which could lead to us missing invoice // updates on next startup. -func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { +func (s *InterceptorService) invoiceUpdate(ctx context.Context, + invoice *lndclient.Invoice) error { + s.Lock() defer s.Unlock() @@ -572,7 +577,7 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { if needUpdate { err := s.store.StoreLastIndexes( - s.currentAddIndex, s.currentSettleIndex, + ctx, s.currentAddIndex, s.currentSettleIndex, ) if err != nil { return s.disableAndErrorfUnsafe( @@ -594,7 +599,7 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { return nil } - account, err := s.store.Account(acctID) + account, err := s.store.Account(ctx, acctID) if err != nil { return s.disableAndErrorfUnsafe( "error fetching account: %w", err, @@ -605,7 +610,7 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { // it that was just paid. Credit the amount to the account and update it // in the DB. account.CurrentBalance += int64(invoice.AmountPaid) - if err := s.store.UpdateAccount(account); err != nil { + if err := s.store.UpdateAccount(ctx, account); err != nil { return s.disableAndErrorfUnsafe( "error updating account: %w", err, ) @@ -620,8 +625,8 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { // TrackPayment adds a new payment to be tracked to the service. If the payment // is eventually settled, its amount needs to be debited from the given account. -func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, - fullAmt lnwire.MilliSatoshi) error { +func (s *InterceptorService) TrackPayment(ctx context.Context, id AccountID, + hash lntypes.Hash, fullAmt lnwire.MilliSatoshi) error { s.Lock() defer s.Unlock() @@ -634,7 +639,7 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, // Similarly, if we've already processed the payment in the past, there // is a reference in the account with the given state. - account, err := s.store.Account(id) + account, err := s.store.Account(ctx, id) if err != nil { return fmt.Errorf("error fetching account: %w", err) } @@ -658,7 +663,7 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, FullAmount: fullAmt, } - if err := s.store.UpdateAccount(account); err != nil { + if err := s.store.UpdateAccount(ctx, account); err != nil { if !ok { // In the rare case that the payment isn't associated // with an account yet, and we fail to update the @@ -718,7 +723,7 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, select { case paymentUpdate := <-statusChan: terminalState, err := s.paymentUpdate( - hash, paymentUpdate, + s.mainCtx, hash, paymentUpdate, ) if err != nil { s.mainErrCallback(err) @@ -746,7 +751,7 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, // seen as in-flight balance when // calculating the account's available // balance. - err := s.RemovePayment(hash) + err := s.RemovePayment(ctx, hash) if err != nil { // We don't disable the service // here, as the worst that can @@ -789,8 +794,8 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, // NOTE: Any code that errors in this function MUST call disableAndErrorfUnsafe // while the store lock is held to ensure that the service is disabled under // the same lock. -func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, - status lndclient.PaymentStatus) (bool, error) { +func (s *InterceptorService) paymentUpdate(ctx context.Context, + hash lntypes.Hash, status lndclient.PaymentStatus) (bool, error) { // Are we still in-flight? Then we don't have to do anything just yet. // The unknown state should never happen in practice but if it ever did @@ -824,7 +829,7 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, // A failed payment can just be removed, no further action needed. if status.State == lnrpc.Payment_FAILED { - err := s.removePayment(hash, status.State) + err := s.removePayment(ctx, hash, status.State) if err != nil { err = s.disableAndErrorfUnsafe("error removing "+ "payment: %w", err) @@ -835,7 +840,7 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, // The payment went through! We now need to debit the full amount from // the account. - account, err := s.store.Account(pendingPayment.accountID) + account, err := s.store.Account(ctx, pendingPayment.accountID) if err != nil { err = s.disableAndErrorfUnsafe("error fetching account: %w", err) @@ -851,7 +856,7 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, Status: lnrpc.Payment_SUCCEEDED, FullAmount: fullAmount, } - if err := s.store.UpdateAccount(account); err != nil { + if err := s.store.UpdateAccount(ctx, account); err != nil { err = s.disableAndErrorfUnsafe("error updating account: %w", err) @@ -860,7 +865,7 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, // We've now fully processed the payment and don't need to keep it // mapped or tracked anymore. - err = s.removePayment(hash, lnrpc.Payment_SUCCEEDED) + err = s.removePayment(ctx, hash, lnrpc.Payment_SUCCEEDED) if err != nil { err = s.disableAndErrorfUnsafe("error removing payment: %w", err) @@ -872,19 +877,21 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, // RemovePayment removes a failed payment from the service because it no longer // needs to be tracked. The payment is certain to never succeed, so we never // need to debit the amount from the account. -func (s *InterceptorService) RemovePayment(hash lntypes.Hash) error { +func (s *InterceptorService) RemovePayment(ctx context.Context, + hash lntypes.Hash) error { + s.Lock() defer s.Unlock() - return s.removePayment(hash, lnrpc.Payment_FAILED) + return s.removePayment(ctx, hash, lnrpc.Payment_FAILED) } // removePayment stops tracking a payment and updates the status in the account // to the given status. // // NOTE: The store lock MUST be held when calling this method. -func (s *InterceptorService) removePayment(hash lntypes.Hash, - status lnrpc.Payment_PaymentStatus) error { +func (s *InterceptorService) removePayment(ctx context.Context, + hash lntypes.Hash, status lnrpc.Payment_PaymentStatus) error { // It could be that we haven't actually started tracking the payment // yet, so if we can't find it, we just do nothing. @@ -893,7 +900,7 @@ func (s *InterceptorService) removePayment(hash lntypes.Hash, return nil } - account, err := s.store.Account(pendingPayment.accountID) + account, err := s.store.Account(ctx, pendingPayment.accountID) if err != nil { return err } @@ -909,7 +916,7 @@ func (s *InterceptorService) removePayment(hash lntypes.Hash, // If we did, let's set the status correctly in the DB now. account.Payments[hash].Status = status - return s.store.UpdateAccount(account) + return s.store.UpdateAccount(ctx, account) } // successState returns true if a payment was completed successfully. diff --git a/accounts/service_test.go b/accounts/service_test.go index b38b119a4..2a28b9174 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -197,6 +197,7 @@ func (r *mockRouter) TrackPayment(_ context.Context, // invoices of account related calls correctly. func TestAccountService(t *testing.T) { t.Parallel() + ctx := context.Background() testCases := []struct { name string @@ -233,7 +234,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -242,7 +243,7 @@ func TestAccountService(t *testing.T) { // Start by closing the store. This should cause an // error once we make an invoice update, as the service // will fail when persisting the invoice update. - s.store.Close() + require.NoError(t, s.store.Close()) // Ensure that the service was started successfully and // still running though, despite the closing of the @@ -260,10 +261,9 @@ func TestAccountService(t *testing.T) { // Ensure that the service was eventually disabled. assertEventually(t, func() bool { - isRunning := s.IsRunning() - return isRunning == false + return !s.IsRunning() }) - lnd.assertMainErrContains(t, "database not open") + lnd.assertMainErrContains(t, ErrDBClosed.Error()) }, }, { name: "err in invoice err channel", @@ -279,7 +279,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -293,8 +293,7 @@ func TestAccountService(t *testing.T) { // Ensure that the service was eventually disabled. assertEventually(t, func() bool { - isRunning := s.IsRunning() - return isRunning == false + return !s.IsRunning() }) lnd.assertMainErrContains(t, testErr.Error()) @@ -314,7 +313,7 @@ func TestAccountService(t *testing.T) { Payments: make(AccountPayments), } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) s.mainErrCallback(testErr) @@ -331,7 +330,7 @@ func TestAccountService(t *testing.T) { s *InterceptorService) { acct, err := s.store.NewAccount( - 1234, testExpiration, "", + ctx, 1234, testExpiration, "", ) require.NoError(t, err) @@ -341,7 +340,7 @@ func TestAccountService(t *testing.T) { FullAmount: 1234, } - err = s.store.UpdateAccount(acct) + err = s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -373,7 +372,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) r.trackPaymentErr = testErr @@ -410,7 +409,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -439,8 +438,7 @@ func TestAccountService(t *testing.T) { // Ensure that the service was eventually disabled. assertEventually(t, func() bool { - isRunning := s.IsRunning() - return isRunning == false + return !s.IsRunning() }) lnd.assertMainErrContains( t, "not mapped to any account", @@ -463,7 +461,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -482,8 +480,7 @@ func TestAccountService(t *testing.T) { // Ensure that the service was eventually disabled. assertEventually(t, func() bool { - isRunning := s.IsRunning() - return isRunning == false + return !s.IsRunning() }) lnd.assertMainErrContains(t, testErr.Error()) @@ -516,7 +513,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -540,7 +537,7 @@ func TestAccountService(t *testing.T) { } assertEventually(t, func() bool { - acct, err := s.store.Account(testID) + acct, err := s.store.Account(ctx, testID) require.NoError(t, err) return acct.CurrentBalance == 3000 @@ -556,7 +553,7 @@ func TestAccountService(t *testing.T) { } assertEventually(t, func() bool { - acct, err := s.store.Account(testID) + acct, err := s.store.Account(ctx, testID) require.NoError(t, err) if len(acct.Payments) != 3 { @@ -582,10 +579,10 @@ func TestAccountService(t *testing.T) { // First check that the account has an available balance // of 1000. That means that the payment with testHash3 // and amount 2000 is still considered to be in-flight. - err := s.CheckBalance(testID, 1000) + err := s.CheckBalance(ctx, testID, 1000) require.NoError(t, err) - err = s.CheckBalance(testID, 1001) + err = s.CheckBalance(ctx, testID, 1001) require.ErrorIs(t, err, ErrAccBalanceInsufficient) // Now signal that the payment was non-initiated. @@ -595,8 +592,8 @@ func TestAccountService(t *testing.T) { // goroutine, and therefore free up the 2000 in-flight // balance. assertEventually(t, func() bool { - bal3000Err := s.CheckBalance(testID, 3000) - bal3001Err := s.CheckBalance(testID, 3001) + bal3000Err := s.CheckBalance(ctx, testID, 3000) + bal3001Err := s.CheckBalance(ctx, testID, 3001) require.ErrorIs( t, bal3001Err, ErrAccBalanceInsufficient, @@ -606,7 +603,7 @@ func TestAccountService(t *testing.T) { // Ensure that the payment is also set to the // failed status. - acct, err := s.store.Account(testID) + acct, err := s.store.Account(ctx, testID) require.NoError(t, err) p, ok := acct.Payments[testHash3] @@ -626,7 +623,7 @@ func TestAccountService(t *testing.T) { setup: func(t *testing.T, lnd *mockLnd, r *mockRouter, s *InterceptorService) { - err := s.store.StoreLastIndexes(987_654, 555_555) + err := s.store.StoreLastIndexes(ctx, 987_654, 555_555) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -645,7 +642,9 @@ func TestAccountService(t *testing.T) { } assertEventually(t, func() bool { - addIdx, settleIdx, err := s.store.LastIndexes() + addIdx, settleIdx, err := s.store.LastIndexes( + ctx, + ) require.NoError(t, err) if addIdx != 987_654 { @@ -662,7 +661,9 @@ func TestAccountService(t *testing.T) { } assertEventually(t, func() bool { - addIdx, settleIdx, err := s.store.LastIndexes() + addIdx, settleIdx, err := s.store.LastIndexes( + ctx, + ) require.NoError(t, err) if addIdx != 1_000_000 { @@ -688,7 +689,7 @@ func TestAccountService(t *testing.T) { Payments: make(AccountPayments), } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -705,7 +706,7 @@ func TestAccountService(t *testing.T) { // Make sure the amount paid is eventually credited. assertEventually(t, func() bool { - acct, err := s.store.Account(testID) + acct, err := s.store.Account(ctx, testID) require.NoError(t, err) return acct.CurrentBalance == 1000 @@ -723,7 +724,7 @@ func TestAccountService(t *testing.T) { // Ensure that the balance now adds up to the sum of // both invoices. assertEventually(t, func() bool { - acct, err := s.store.Account(testID) + acct, err := s.store.Account(ctx, testID) require.NoError(t, err) return acct.CurrentBalance == (1000 + 777) @@ -757,7 +758,7 @@ func TestAccountService(t *testing.T) { }, } - err := s.store.UpdateAccount(acct) + err := s.store.UpdateAccount(ctx, acct) require.NoError(t, err) // The second account has one in-flight payment of 4k @@ -777,7 +778,7 @@ func TestAccountService(t *testing.T) { }, } - err = s.store.UpdateAccount(acct2) + err = s.store.UpdateAccount(ctx, acct2) require.NoError(t, err) }, validate: func(t *testing.T, lnd *mockLnd, r *mockRouter, @@ -787,11 +788,11 @@ func TestAccountService(t *testing.T) { // with an amount smaller or equal to 2k msats. This // also asserts that the second accounts in-flight // payment doesn't affect the first account. - err := s.CheckBalance(testID, 2000) + err := s.CheckBalance(ctx, testID, 2000) require.NoError(t, err) // But exactly one sat over it should fail. - err = s.CheckBalance(testID, 2001) + err = s.CheckBalance(ctx, testID, 2001) require.ErrorIs(t, err, ErrAccBalanceInsufficient) // Remove one of the payments (to simulate it failed) @@ -802,17 +803,17 @@ func TestAccountService(t *testing.T) { // We should now have up to 4k msats available. assertEventually(t, func() bool { - err = s.CheckBalance(testID, 4000) + err = s.CheckBalance(ctx, testID, 4000) return err == nil }) // The second account should be able to initiate a // payment of 1k msats. - err = s.CheckBalance(testID2, 1000) + err = s.CheckBalance(ctx, testID2, 1000) require.NoError(t, err) // But exactly one sat over it should fail. - err = s.CheckBalance(testID2, 1001) + err = s.CheckBalance(ctx, testID2, 1001) require.ErrorIs(t, err, ErrAccBalanceInsufficient) }, }} @@ -828,7 +829,8 @@ func TestAccountService(t *testing.T) { errFunc := func(err error) { lndMock.mainErrChan <- err } - service, err := NewService(t.TempDir(), errFunc) + store := NewTestDB(t) + service, err := NewService(store, errFunc) require.NoError(t, err) // Is a setup call required to initialize initial @@ -839,8 +841,7 @@ func TestAccountService(t *testing.T) { // Any errors during startup expected? err = service.Start( - context.Background(), lndMock, routerMock, - chainParams, + ctx, lndMock, routerMock, chainParams, ) if tc.startupErr != "" { require.ErrorContains(tt, err, tc.startupErr) diff --git a/accounts/store.go b/accounts/store_kvdb.go similarity index 89% rename from accounts/store.go rename to accounts/store_kvdb.go index ebaf937be..84f22f891 100644 --- a/accounts/store.go +++ b/accounts/store_kvdb.go @@ -2,6 +2,7 @@ package accounts import ( "bytes" + "context" "crypto/rand" "encoding/binary" "encoding/hex" @@ -97,13 +98,17 @@ func NewBoltStore(dir, fileName string) (*BoltStore, error) { } // Close closes the underlying bolt DB. +// +// NOTE: This is part of the Store interface. func (s *BoltStore) Close() error { return s.db.Close() } // NewAccount creates a new OffChainBalanceAccount with the given balance and a // randomly chosen ID. -func (s *BoltStore) NewAccount(balance lnwire.MilliSatoshi, +// +// NOTE: This is part of the Store interface. +func (s *BoltStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, expirationDate time.Time, label string) (*OffChainBalanceAccount, error) { @@ -120,7 +125,7 @@ func (s *BoltStore) NewAccount(balance lnwire.MilliSatoshi, label) } - accounts, err := s.Accounts() + accounts, err := s.Accounts(ctx) if err != nil { return nil, fmt.Errorf("error checking label "+ "uniqueness: %w", err) @@ -128,7 +133,8 @@ func (s *BoltStore) NewAccount(balance lnwire.MilliSatoshi, for _, account := range accounts { if account.Label == label { return nil, fmt.Errorf("an account with the "+ - "label '%s' already exists", label) + "label '%s' already exists: %w", label, + ErrLabelAlreadyExists) } } } @@ -140,7 +146,6 @@ func (s *BoltStore) NewAccount(balance lnwire.MilliSatoshi, InitialBalance: balance, CurrentBalance: int64(balance), ExpirationDate: expirationDate, - LastUpdate: time.Now(), Invoices: make(AccountInvoices), Payments: make(AccountPayments), Label: label, @@ -174,14 +179,17 @@ func (s *BoltStore) NewAccount(balance lnwire.MilliSatoshi, // UpdateAccount writes an account to the database, overwriting the existing one // if it exists. -func (s *BoltStore) UpdateAccount(account *OffChainBalanceAccount) error { +// +// NOTE: This is part of the Store interface. +func (s *BoltStore) UpdateAccount(_ context.Context, + account *OffChainBalanceAccount) error { + return s.db.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { return ErrAccountBucketNotFound } - account.LastUpdate = time.Now() return storeAccount(bucket, account) }, func() {}) } @@ -191,6 +199,8 @@ func (s *BoltStore) UpdateAccount(account *OffChainBalanceAccount) error { func storeAccount(accountBucket kvdb.RwBucket, account *OffChainBalanceAccount) error { + account.LastUpdate = time.Now() + accountBinary, err := serializeAccount(account) if err != nil { return err @@ -225,7 +235,11 @@ func uniqueRandomAccountID(accountBucket kvdb.RBucket) (AccountID, error) { // Account retrieves an account from the bolt DB and un-marshals it. If the // account cannot be found, then ErrAccNotFound is returned. -func (s *BoltStore) Account(id AccountID) (*OffChainBalanceAccount, error) { +// +// NOTE: This is part of the Store interface. +func (s *BoltStore) Account(_ context.Context, id AccountID) ( + *OffChainBalanceAccount, error) { + // Try looking up and reading the account by its ID from the local // bolt DB. var accountBinary []byte @@ -259,7 +273,11 @@ func (s *BoltStore) Account(id AccountID) (*OffChainBalanceAccount, error) { } // Accounts retrieves all accounts from the bolt DB and un-marshals them. -func (s *BoltStore) Accounts() ([]*OffChainBalanceAccount, error) { +// +// NOTE: This is part of the Store interface. +func (s *BoltStore) Accounts(_ context.Context) ([]*OffChainBalanceAccount, + error) { + var accounts []*OffChainBalanceAccount err := s.db.View(func(tx kvdb.RTx) error { // This function will be called in the ForEach and receive @@ -302,7 +320,9 @@ func (s *BoltStore) Accounts() ([]*OffChainBalanceAccount, error) { } // RemoveAccount finds an account by its ID and removes it from the DB. -func (s *BoltStore) RemoveAccount(id AccountID) error { +// +// NOTE: This is part of the Store interface. +func (s *BoltStore) RemoveAccount(_ context.Context, id AccountID) error { return s.db.Update(func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(accountBucketName) if bucket == nil { @@ -320,7 +340,9 @@ func (s *BoltStore) RemoveAccount(id AccountID) error { // LastIndexes returns the last invoice add and settle index or // ErrNoInvoiceIndexKnown if no indexes are known yet. -func (s *BoltStore) LastIndexes() (uint64, uint64, error) { +// +// NOTE: This is part of the Store interface. +func (s *BoltStore) LastIndexes(_ context.Context) (uint64, uint64, error) { var ( addValue, settleValue []byte ) @@ -352,7 +374,11 @@ func (s *BoltStore) LastIndexes() (uint64, uint64, error) { } // StoreLastIndexes stores the last invoice add and settle index. -func (s *BoltStore) StoreLastIndexes(addIndex, settleIndex uint64) error { +// +// NOTE: This is part of the Store interface. +func (s *BoltStore) StoreLastIndexes(_ context.Context, addIndex, + settleIndex uint64) error { + addValue := make([]byte, 8) settleValue := make([]byte, 8) byteOrder.PutUint64(addValue, addIndex) diff --git a/accounts/store_test.go b/accounts/store_test.go index 2f661febc..b7167c3cf 100644 --- a/accounts/store_test.go +++ b/accounts/store_test.go @@ -1,6 +1,7 @@ package accounts import ( + "context" "testing" "time" @@ -12,26 +13,26 @@ import ( // TestAccountStore tests that accounts can be stored and retrieved correctly. func TestAccountStore(t *testing.T) { t.Parallel() + ctx := context.Background() - store, err := NewBoltStore(t.TempDir(), DBFilename) - require.NoError(t, err) + store := NewTestDB(t) // Create an account that does not expire. - acct1, err := store.NewAccount(0, time.Time{}, "foo") + acct1, err := store.NewAccount(ctx, 0, time.Time{}, "foo") require.NoError(t, err) require.False(t, acct1.HasExpired()) - dbAccount, err := store.Account(acct1.ID) + dbAccount, err := store.Account(ctx, acct1.ID) require.NoError(t, err) assertEqualAccounts(t, acct1, dbAccount) // Make sure we cannot create a second account with the same label. - _, err = store.NewAccount(123, time.Time{}, "foo") - require.ErrorContains(t, err, "account with the label 'foo' already") + _, err = store.NewAccount(ctx, 123, time.Time{}, "foo") + require.ErrorIs(t, err, ErrLabelAlreadyExists) // Make sure we cannot set a label that looks like an account ID. - _, err = store.NewAccount(123, time.Time{}, "0011223344556677") + _, err = store.NewAccount(ctx, 123, time.Time{}, "0011223344556677") require.ErrorContains(t, err, "is not allowed as it can be mistaken") // Update all values of the account that we can modify. @@ -47,10 +48,10 @@ func TestAccountStore(t *testing.T) { } acct1.Invoices[lntypes.Hash{12, 34, 56, 78}] = struct{}{} acct1.Invoices[lntypes.Hash{34, 56, 78, 90}] = struct{}{} - err = store.UpdateAccount(acct1) + err = store.UpdateAccount(ctx, acct1) require.NoError(t, err) - dbAccount, err = store.Account(acct1.ID) + dbAccount, err = store.Account(ctx, acct1.ID) require.NoError(t, err) assertEqualAccounts(t, acct1, dbAccount) @@ -62,18 +63,18 @@ func TestAccountStore(t *testing.T) { require.True(t, acct1.HasExpired()) // Test listing and deleting accounts. - accounts, err := store.Accounts() + accounts, err := store.Accounts(ctx) require.NoError(t, err) require.Len(t, accounts, 1) - err = store.RemoveAccount(acct1.ID) + err = store.RemoveAccount(ctx, acct1.ID) require.NoError(t, err) - accounts, err = store.Accounts() + accounts, err = store.Accounts(ctx) require.NoError(t, err) require.Len(t, accounts, 0) - _, err = store.Account(acct1.ID) + _, err = store.Account(ctx, acct1.ID) require.ErrorIs(t, err, ErrAccNotFound) } @@ -108,16 +109,16 @@ func assertEqualAccounts(t *testing.T, expected, // stored and retrieved correctly. func TestLastInvoiceIndexes(t *testing.T) { t.Parallel() + ctx := context.Background() - store, err := NewBoltStore(t.TempDir(), DBFilename) - require.NoError(t, err) + store := NewTestDB(t) - _, _, err = store.LastIndexes() + _, _, err := store.LastIndexes(ctx) require.ErrorIs(t, err, ErrNoInvoiceIndexKnown) - require.NoError(t, store.StoreLastIndexes(7, 99)) + require.NoError(t, store.StoreLastIndexes(ctx, 7, 99)) - add, settle, err := store.LastIndexes() + add, settle, err := store.LastIndexes(ctx) require.NoError(t, err) require.EqualValues(t, 7, add) require.EqualValues(t, 99, settle) diff --git a/accounts/test_kvdb.go b/accounts/test_kvdb.go new file mode 100644 index 000000000..b050d149c --- /dev/null +++ b/accounts/test_kvdb.go @@ -0,0 +1,30 @@ +package accounts + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// ErrDBClosed is an error that is returned when a database operation is +// performed on a closed database. +var ErrDBClosed = errors.New("database not open") + +// NewTestDB is a helper function that creates an BBolt database for testing. +func NewTestDB(t *testing.T) *BoltStore { + return NewTestDBFromPath(t, t.TempDir()) +} + +// NewTestDBFromPath is a helper function that creates a new BoltStore with a +// connection to an existing BBolt database for testing. +func NewTestDBFromPath(t *testing.T, dbPath string) *BoltStore { + store, err := NewBoltStore(dbPath, DBFilename) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, store.db.Close()) + }) + + return store +} diff --git a/terminal.go b/terminal.go index c92827ad4..84a7f42a4 100644 --- a/terminal.go +++ b/terminal.go @@ -214,6 +214,7 @@ type LightningTerminal struct { middleware *mid.Manager middlewareStarted bool + accountsStore *accounts.BoltStore accountService *accounts.InterceptorService accountServiceStarted bool @@ -412,8 +413,15 @@ func (g *LightningTerminal) start(ctx context.Context) error { ) } + g.accountsStore, err = accounts.NewBoltStore( + filepath.Dir(g.cfg.MacaroonPath), accounts.DBFilename, + ) + if err != nil { + return fmt.Errorf("error creating accounts store: %w", err) + } + g.accountService, err = accounts.NewService( - filepath.Dir(g.cfg.MacaroonPath), accountServiceErrCallback, + g.accountsStore, accountServiceErrCallback, ) if err != nil { return fmt.Errorf("error creating account service: %v", err) @@ -1421,6 +1429,14 @@ func (g *LightningTerminal) shutdownSubServers() error { } } + if g.accountsStore != nil { + err = g.accountsStore.Close() + if err != nil { + log.Errorf("Error closing accounts store: %v", err) + returnErr = err + } + } + if g.middlewareStarted { g.middleware.Stop() }