From 9316a591a468865254855806f4abff2fb2b42e61 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 11 Mar 2025 14:31:18 -0500 Subject: [PATCH 1/3] firewalldb+rules: thread contexts through KVStores methods --- firewalldb/kvstores.go | 18 +++++--- firewalldb/kvstores_test.go | 82 ++++++++++++++++++++++++++----------- rules/mock.go | 12 ++++-- rules/onchain_budget.go | 12 ++++-- 4 files changed, 88 insertions(+), 36 deletions(-) diff --git a/firewalldb/kvstores.go b/firewalldb/kvstores.go index 26ced86e7..5a1b9b180 100644 --- a/firewalldb/kvstores.go +++ b/firewalldb/kvstores.go @@ -61,13 +61,15 @@ type KVStores interface { // error, the transaction is rolled back. If the rollback fails, the // original error returned by f is still returned. If the commit fails, // the commit error is returned. - Update(f func(tx KVStoreTx) error) error + Update(ctx context.Context, f func(ctx context.Context, + tx KVStoreTx) error) error // View opens a database read transaction and executes the function f // with the transaction passed as a parameter. After f exits, the // transaction is rolled back. If f errors, its error is returned, not a // rollback error (if any occur). - View(f func(tx KVStoreTx) error) error + View(ctx context.Context, f func(ctx context.Context, + tx KVStoreTx) error) error } // KVStoreTx represents a database transaction that can be used for both read @@ -158,7 +160,9 @@ func (s *kvStores) beginTx(writable bool) (*kvStoreTx, error) { // returned. // // NOTE: this is part of the KVStores interface. -func (s *kvStores) Update(f func(tx KVStoreTx) error) error { +func (s *kvStores) Update(ctx context.Context, f func(ctx context.Context, + tx KVStoreTx) error) error { + tx, err := s.beginTx(true) if err != nil { return err @@ -171,7 +175,7 @@ func (s *kvStores) Update(f func(tx KVStoreTx) error) error { } }() - err = f(tx) + err = f(ctx, tx) if err != nil { // Want to return the original error, not a rollback error if // any occur. @@ -188,7 +192,9 @@ func (s *kvStores) Update(f func(tx KVStoreTx) error) error { // occur). // // NOTE: this is part of the KVStores interface. -func (s *kvStores) View(f func(tx KVStoreTx) error) error { +func (s *kvStores) View(ctx context.Context, f func(ctx context.Context, + tx KVStoreTx) error) error { + tx, err := s.beginTx(false) if err != nil { return err @@ -201,7 +207,7 @@ func (s *kvStores) View(f func(tx KVStoreTx) error) error { } }() - err = f(tx) + err = f(ctx, tx) rollbackErr := tx.boltTx.Rollback() if err != nil { return err diff --git a/firewalldb/kvstores_test.go b/firewalldb/kvstores_test.go index 86f650f76..252ccfde4 100644 --- a/firewalldb/kvstores_test.go +++ b/firewalldb/kvstores_test.go @@ -28,7 +28,7 @@ func TestKVStoreTxs(t *testing.T) { // Test that if an action fails midway through the transaction, then // it is rolled back. - err = store.Update(func(tx KVStoreTx) error { + err = store.Update(ctx, func(ctx context.Context, tx KVStoreTx) error { err := tx.Global().Set(ctx, "test", []byte{1}) if err != nil { return err @@ -46,7 +46,7 @@ func TestKVStoreTxs(t *testing.T) { require.Error(t, err) var v []byte - err = store.View(func(tx KVStoreTx) error { + err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error { b, err := tx.Global().Get(ctx, "test") if err != nil { return err @@ -94,7 +94,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) { store := db.GetKVStores("test-rule", [4]byte{1, 1, 1, 1}, featureName) - err = store.Update(func(tx KVStoreTx) error { + err = store.Update(ctx, func(ctx context.Context, tx KVStoreTx) error { // Set an item in the temp store. err := tx.LocalTemp().Set(ctx, "test", []byte{4, 3, 2}) if err != nil { @@ -112,7 +112,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) { v1 []byte v2 []byte ) - err = store.View(func(tx KVStoreTx) error { + err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error { b, err := tx.LocalTemp().Get(ctx, "test") if err != nil { return err @@ -144,7 +144,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) { // The temp store should no longer have the stored value but the perm // store should . - err = store.View(func(tx KVStoreTx) error { + err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error { b, err := tx.LocalTemp().Get(ctx, "test") if err != nil { return err @@ -188,21 +188,27 @@ func TestKVStoreNameSpaces(t *testing.T) { rulesDB3 := db.GetKVStores("test-rule", groupID2, "re-balance") // Test that the three ruleDBs share the same global space. - err = rulesDB1.Update(func(tx KVStoreTx) error { + err = rulesDB1.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + return tx.Global().Set( ctx, "test-global", []byte("global thing!"), ) }) require.NoError(t, err) - err = rulesDB2.Update(func(tx KVStoreTx) error { + err = rulesDB2.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + return tx.Global().Set( ctx, "test-global", []byte("different global thing!"), ) }) require.NoError(t, err) - err = rulesDB3.Update(func(tx KVStoreTx) error { + err = rulesDB3.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + return tx.Global().Set( ctx, "test-global", []byte("yet another global thing"), ) @@ -210,7 +216,9 @@ func TestKVStoreNameSpaces(t *testing.T) { require.NoError(t, err) var v []byte - err = rulesDB1.View(func(tx KVStoreTx) error { + err = rulesDB1.View(ctx, func(ctx context.Context, + tx KVStoreTx) error { + b, err := tx.Global().Get(ctx, "test-global") if err != nil { return err @@ -221,7 +229,9 @@ func TestKVStoreNameSpaces(t *testing.T) { require.NoError(t, err) require.True(t, bytes.Equal(v, []byte("yet another global thing"))) - err = rulesDB2.View(func(tx KVStoreTx) error { + err = rulesDB2.View(ctx, func(ctx context.Context, + tx KVStoreTx) error { + b, err := tx.Global().Get(ctx, "test-global") if err != nil { return err @@ -232,7 +242,9 @@ func TestKVStoreNameSpaces(t *testing.T) { require.NoError(t, err) require.True(t, bytes.Equal(v, []byte("yet another global thing"))) - err = rulesDB3.View(func(tx KVStoreTx) error { + err = rulesDB3.View(ctx, func(ctx context.Context, + tx KVStoreTx) error { + b, err := tx.Global().Get(ctx, "test-global") if err != nil { return err @@ -244,22 +256,30 @@ func TestKVStoreNameSpaces(t *testing.T) { require.True(t, bytes.Equal(v, []byte("yet another global thing"))) // Test that the feature space is not shared by any of the dbs. - err = rulesDB1.Update(func(tx KVStoreTx) error { + err = rulesDB1.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + return tx.Local().Set(ctx, "count", []byte("1")) }) require.NoError(t, err) - err = rulesDB2.Update(func(tx KVStoreTx) error { + err = rulesDB2.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + return tx.Local().Set(ctx, "count", []byte("2")) }) require.NoError(t, err) - err = rulesDB3.Update(func(tx KVStoreTx) error { + err = rulesDB3.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + return tx.Local().Set(ctx, "count", []byte("3")) }) require.NoError(t, err) - err = rulesDB1.View(func(tx KVStoreTx) error { + err = rulesDB1.View(ctx, func(ctx context.Context, + tx KVStoreTx) error { + b, err := tx.Local().Get(ctx, "count") if err != nil { return err @@ -270,7 +290,9 @@ func TestKVStoreNameSpaces(t *testing.T) { require.NoError(t, err) require.True(t, bytes.Equal(v, []byte("1"))) - err = rulesDB2.View(func(tx KVStoreTx) error { + err = rulesDB2.View(ctx, func(ctx context.Context, + tx KVStoreTx) error { + b, err := tx.Local().Get(ctx, "count") if err != nil { return err @@ -281,7 +303,9 @@ func TestKVStoreNameSpaces(t *testing.T) { require.NoError(t, err) require.True(t, bytes.Equal(v, []byte("2"))) - err = rulesDB3.View(func(tx KVStoreTx) error { + err = rulesDB3.View(ctx, func(ctx context.Context, + tx KVStoreTx) error { + b, err := tx.Local().Get(ctx, "count") if err != nil { return err @@ -299,22 +323,30 @@ func TestKVStoreNameSpaces(t *testing.T) { rulesDB2 = db.GetKVStores("test-rule", groupID1, "") rulesDB3 = db.GetKVStores("test-rule", groupID2, "") - err = rulesDB1.Update(func(tx KVStoreTx) error { + err = rulesDB1.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + return tx.Local().Set(ctx, "test", []byte("thing 1")) }) require.NoError(t, err) - err = rulesDB2.Update(func(tx KVStoreTx) error { + err = rulesDB2.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + return tx.Local().Set(ctx, "test", []byte("thing 2")) }) require.NoError(t, err) - err = rulesDB3.Update(func(tx KVStoreTx) error { + err = rulesDB3.Update(ctx, func(ctx context.Context, + tx KVStoreTx) error { + return tx.Local().Set(ctx, "test", []byte("thing 3")) }) require.NoError(t, err) - err = rulesDB1.View(func(tx KVStoreTx) error { + err = rulesDB1.View(ctx, func(ctx context.Context, + tx KVStoreTx) error { + b, err := tx.Local().Get(ctx, "test") if err != nil { return err @@ -325,7 +357,9 @@ func TestKVStoreNameSpaces(t *testing.T) { require.NoError(t, err) require.True(t, bytes.Equal(v, []byte("thing 2"))) - err = rulesDB2.View(func(tx KVStoreTx) error { + err = rulesDB2.View(ctx, func(ctx context.Context, + tx KVStoreTx) error { + b, err := tx.Local().Get(ctx, "test") if err != nil { return err @@ -336,7 +370,9 @@ func TestKVStoreNameSpaces(t *testing.T) { require.NoError(t, err) require.True(t, bytes.Equal(v, []byte("thing 2"))) - err = rulesDB3.View(func(tx KVStoreTx) error { + err = rulesDB3.View(ctx, func(ctx context.Context, + tx KVStoreTx) error { + b, err := tx.Local().Get(ctx, "test") if err != nil { return err diff --git a/rules/mock.go b/rules/mock.go index 6068a8510..a55022df9 100644 --- a/rules/mock.go +++ b/rules/mock.go @@ -47,12 +47,16 @@ type mockKVStores struct { tx *mockKVStoresTX } -func (m *mockKVStores) Update(f func(tx firewalldb.KVStoreTx) error) error { - return f(m.tx) +func (m *mockKVStores) Update(ctx context.Context, f func(ctx context.Context, + tx firewalldb.KVStoreTx) error) error { + + return f(ctx, m.tx) } -func (m *mockKVStores) View(f func(tx firewalldb.KVStoreTx) error) error { - return f(m.tx) +func (m *mockKVStores) View(ctx context.Context, f func(ctx context.Context, + tx firewalldb.KVStoreTx) error) error { + + return f(ctx, m.tx) } var _ firewalldb.KVStores = (*mockKVStores)(nil) diff --git a/rules/onchain_budget.go b/rules/onchain_budget.go index e4798cb69..248b2f699 100644 --- a/rules/onchain_budget.go +++ b/rules/onchain_budget.go @@ -531,7 +531,9 @@ func (o *OnChainBudgetEnforcer) handleBatchOpenChannelRequest( func (o *OnChainBudgetEnforcer) handlePendingPayment(ctx context.Context, request *onChainAction, reqID string) error { - return o.GetStores().Update(func(tx firewalldb.KVStoreTx) error { + return o.GetStores().Update(ctx, func(ctx context.Context, + tx firewalldb.KVStoreTx) error { + // First, we fetch the current state of the budget. spent, pending, err := o.getBudgetState(ctx, tx) if err != nil { @@ -586,7 +588,9 @@ type onChainAction struct { func (o *OnChainBudgetEnforcer) cancelPendingPayment( ctx context.Context) error { - return o.GetStores().Update(func(tx firewalldb.KVStoreTx) error { + return o.GetStores().Update(ctx, func(ctx context.Context, + tx firewalldb.KVStoreTx) error { + // First, we get our current budget state. _, pending, err := o.getBudgetState(ctx, tx) if err != nil { @@ -643,7 +647,9 @@ func (o *OnChainBudgetEnforcer) cancelPendingPayment( func (o *OnChainBudgetEnforcer) handlePaymentConfirmed( ctx context.Context) error { - return o.GetStores().Update(func(tx firewalldb.KVStoreTx) error { + return o.GetStores().Update(ctx, func(ctx context.Context, + tx firewalldb.KVStoreTx) error { + // First, we get our current budget state. complete, pending, err := o.getBudgetState(ctx, tx) if err != nil { From 575045191bff4b79e6dd4b170f27322c0dc4f39b Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 11 Mar 2025 14:33:49 -0500 Subject: [PATCH 2/3] firewalldb: introduce DBExecutor and use it for KVStores Here we introduce a generic DBExecutor interface with methods that take a generic T transaction. This is in preparation for the introducing a SQL transaction instead of a bolt one. --- firewalldb/interface.go | 21 +++++++++++++++++++++ firewalldb/kvstores.go | 18 +----------------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/firewalldb/interface.go b/firewalldb/interface.go index ff82eab68..1dc6951b7 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -14,3 +14,24 @@ type SessionDB interface { // GetSession returns the session for a specific id. GetSession(context.Context, session.ID) (*session.Session, error) } + +// DBExecutor provides an Update and View method that will allow the caller +// to perform atomic read and write transactions defined by PrivacyMapTx on the +// underlying BoltDB. +type DBExecutor[T any] interface { + // Update opens a database read/write transaction and executes the + // function f with the transaction passed as a parameter. After f exits, + // if f did not error, the transaction is committed. Otherwise, if f did + // error, the transaction is rolled back. If the rollback fails, the + // original error returned by f is still returned. If the commit fails, + // the commit error is returned. + Update(ctx context.Context, f func(ctx context.Context, + tx T) error) error + + // View opens a database read transaction and executes the function f + // with the transaction passed as a parameter. After f exits, the + // transaction is rolled back. If f errors, its error is returned, not a + // rollback error (if any occur). + View(ctx context.Context, f func(ctx context.Context, + tx T) error) error +} diff --git a/firewalldb/kvstores.go b/firewalldb/kvstores.go index 5a1b9b180..430abfe0c 100644 --- a/firewalldb/kvstores.go +++ b/firewalldb/kvstores.go @@ -54,23 +54,7 @@ var ( // KVStores provides an Update and View method that will allow the caller to // perform atomic read and write transactions on and of the key value stores // offered the KVStoreTx. -type KVStores interface { - // Update opens a database read/write transaction and executes the - // function f with the transaction passed as a parameter. After f exits, - // if f did not error, the transaction is committed. Otherwise, if f did - // error, the transaction is rolled back. If the rollback fails, the - // original error returned by f is still returned. If the commit fails, - // the commit error is returned. - Update(ctx context.Context, f func(ctx context.Context, - tx KVStoreTx) error) error - - // View opens a database read transaction and executes the function f - // with the transaction passed as a parameter. After f exits, the - // transaction is rolled back. If f errors, its error is returned, not a - // rollback error (if any occur). - View(ctx context.Context, f func(ctx context.Context, - tx KVStoreTx) error) error -} +type KVStores = DBExecutor[KVStoreTx] // KVStoreTx represents a database transaction that can be used for both read // and writes of the various different key value stores offered for the rule. From 627e7c4e5b872185dca63aa0b5fef9a12cda06dd Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 11 Mar 2025 14:40:57 -0500 Subject: [PATCH 3/3] firewalldb: use Bolt's existing Update/View methods --- firewalldb/kvstores.go | 72 ++++++++++-------------------------------- 1 file changed, 16 insertions(+), 56 deletions(-) diff --git a/firewalldb/kvstores.go b/firewalldb/kvstores.go index 430abfe0c..1dffd54cf 100644 --- a/firewalldb/kvstores.go +++ b/firewalldb/kvstores.go @@ -108,7 +108,7 @@ func (db *DB) GetKVStores(rule string, groupID session.ID, feature string) KVStores { return &kvStores{ - DB: db, + db: db.DB, ruleName: rule, groupID: groupID, featureName: feature, @@ -117,25 +117,12 @@ func (db *DB) GetKVStores(rule string, groupID session.ID, // kvStores implements the rules.KVStores interface. type kvStores struct { - *DB + db *bbolt.DB ruleName string groupID session.ID featureName string } -// beginTx starts db transaction. The transaction will be a read or read-write -// transaction depending on the value of the `writable` parameter. -func (s *kvStores) beginTx(writable bool) (*kvStoreTx, error) { - boltTx, err := s.Begin(writable) - if err != nil { - return nil, err - } - return &kvStoreTx{ - kvStores: s, - boltTx: boltTx, - }, nil -} - // Update opens a database read/write transaction and executes the function f // with the transaction passed as a parameter. After f exits, if f did not // error, the transaction is committed. Otherwise, if f did error, the @@ -144,30 +131,17 @@ func (s *kvStores) beginTx(writable bool) (*kvStoreTx, error) { // returned. // // NOTE: this is part of the KVStores interface. -func (s *kvStores) Update(ctx context.Context, f func(ctx context.Context, +func (s *kvStores) Update(ctx context.Context, fn func(ctx context.Context, tx KVStoreTx) error) error { - tx, err := s.beginTx(true) - if err != nil { - return err - } - - // Make sure the transaction rolls back in the event of a panic. - defer func() { - if tx != nil { - _ = tx.boltTx.Rollback() + return s.db.Update(func(tx *bbolt.Tx) error { + boltTx := &kvStoreTx{ + boltTx: tx, + kvStores: s, } - }() - err = f(ctx, tx) - if err != nil { - // Want to return the original error, not a rollback error if - // any occur. - _ = tx.boltTx.Rollback() - return err - } - - return tx.boltTx.Commit() + return fn(ctx, boltTx) + }) } // View opens a database read transaction and executes the function f with the @@ -176,31 +150,17 @@ func (s *kvStores) Update(ctx context.Context, f func(ctx context.Context, // occur). // // NOTE: this is part of the KVStores interface. -func (s *kvStores) View(ctx context.Context, f func(ctx context.Context, +func (s *kvStores) View(ctx context.Context, fn func(ctx context.Context, tx KVStoreTx) error) error { - tx, err := s.beginTx(false) - if err != nil { - return err - } - - // Make sure the transaction rolls back in the event of a panic. - defer func() { - if tx != nil { - _ = tx.boltTx.Rollback() + return s.db.View(func(tx *bbolt.Tx) error { + boltTx := &kvStoreTx{ + boltTx: tx, + kvStores: s, } - }() - err = f(ctx, tx) - rollbackErr := tx.boltTx.Rollback() - if err != nil { - return err - } - - if rollbackErr != nil { - return rollbackErr - } - return nil + return fn(ctx, boltTx) + }) } // getBucketFunc defines the signature of the bucket creation/fetching function