diff --git a/config_dev.go b/config_dev.go index fd62c5b86..c4ffa81d3 100644 --- a/config_dev.go +++ b/config_dev.go @@ -3,10 +3,12 @@ package terminal import ( + "fmt" "path/filepath" "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/firewalldb" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" ) @@ -87,7 +89,7 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { networkDir = filepath.Join(cfg.LitDir, cfg.Network) acctStore accounts.Store sessStore session.Store - closeFn func() error + closeFns = make(map[string]func() error) ) switch cfg.DatabaseBackend { @@ -106,7 +108,7 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { acctStore = accounts.NewSQLStore(sqlStore.BaseDB, clock) sessStore = session.NewSQLStore(sqlStore.BaseDB, clock) - closeFn = sqlStore.BaseDB.Close + closeFns["sqlite"] = sqlStore.BaseDB.Close case DatabaseBackendPostgres: sqlStore, err := db.NewPostgresStore(cfg.Postgres) @@ -116,7 +118,7 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { acctStore = accounts.NewSQLStore(sqlStore.BaseDB, clock) sessStore = session.NewSQLStore(sqlStore.BaseDB, clock) - closeFn = sqlStore.BaseDB.Close + closeFns["postgres"] = sqlStore.BaseDB.Close default: accountStore, err := accounts.NewBoltStore( @@ -126,6 +128,7 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { if err != nil { return nil, err } + closeFns["bbolt-accounts"] = accountStore.Close sessionStore, err := session.NewDB( networkDir, session.DBFilename, clock, accountStore, @@ -133,28 +136,38 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { if err != nil { return nil, err } + closeFns["bbolt-sessions"] = sessionStore.Close acctStore = accountStore sessStore = sessionStore - closeFn = func() error { - var returnErr error - err = accountStore.Close() - if err != nil { - returnErr = err - } - - err = sessionStore.Close() - if err != nil { - returnErr = err - } + } - return returnErr - } + firewallBoltDB, err := firewalldb.NewBoltDB( + networkDir, firewalldb.DBFilename, sessStore, + ) + if err != nil { + return nil, fmt.Errorf("error creating firewall BoltDB: %v", + err) } + closeFns["bbolt-firewalldb"] = firewallBoltDB.Close return &stores{ - accounts: acctStore, - sessions: sessStore, - close: closeFn, + accounts: acctStore, + sessions: sessStore, + firewall: firewalldb.NewDB(firewallBoltDB), + firewallBolt: firewallBoltDB, + close: func() error { + var returnErr error + for storeName, fn := range closeFns { + err := fn() + if err != nil { + log.Errorf("error closing %s store: %v", + storeName, err) + returnErr = err + } + } + + return returnErr + }, }, nil } diff --git a/config_prod.go b/config_prod.go index 4385c2b64..5493d2b3a 100644 --- a/config_prod.go +++ b/config_prod.go @@ -7,6 +7,7 @@ import ( "path/filepath" "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/firewalldb" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" ) @@ -46,18 +47,37 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { err) } + firewallDB, err := firewalldb.NewBoltDB( + networkDir, firewalldb.DBFilename, sessStore, + ) + if err != nil { + return nil, fmt.Errorf("error creating firewall DB: %v", err) + } + return &stores{ - accounts: acctStore, - sessions: sessStore, + accounts: acctStore, + sessions: sessStore, + firewallBolt: firewallDB, + firewall: firewalldb.NewDB(firewallDB), close: func() error { var returnErr error if err := acctStore.Close(); err != nil { returnErr = fmt.Errorf("error closing "+ "account store: %v", err) + + log.Error(returnErr.Error()) } if err := sessStore.Close(); err != nil { returnErr = fmt.Errorf("error closing "+ "session store: %v", err) + + log.Error(returnErr.Error()) + } + if err := firewallDB.Close(); err != nil { + returnErr = fmt.Errorf("error closing "+ + "firewall DB: %v", err) + + log.Error(returnErr.Error()) } return returnErr diff --git a/firewalldb/actions.go b/firewalldb/actions.go index 657f19cea..72e141556 100644 --- a/firewalldb/actions.go +++ b/firewalldb/actions.go @@ -117,7 +117,9 @@ type Action struct { } // AddAction serialises and adds an Action to the DB under the given sessionID. -func (db *DB) AddAction(sessionID session.ID, action *Action) (uint64, error) { +func (db *BoltDB) AddAction(sessionID session.ID, action *Action) (uint64, + error) { + var buf bytes.Buffer if err := SerializeAction(&buf, action); err != nil { return 0, err @@ -231,7 +233,7 @@ func getAction(actionsBkt *bbolt.Bucket, al *ActionLocator) (*Action, error) { // SetActionState finds the action specified by the ActionLocator and sets its // state to the given state. -func (db *DB) SetActionState(al *ActionLocator, state ActionState, +func (db *BoltDB) SetActionState(al *ActionLocator, state ActionState, errorReason string) error { if errorReason != "" && state != ActionStateError { @@ -293,7 +295,7 @@ type ListActionsFilterFn func(a *Action, reversed bool) (bool, bool) // The indexOffset and maxNum params can be used to control the number of // actions returned. The return values are the list of actions, the last index // and the total count (iff query.CountTotal is set). -func (db *DB) ListActions(filterFn ListActionsFilterFn, +func (db *BoltDB) ListActions(filterFn ListActionsFilterFn, query *ListActionsQuery) ([]*Action, uint64, uint64, error) { var ( @@ -345,7 +347,7 @@ func (db *DB) ListActions(filterFn ListActionsFilterFn, // ListSessionActions returns a list of the given session's Actions that pass // the filterFn requirements. -func (db *DB) ListSessionActions(sessionID session.ID, +func (db *BoltDB) ListSessionActions(sessionID session.ID, filterFn ListActionsFilterFn, query *ListActionsQuery) ([]*Action, uint64, uint64, error) { @@ -391,7 +393,7 @@ func (db *DB) ListSessionActions(sessionID session.ID, // pass the filterFn requirements. // // TODO: update to allow for pagination. -func (db *DB) ListGroupActions(ctx context.Context, groupID session.ID, +func (db *BoltDB) ListGroupActions(ctx context.Context, groupID session.ID, filterFn ListActionsFilterFn) ([]*Action, error) { if filterFn == nil { @@ -589,7 +591,7 @@ type ActionReadDBGetter interface { } // GetActionsReadDB is a method on DB that constructs an ActionsReadDB. -func (db *DB) GetActionsReadDB(groupID session.ID, +func (db *BoltDB) GetActionsReadDB(groupID session.ID, featureName string) ActionsReadDB { return &allActionsReadDB{ @@ -601,7 +603,7 @@ func (db *DB) GetActionsReadDB(groupID session.ID, // allActionsReadDb is an implementation of the ActionsReadDB. type allActionsReadDB struct { - db *DB + db *BoltDB groupID session.ID featureName string } diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index 8b66529f4..63a77ec45 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -43,7 +43,7 @@ var ( func TestActionStorage(t *testing.T) { tmpDir := t.TempDir() - db, err := NewDB(tmpDir, "test.db", nil) + db, err := NewBoltDB(tmpDir, "test.db", nil) require.NoError(t, err) t.Cleanup(func() { _ = db.Close() @@ -151,7 +151,7 @@ func TestActionStorage(t *testing.T) { func TestListActions(t *testing.T) { tmpDir := t.TempDir() - db, err := NewDB(tmpDir, "test.db", nil) + db, err := NewBoltDB(tmpDir, "test.db", nil) require.NoError(t, err) t.Cleanup(func() { _ = db.Close() @@ -353,7 +353,7 @@ func TestListGroupActions(t *testing.T) { index.AddPair(sessionID1, group1) index.AddPair(sessionID2, group1) - db, err := NewDB(t.TempDir(), "test.db", index) + db, err := NewBoltDB(t.TempDir(), "test.db", index) require.NoError(t, err) t.Cleanup(func() { _ = db.Close() diff --git a/firewalldb/db.go b/firewalldb/db.go index 2c5a1d7a8..fe18cbb70 100644 --- a/firewalldb/db.go +++ b/firewalldb/db.go @@ -1,152 +1,50 @@ package firewalldb import ( - "encoding/binary" - "errors" + "context" "fmt" - "os" - "path/filepath" - "time" + "sync" - "go.etcd.io/bbolt" -) - -const ( - // DBFilename is the default filename of the rules' database. - DBFilename = "rules.db" - - // dbFilePermission is the default permission the rules' database file - // is created with. - dbFilePermission = 0600 - - // DefaultRulesDBTimeout is the default maximum time we wait for the - // db bbolt database to be opened. If the database is already - // opened by another process, the unique lock cannot be obtained. With - // the timeout we error out after the given time instead of just - // blocking for forever. - DefaultRulesDBTimeout = 5 * time.Second + "github.com/lightningnetwork/lnd/fn" ) var ( - // byteOrder is the default byte order we'll use for serialization - // within the database. - byteOrder = binary.BigEndian - // ErrNoSuchKeyFound is returned when there is no key-value pair found // for the given key. ErrNoSuchKeyFound = fmt.Errorf("no such key found") ) -// DB is a bolt-backed persistent store. +// DB manages the firewall rules database. type DB struct { - *bbolt.DB - - sessionIDIndex SessionDB -} + started sync.Once + stopped sync.Once -// NewDB creates a new bolt database that can be found at the given directory. -func NewDB(dir, fileName string, sessionIDIndex SessionDB) (*DB, error) { - firstInit := false - path := filepath.Join(dir, fileName) - - // If the database file does not exist yet, create its directory. - if !fileExists(path) { - if err := os.MkdirAll(dir, 0700); err != nil { - return nil, err - } - firstInit = true - } - - db, err := initDB(path, firstInit) - if err != nil { - return nil, err - } - - // Attempt to sync the database's current version with the latest known - // version available. - if err := syncVersions(db); err != nil { - return nil, err - } + RulesDB - return &DB{ - DB: db, - sessionIDIndex: sessionIDIndex, - }, nil + cancel fn.Option[context.CancelFunc] } -// fileExists reports whether the named file or directory exists. -func fileExists(path string) bool { - if _, err := os.Stat(path); err != nil { - if os.IsNotExist(err) { - return false - } +// NewDB creates a new firewall database. For now, it only contains the +// underlying rules' database. +func NewDB(kvdb RulesDB) *DB { + return &DB{ + RulesDB: kvdb, } - return true } -// initDB initializes all the required top-level buckets for the database. -func initDB(filepath string, firstInit bool) (*bbolt.DB, error) { - db, err := bbolt.Open(filepath, dbFilePermission, &bbolt.Options{ - Timeout: DefaultRulesDBTimeout, +// Start starts the firewall database. +func (db *DB) Start(ctx context.Context) error { + db.started.Do(func() { + _, cancel := context.WithCancel(ctx) + db.cancel = fn.Some(cancel) }) - if err == bbolt.ErrTimeout { - return nil, fmt.Errorf("error while trying to open %s: timed "+ - "out after %v when trying to obtain exclusive lock", - filepath, DefaultRulesDBTimeout) - } - if err != nil { - return nil, err - } - - err = db.Update(func(tx *bbolt.Tx) error { - if firstInit { - metadataBucket, err := tx.CreateBucketIfNotExists( - metadataBucketKey, - ) - if err != nil { - return err - } - err = setDBVersion(metadataBucket, latestDBVersion) - if err != nil { - return err - } - } - - rulesBucket, err := tx.CreateBucketIfNotExists(rulesBucketKey) - if err != nil { - return err - } - // Delete everything under the "temp" key if such a bucket - // exists. - err = rulesBucket.DeleteBucket(tempBucketKey) - if err != nil && !errors.Is(err, bbolt.ErrBucketNotFound) { - return err - } - - actionsBucket, err := tx.CreateBucketIfNotExists( - actionsBucketKey, - ) - if err != nil { - return err - } - - _, err = actionsBucket.CreateBucketIfNotExists(actionsKey) - if err != nil { - return err - } - - _, err = actionsBucket.CreateBucketIfNotExists(actionsIndex) - if err != nil { - return err - } + return db.DeleteTempKVStores(ctx) +} - _, err = tx.CreateBucketIfNotExists(privacyBucketKey) - return err - }) - if err != nil { - return nil, err - } +// Stop stops the firewall database operations. +func (db *DB) Stop() error { + db.stopped.Do(func() {}) - return db, nil + return nil } diff --git a/firewalldb/interface.go b/firewalldb/interface.go index 1dc6951b7..3a0c4ddca 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -35,3 +35,60 @@ type DBExecutor[T any] interface { View(ctx context.Context, f func(ctx context.Context, tx T) error) error } + +// KVStores provides an Update and View method that will allow the caller to +// perform atomic read and write transactions on and of the key value stores +// offered the KVStoreTx. +type KVStores = DBExecutor[KVStoreTx] + +// KVStoreTx represents a database transaction that can be used for both read +// and writes of the various different key value stores offered for the rule. +type KVStoreTx interface { + // Global returns a persisted global, rule-name indexed, kv store. A + // rule with a given name will have access to this store independent of + // group ID or feature. + Global() KVStore + + // Local returns a persisted local kv store for the rule. Depending on + // how the implementation is initialised, this will either be under the + // group ID namespace or the group ID _and_ feature name namespace. + Local() KVStore + + // GlobalTemp is similar to the Global store except that its contents + // is cleared upon restart of the database. The reason persisting the + // temporary store changes instead of just keeping an in-memory store is + // that we can then guarantee atomicity if changes are made to both + // the permanent and temporary stores. + GlobalTemp() KVStore + + // LocalTemp is similar to the Local store except that its contents is + // cleared upon restart of the database. The reason persisting the + // temporary store changes instead of just keeping an in-memory store is + // that we can then guarantee atomicity if changes are made to both + // the permanent and temporary stores. + LocalTemp() KVStore +} + +// KVStore is in interface representing a key value store. It allows us to +// abstract away the details of the data storage method. +type KVStore interface { + // Get fetches the value under the given key from the underlying kv + // store. If no value is found, nil is returned. + Get(ctx context.Context, key string) ([]byte, error) + + // Set sets the given key-value pair in the underlying kv store. + Set(ctx context.Context, key string, value []byte) error + + // Del deletes the value under the given key in the underlying kv store. + Del(ctx context.Context, key string) error +} + +// RulesDB can be used to initialise a new rules.KVStores. +type RulesDB interface { + // GetKVStores constructs a new rules.KVStores in a namespace defined + // by the rule name, group ID and feature name. + GetKVStores(rule string, groupID session.ID, feature string) KVStores + + // DeleteTempKVStores deletes all temporary kv stores. + DeleteTempKVStores(ctx context.Context) error +} diff --git a/firewalldb/kvdb_store.go b/firewalldb/kvdb_store.go index d4ce79f20..99497a27d 100644 --- a/firewalldb/kvdb_store.go +++ b/firewalldb/kvdb_store.go @@ -2,10 +2,147 @@ package firewalldb import ( "context" + "encoding/binary" + "fmt" + "os" + "path/filepath" + "time" "go.etcd.io/bbolt" ) +const ( + // DBFilename is the default filename of the rules' database. + DBFilename = "rules.db" + + // dbFilePermission is the default permission the rules' database file + // is created with. + dbFilePermission = 0600 + + // DefaultRulesDBTimeout is the default maximum time we wait for the + // db bbolt database to be opened. If the database is already + // opened by another process, the unique lock cannot be obtained. With + // the timeout we error out after the given time instead of just + // blocking for forever. + DefaultRulesDBTimeout = 5 * time.Second +) + +var ( + // byteOrder is the default byte order we'll use for serialization + // within the database. + byteOrder = binary.BigEndian +) + +// BoltDB is a bolt-backed persistent store. +type BoltDB struct { + *bbolt.DB + + sessionIDIndex SessionDB +} + +// NewBoltDB creates a new bolt database that can be found at the given +// directory. +func NewBoltDB(dir, fileName string, sessionIDIndex SessionDB) (*BoltDB, + error) { + + firstInit := false + path := filepath.Join(dir, fileName) + + // If the database file does not exist yet, create its directory. + if !fileExists(path) { + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, err + } + firstInit = true + } + + db, err := initDB(path, firstInit) + if err != nil { + return nil, err + } + + // Attempt to sync the database's current version with the latest known + // version available. + if err := syncVersions(db); err != nil { + return nil, err + } + + return &BoltDB{ + DB: db, + sessionIDIndex: sessionIDIndex, + }, nil +} + +// fileExists reports whether the named file or directory exists. +func fileExists(path string) bool { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return false + } + } + return true +} + +// initDB initializes all the required top-level buckets for the database. +func initDB(filepath string, firstInit bool) (*bbolt.DB, error) { + db, err := bbolt.Open(filepath, dbFilePermission, &bbolt.Options{ + Timeout: DefaultRulesDBTimeout, + }) + if err == bbolt.ErrTimeout { + return nil, fmt.Errorf("error while trying to open %s: timed "+ + "out after %v when trying to obtain exclusive lock", + filepath, DefaultRulesDBTimeout) + } + if err != nil { + return nil, err + } + + err = db.Update(func(tx *bbolt.Tx) error { + if firstInit { + metadataBucket, err := tx.CreateBucketIfNotExists( + metadataBucketKey, + ) + if err != nil { + return err + } + err = setDBVersion(metadataBucket, latestDBVersion) + if err != nil { + return err + } + } + + _, err := tx.CreateBucketIfNotExists(rulesBucketKey) + if err != nil { + return err + } + + actionsBucket, err := tx.CreateBucketIfNotExists( + actionsBucketKey, + ) + if err != nil { + return err + } + + _, err = actionsBucket.CreateBucketIfNotExists(actionsKey) + if err != nil { + return err + } + + _, err = actionsBucket.CreateBucketIfNotExists(actionsIndex) + if err != nil { + return err + } + + _, err = tx.CreateBucketIfNotExists(privacyBucketKey) + return err + }) + if err != nil { + return nil, err + } + + return db, nil +} + // kvdbExecutor is a concrete implementation of the DBExecutor interface that // uses a bbolt database as its backing store. type kvdbExecutor[T any] struct { diff --git a/firewalldb/kvstores.go b/firewalldb/kvstores_kvdb.go similarity index 79% rename from firewalldb/kvstores.go rename to firewalldb/kvstores_kvdb.go index 9dad0a0cc..a7b9d4765 100644 --- a/firewalldb/kvstores.go +++ b/firewalldb/kvstores_kvdb.go @@ -2,6 +2,7 @@ package firewalldb import ( "context" + "errors" "github.com/lightninglabs/lightning-terminal/session" "go.etcd.io/bbolt" @@ -51,60 +52,8 @@ var ( featureKVStoreBucketKey = []byte("feature-kv-store") ) -// KVStores provides an Update and View method that will allow the caller to -// perform atomic read and write transactions on and of the key value stores -// offered the KVStoreTx. -type KVStores = DBExecutor[KVStoreTx] - -// KVStoreTx represents a database transaction that can be used for both read -// and writes of the various different key value stores offered for the rule. -type KVStoreTx interface { - // Global returns a persisted global, rule-name indexed, kv store. A - // rule with a given name will have access to this store independent of - // group ID or feature. - Global() KVStore - - // Local returns a persisted local kv store for the rule. Depending on - // how the implementation is initialised, this will either be under the - // group ID namespace or the group ID _and_ feature name namespace. - Local() KVStore - - // GlobalTemp is similar to the Global store except that its contents - // is cleared upon restart of the database. The reason persisting the - // temporary store changes instead of just keeping an in-memory store is - // that we can then guarantee atomicity if changes are made to both - // the permanent and temporary stores. - GlobalTemp() KVStore - - // LocalTemp is similar to the Local store except that its contents is - // cleared upon restart of the database. The reason persisting the - // temporary store changes instead of just keeping an in-memory store is - // that we can then guarantee atomicity if changes are made to both - // the permanent and temporary stores. - LocalTemp() KVStore -} - -// KVStore is in interface representing a key value store. It allows us to -// abstract away the details of the data storage method. -type KVStore interface { - // Get fetches the value under the given key from the underlying kv - // store. If no value is found, nil is returned. - Get(ctx context.Context, key string) ([]byte, error) - - // Set sets the given key-value pair in the underlying kv store. - Set(ctx context.Context, key string, value []byte) error - - // Del deletes the value under the given key in the underlying kv store. - Del(ctx context.Context, key string) error -} - -// RulesDB can be used to initialise a new rules.KVStores. -type RulesDB interface { - GetKVStores(rule string, groupID session.ID, feature string) KVStores -} - // GetKVStores constructs a new rules.KVStores backed by a bbolt db. -func (db *DB) GetKVStores(rule string, groupID session.ID, +func (db *BoltDB) GetKVStores(rule string, groupID session.ID, feature string) KVStores { return &kvdbExecutor[KVStoreTx]{ @@ -122,6 +71,25 @@ func (db *DB) GetKVStores(rule string, groupID session.ID, } } +// DeleteTempKVStores deletes all kv-stores in the temporary namespace. +func (db *BoltDB) DeleteTempKVStores(_ context.Context) error { + return db.Update(func(tx *bbolt.Tx) error { + rulesBucket, err := tx.CreateBucketIfNotExists(rulesBucketKey) + if err != nil { + return err + } + + // Delete everything under the "temp" key if such a bucket + // exists. + err = rulesBucket.DeleteBucket(tempBucketKey) + if err != nil && !errors.Is(err, bbolt.ErrBucketNotFound) { + return err + } + + return nil + }) +} + // kvStores implements the rules.KVStores interface. type kvStores struct { ruleName string diff --git a/firewalldb/kvstores_test.go b/firewalldb/kvstores_test.go index 252ccfde4..0742a40a7 100644 --- a/firewalldb/kvstores_test.go +++ b/firewalldb/kvstores_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "fmt" - "os" "testing" "github.com/lightninglabs/lightning-terminal/session" @@ -15,20 +14,15 @@ import ( // atomic access to the db. If anything fails in the middle of an `Update` // function, then all the changes prior should be rolled back. func TestKVStoreTxs(t *testing.T) { - ctx := context.Background() - tmpDir := t.TempDir() - - db, err := NewDB(tmpDir, "test.db", nil) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + t.Parallel() + ctx := context.Background() + db := NewTestDB(t) store := db.GetKVStores("AutoFees", [4]byte{1, 1, 1, 1}, "auto-fees") // Test that if an action fails midway through the transaction, then // it is rolled back. - err = store.Update(ctx, func(ctx context.Context, tx KVStoreTx) error { + err := store.Update(ctx, func(ctx context.Context, tx KVStoreTx) error { err := tx.Global().Set(ctx, "test", []byte{1}) if err != nil { return err @@ -64,10 +58,14 @@ func TestKVStoreTxs(t *testing.T) { // KV stores and the session feature level stores. func TestTempAndPermStores(t *testing.T) { t.Run("session level kv store", func(t *testing.T) { + t.Parallel() + testTempAndPermStores(t, false) }) t.Run("session feature level kv store", func(t *testing.T) { + t.Parallel() + testTempAndPermStores(t, true) }) } @@ -79,22 +77,23 @@ func TestTempAndPermStores(t *testing.T) { // session level KV stores. func testTempAndPermStores(t *testing.T, featureSpecificStore bool) { ctx := context.Background() - tmpDir := t.TempDir() var featureName string if featureSpecificStore { featureName = "auto-fees" } - db, err := NewDB(tmpDir, "test.db", nil) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + store := NewTestDB(t) + db := NewDB(store) + require.NoError(t, db.Start(ctx)) + + kvstores := db.GetKVStores( + "test-rule", [4]byte{1, 1, 1, 1}, featureName, + ) - store := db.GetKVStores("test-rule", [4]byte{1, 1, 1, 1}, featureName) + err := kvstores.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { - err = store.Update(ctx, func(ctx context.Context, tx KVStoreTx) error { // Set an item in the temp store. err := tx.LocalTemp().Set(ctx, "test", []byte{4, 3, 2}) if err != nil { @@ -112,7 +111,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) { v1 []byte v2 []byte ) - err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error { + err = kvstores.View(ctx, func(ctx context.Context, tx KVStoreTx) error { b, err := tx.LocalTemp().Get(ctx, "test") if err != nil { return err @@ -130,21 +129,19 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) { require.True(t, bytes.Equal(v1, []byte{4, 3, 2})) require.True(t, bytes.Equal(v2, []byte{6, 5, 4})) - // Close the db. - require.NoError(t, db.Close()) - - // Restart it. - db, err = NewDB(tmpDir, "test.db", nil) - require.NoError(t, err) + // Re-init the DB. + require.NoError(t, db.Stop()) + db = NewDB(store) + require.NoError(t, db.Start(ctx)) t.Cleanup(func() { - _ = db.Close() - _ = os.RemoveAll(tmpDir) + require.NoError(t, db.Stop()) }) - store = db.GetKVStores("test-rule", [4]byte{1, 1, 1, 1}, featureName) + + kvstores = db.GetKVStores("test-rule", [4]byte{1, 1, 1, 1}, featureName) // The temp store should no longer have the stored value but the perm // store should . - err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error { + err = kvstores.View(ctx, func(ctx context.Context, tx KVStoreTx) error { b, err := tx.LocalTemp().Get(ctx, "test") if err != nil { return err @@ -165,14 +162,9 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) { // TestKVStoreNameSpaces tests that the various name spaces are used correctly. func TestKVStoreNameSpaces(t *testing.T) { + t.Parallel() ctx := context.Background() - tmpDir := t.TempDir() - - db, err := NewDB(tmpDir, "test.db", nil) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + db := NewTestDB(t) var ( groupID1 = intToSessionID(1) @@ -188,7 +180,7 @@ func TestKVStoreNameSpaces(t *testing.T) { rulesDB3 := db.GetKVStores("test-rule", groupID2, "re-balance") // Test that the three ruleDBs share the same global space. - err = rulesDB1.Update(ctx, func(ctx context.Context, + err := rulesDB1.Update(ctx, func(ctx context.Context, tx KVStoreTx) error { return tx.Global().Set( diff --git a/firewalldb/privacy_mapper.go b/firewalldb/privacy_mapper.go index ab8e60e40..428169519 100644 --- a/firewalldb/privacy_mapper.go +++ b/firewalldb/privacy_mapper.go @@ -41,7 +41,7 @@ type NewPrivacyMapDB func(groupID session.ID) PrivacyMapDB // PrivacyDB constructs a PrivacyMapDB that will be indexed under the given // group ID key. -func (db *DB) PrivacyDB(groupID session.ID) PrivacyMapDB { +func (db *BoltDB) PrivacyDB(groupID session.ID) PrivacyMapDB { return &kvdbExecutor[PrivacyMapTx]{ db: db.DB, wrapTx: func(tx *bbolt.Tx) PrivacyMapTx { diff --git a/firewalldb/privacy_mapper_test.go b/firewalldb/privacy_mapper_test.go index 7be4d3b64..8242f2da6 100644 --- a/firewalldb/privacy_mapper_test.go +++ b/firewalldb/privacy_mapper_test.go @@ -14,7 +14,7 @@ func TestPrivacyMapStorage(t *testing.T) { ctx := context.Background() tmpDir := t.TempDir() - db, err := NewDB(tmpDir, "test.db", nil) + db, err := NewBoltDB(tmpDir, "test.db", nil) require.NoError(t, err) t.Cleanup(func() { _ = db.Close() @@ -188,7 +188,7 @@ func TestPrivacyMapTxs(t *testing.T) { ctx := context.Background() tmpDir := t.TempDir() - db, err := NewDB(tmpDir, "test.db", nil) + db, err := NewBoltDB(tmpDir, "test.db", nil) require.NoError(t, err) t.Cleanup(func() { _ = db.Close() diff --git a/firewalldb/test_kvdb.go b/firewalldb/test_kvdb.go new file mode 100644 index 000000000..0757786eb --- /dev/null +++ b/firewalldb/test_kvdb.go @@ -0,0 +1,25 @@ +package firewalldb + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// NewTestDB is a helper function that creates an BBolt database for testing. +func NewTestDB(t *testing.T) *BoltDB { + return NewTestDBFromPath(t, t.TempDir()) +} + +// NewTestDBFromPath is a helper function that creates a new BoltStore with a +// connection to an existing BBolt database for testing. +func NewTestDBFromPath(t *testing.T, dbPath string) *BoltDB { + store, err := NewBoltDB(dbPath, DBFilename, nil) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, store.DB.Close()) + }) + + return store +} diff --git a/session_rpcserver.go b/session_rpcserver.go index 652196f59..b20700948 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -63,7 +63,7 @@ type sessionRpcServerConfig struct { superMacBaker litmac.Baker firstConnectionDeadline time.Duration permMgr *perms.Manager - actionsDB *firewalldb.DB + actionsDB *firewalldb.BoltDB autopilot autopilotserver.Autopilot ruleMgrs rules.ManagerSet privMap firewalldb.NewPrivacyMapDB diff --git a/terminal.go b/terminal.go index 37dcc48b4..4992f4158 100644 --- a/terminal.go +++ b/terminal.go @@ -223,8 +223,6 @@ type LightningTerminal struct { stores *stores - firewallDB *firewalldb.DB - restHandler http.Handler restCancel func() } @@ -241,6 +239,9 @@ type stores struct { accounts accounts.Store sessions session.Store + firewall *firewalldb.DB + firewallBolt *firewalldb.BoltDB + // close is a callback that can be used to close all the stores in the // stores struct. close func() error @@ -436,6 +437,10 @@ func (g *LightningTerminal) start(ctx context.Context) error { return fmt.Errorf("could not create stores: %v", err) } + if err := g.stores.firewall.Start(ctx); err != nil { + return fmt.Errorf("could not start firewall DB: %v", err) + } + g.accountService, err = accounts.NewService( g.stores.accounts, accountServiceErrCallback, ) @@ -457,13 +462,6 @@ func (g *LightningTerminal) start(ctx context.Context) error { g.ruleMgrs = rules.NewRuleManagerSet() - g.firewallDB, err = firewalldb.NewDB( - networkDir, firewalldb.DBFilename, g.stores.sessions, - ) - if err != nil { - return fmt.Errorf("error creating firewall DB: %v", err) - } - if !g.cfg.Autopilot.Disable { if g.cfg.Autopilot.Address == "" && len(g.cfg.Autopilot.DialOpts) == 0 { @@ -517,10 +515,10 @@ func (g *LightningTerminal) start(ctx context.Context) error { superMacBaker: superMacBaker, firstConnectionDeadline: g.cfg.FirstLNCConnDeadline, permMgr: g.permsMgr, - actionsDB: g.firewallDB, + actionsDB: g.stores.firewallBolt, autopilot: g.autopilotClient, ruleMgrs: g.ruleMgrs, - privMap: g.firewallDB.PrivacyDB, + privMap: g.stores.firewallBolt.PrivacyDB, }) if err != nil { return fmt.Errorf("could not create new session rpc "+ @@ -1079,14 +1077,14 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, } requestLogger, err := firewall.NewRequestLogger( - g.cfg.Firewall.RequestLogger, g.firewallDB, + g.cfg.Firewall.RequestLogger, g.stores.firewallBolt, ) if err != nil { return fmt.Errorf("error creating new request logger") } privacyMapper := firewall.NewPrivacyMapper( - g.firewallDB.PrivacyDB, firewall.CryptoRandIntn, + g.stores.firewallBolt.PrivacyDB, firewall.CryptoRandIntn, g.stores.sessions, ) @@ -1098,7 +1096,8 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, if !g.cfg.Autopilot.Disable { ruleEnforcer := firewall.NewRuleEnforcer( - g.firewallDB, g.firewallDB, g.stores.sessions, + g.stores.firewall, g.stores.firewallBolt, + g.stores.sessions, g.autopilotClient.ListFeaturePerms, g.permsMgr, g.lndClient.NodePubkey, g.lndClient.Router, @@ -1108,7 +1107,7 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, reqID, firewalldb.ActionStateError, reason, ) - }, g.firewallDB.PrivacyDB, + }, g.stores.firewallBolt.PrivacyDB, ) mw = append(mw, ruleEnforcer) @@ -1443,13 +1442,6 @@ func (g *LightningTerminal) shutdownSubServers() error { g.middleware.Stop() } - if g.firewallDB != nil { - if err := g.firewallDB.Close(); err != nil { - log.Errorf("Error closing rules DB: %v", err) - returnErr = err - } - } - if g.ruleMgrs != nil { if err := g.ruleMgrs.Stop(); err != nil { log.Errorf("Error stopping rule manager set: %v", err) @@ -1458,6 +1450,11 @@ func (g *LightningTerminal) shutdownSubServers() error { } if g.stores != nil { + if err := g.stores.firewall.Stop(); err != nil { + log.Errorf("Error stoppint firewall DB: %v", err) + returnErr = err + } + err = g.stores.close() if err != nil { log.Errorf("Error closing stores: %v", err)