Skip to content

[sql-23] firewalldb: thread contexts through for kv-store interfaces #1001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions firewalldb/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,24 @@ type SessionDB interface {
// GetSession returns the session for a specific id.
GetSession(context.Context, session.ID) (*session.Session, error)
}

// DBExecutor provides an Update and View method that will allow the caller
// to perform atomic read and write transactions defined by PrivacyMapTx on the
// underlying BoltDB.
type DBExecutor[T any] interface {
// Update opens a database read/write transaction and executes the
// function f with the transaction passed as a parameter. After f exits,
// if f did not error, the transaction is committed. Otherwise, if f did
// error, the transaction is rolled back. If the rollback fails, the
// original error returned by f is still returned. If the commit fails,
// the commit error is returned.
Update(ctx context.Context, f func(ctx context.Context,
tx T) error) error

// View opens a database read transaction and executes the function f
// with the transaction passed as a parameter. After f exits, the
// transaction is rolled back. If f errors, its error is returned, not a
// rollback error (if any occur).
View(ctx context.Context, f func(ctx context.Context,
tx T) error) error
}
88 changes: 19 additions & 69 deletions firewalldb/kvstores.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,7 @@ var (
// KVStores provides an Update and View method that will allow the caller to
// perform atomic read and write transactions on and of the key value stores
// offered the KVStoreTx.
type KVStores interface {
// Update opens a database read/write transaction and executes the
// function f with the transaction passed as a parameter. After f exits,
// if f did not error, the transaction is committed. Otherwise, if f did
// error, the transaction is rolled back. If the rollback fails, the
// original error returned by f is still returned. If the commit fails,
// the commit error is returned.
Update(f func(tx KVStoreTx) error) error

// View opens a database read transaction and executes the function f
// with the transaction passed as a parameter. After f exits, the
// transaction is rolled back. If f errors, its error is returned, not a
// rollback error (if any occur).
View(f func(tx KVStoreTx) error) error
}
type KVStores = DBExecutor[KVStoreTx]

// KVStoreTx represents a database transaction that can be used for both read
// and writes of the various different key value stores offered for the rule.
Expand Down Expand Up @@ -122,7 +108,7 @@ func (db *DB) GetKVStores(rule string, groupID session.ID,
feature string) KVStores {

return &kvStores{
DB: db,
db: db.DB,
ruleName: rule,
groupID: groupID,
featureName: feature,
Expand All @@ -131,25 +117,12 @@ func (db *DB) GetKVStores(rule string, groupID session.ID,

// kvStores implements the rules.KVStores interface.
type kvStores struct {
*DB
db *bbolt.DB
ruleName string
groupID session.ID
featureName string
}

// beginTx starts db transaction. The transaction will be a read or read-write
// transaction depending on the value of the `writable` parameter.
func (s *kvStores) beginTx(writable bool) (*kvStoreTx, error) {
boltTx, err := s.Begin(writable)
if err != nil {
return nil, err
}
return &kvStoreTx{
kvStores: s,
boltTx: boltTx,
}, nil
}

// Update opens a database read/write transaction and executes the function f
// with the transaction passed as a parameter. After f exits, if f did not
// error, the transaction is committed. Otherwise, if f did error, the
Expand All @@ -158,28 +131,17 @@ func (s *kvStores) beginTx(writable bool) (*kvStoreTx, error) {
// returned.
//
// NOTE: this is part of the KVStores interface.
func (s *kvStores) Update(f func(tx KVStoreTx) error) error {
tx, err := s.beginTx(true)
if err != nil {
return err
}
func (s *kvStores) Update(ctx context.Context, fn func(ctx context.Context,
tx KVStoreTx) error) error {

// Make sure the transaction rolls back in the event of a panic.
defer func() {
if tx != nil {
_ = tx.boltTx.Rollback()
return s.db.Update(func(tx *bbolt.Tx) error {
boltTx := &kvStoreTx{
boltTx: tx,
kvStores: s,
}
}()

err = f(tx)
if err != nil {
// Want to return the original error, not a rollback error if
// any occur.
_ = tx.boltTx.Rollback()
return err
}

return tx.boltTx.Commit()
return fn(ctx, boltTx)
})
}

// View opens a database read transaction and executes the function f with the
Expand All @@ -188,29 +150,17 @@ func (s *kvStores) Update(f func(tx KVStoreTx) error) error {
// occur).
//
// NOTE: this is part of the KVStores interface.
func (s *kvStores) View(f func(tx KVStoreTx) error) error {
tx, err := s.beginTx(false)
if err != nil {
return err
}
func (s *kvStores) View(ctx context.Context, fn func(ctx context.Context,
tx KVStoreTx) error) error {

// Make sure the transaction rolls back in the event of a panic.
defer func() {
if tx != nil {
_ = tx.boltTx.Rollback()
return s.db.View(func(tx *bbolt.Tx) error {
boltTx := &kvStoreTx{
boltTx: tx,
kvStores: s,
}
}()

err = f(tx)
rollbackErr := tx.boltTx.Rollback()
if err != nil {
return err
}

if rollbackErr != nil {
return rollbackErr
}
return nil
return fn(ctx, boltTx)
})
}

// getBucketFunc defines the signature of the bucket creation/fetching function
Expand Down
82 changes: 59 additions & 23 deletions firewalldb/kvstores_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestKVStoreTxs(t *testing.T) {

// Test that if an action fails midway through the transaction, then
// it is rolled back.
err = store.Update(func(tx KVStoreTx) error {
err = store.Update(ctx, func(ctx context.Context, tx KVStoreTx) error {
err := tx.Global().Set(ctx, "test", []byte{1})
if err != nil {
return err
Expand All @@ -46,7 +46,7 @@ func TestKVStoreTxs(t *testing.T) {
require.Error(t, err)

var v []byte
err = store.View(func(tx KVStoreTx) error {
err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error {
b, err := tx.Global().Get(ctx, "test")
if err != nil {
return err
Expand Down Expand Up @@ -94,7 +94,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) {

store := db.GetKVStores("test-rule", [4]byte{1, 1, 1, 1}, featureName)

err = store.Update(func(tx KVStoreTx) error {
err = store.Update(ctx, func(ctx context.Context, tx KVStoreTx) error {
// Set an item in the temp store.
err := tx.LocalTemp().Set(ctx, "test", []byte{4, 3, 2})
if err != nil {
Expand All @@ -112,7 +112,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) {
v1 []byte
v2 []byte
)
err = store.View(func(tx KVStoreTx) error {
err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error {
b, err := tx.LocalTemp().Get(ctx, "test")
if err != nil {
return err
Expand Down Expand Up @@ -144,7 +144,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) {

// The temp store should no longer have the stored value but the perm
// store should .
err = store.View(func(tx KVStoreTx) error {
err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error {
b, err := tx.LocalTemp().Get(ctx, "test")
if err != nil {
return err
Expand Down Expand Up @@ -188,29 +188,37 @@ func TestKVStoreNameSpaces(t *testing.T) {
rulesDB3 := db.GetKVStores("test-rule", groupID2, "re-balance")

// Test that the three ruleDBs share the same global space.
err = rulesDB1.Update(func(tx KVStoreTx) error {
err = rulesDB1.Update(ctx, func(ctx context.Context,
tx KVStoreTx) error {

return tx.Global().Set(
ctx, "test-global", []byte("global thing!"),
)
})
require.NoError(t, err)

err = rulesDB2.Update(func(tx KVStoreTx) error {
err = rulesDB2.Update(ctx, func(ctx context.Context,
tx KVStoreTx) error {

return tx.Global().Set(
ctx, "test-global", []byte("different global thing!"),
)
})
require.NoError(t, err)

err = rulesDB3.Update(func(tx KVStoreTx) error {
err = rulesDB3.Update(ctx, func(ctx context.Context,
tx KVStoreTx) error {

return tx.Global().Set(
ctx, "test-global", []byte("yet another global thing"),
)
})
require.NoError(t, err)

var v []byte
err = rulesDB1.View(func(tx KVStoreTx) error {
err = rulesDB1.View(ctx, func(ctx context.Context,
tx KVStoreTx) error {

b, err := tx.Global().Get(ctx, "test-global")
if err != nil {
return err
Expand All @@ -221,7 +229,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
require.NoError(t, err)
require.True(t, bytes.Equal(v, []byte("yet another global thing")))

err = rulesDB2.View(func(tx KVStoreTx) error {
err = rulesDB2.View(ctx, func(ctx context.Context,
tx KVStoreTx) error {

b, err := tx.Global().Get(ctx, "test-global")
if err != nil {
return err
Expand All @@ -232,7 +242,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
require.NoError(t, err)
require.True(t, bytes.Equal(v, []byte("yet another global thing")))

err = rulesDB3.View(func(tx KVStoreTx) error {
err = rulesDB3.View(ctx, func(ctx context.Context,
tx KVStoreTx) error {

b, err := tx.Global().Get(ctx, "test-global")
if err != nil {
return err
Expand All @@ -244,22 +256,30 @@ func TestKVStoreNameSpaces(t *testing.T) {
require.True(t, bytes.Equal(v, []byte("yet another global thing")))

// Test that the feature space is not shared by any of the dbs.
err = rulesDB1.Update(func(tx KVStoreTx) error {
err = rulesDB1.Update(ctx, func(ctx context.Context,
tx KVStoreTx) error {

return tx.Local().Set(ctx, "count", []byte("1"))
})
require.NoError(t, err)

err = rulesDB2.Update(func(tx KVStoreTx) error {
err = rulesDB2.Update(ctx, func(ctx context.Context,
tx KVStoreTx) error {

return tx.Local().Set(ctx, "count", []byte("2"))
})
require.NoError(t, err)

err = rulesDB3.Update(func(tx KVStoreTx) error {
err = rulesDB3.Update(ctx, func(ctx context.Context,
tx KVStoreTx) error {

return tx.Local().Set(ctx, "count", []byte("3"))
})
require.NoError(t, err)

err = rulesDB1.View(func(tx KVStoreTx) error {
err = rulesDB1.View(ctx, func(ctx context.Context,
tx KVStoreTx) error {

b, err := tx.Local().Get(ctx, "count")
if err != nil {
return err
Expand All @@ -270,7 +290,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
require.NoError(t, err)
require.True(t, bytes.Equal(v, []byte("1")))

err = rulesDB2.View(func(tx KVStoreTx) error {
err = rulesDB2.View(ctx, func(ctx context.Context,
tx KVStoreTx) error {

b, err := tx.Local().Get(ctx, "count")
if err != nil {
return err
Expand All @@ -281,7 +303,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
require.NoError(t, err)
require.True(t, bytes.Equal(v, []byte("2")))

err = rulesDB3.View(func(tx KVStoreTx) error {
err = rulesDB3.View(ctx, func(ctx context.Context,
tx KVStoreTx) error {

b, err := tx.Local().Get(ctx, "count")
if err != nil {
return err
Expand All @@ -299,22 +323,30 @@ func TestKVStoreNameSpaces(t *testing.T) {
rulesDB2 = db.GetKVStores("test-rule", groupID1, "")
rulesDB3 = db.GetKVStores("test-rule", groupID2, "")

err = rulesDB1.Update(func(tx KVStoreTx) error {
err = rulesDB1.Update(ctx, func(ctx context.Context,
tx KVStoreTx) error {

return tx.Local().Set(ctx, "test", []byte("thing 1"))
})
require.NoError(t, err)

err = rulesDB2.Update(func(tx KVStoreTx) error {
err = rulesDB2.Update(ctx, func(ctx context.Context,
tx KVStoreTx) error {

return tx.Local().Set(ctx, "test", []byte("thing 2"))
})
require.NoError(t, err)

err = rulesDB3.Update(func(tx KVStoreTx) error {
err = rulesDB3.Update(ctx, func(ctx context.Context,
tx KVStoreTx) error {

return tx.Local().Set(ctx, "test", []byte("thing 3"))
})
require.NoError(t, err)

err = rulesDB1.View(func(tx KVStoreTx) error {
err = rulesDB1.View(ctx, func(ctx context.Context,
tx KVStoreTx) error {

b, err := tx.Local().Get(ctx, "test")
if err != nil {
return err
Expand All @@ -325,7 +357,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
require.NoError(t, err)
require.True(t, bytes.Equal(v, []byte("thing 2")))

err = rulesDB2.View(func(tx KVStoreTx) error {
err = rulesDB2.View(ctx, func(ctx context.Context,
tx KVStoreTx) error {

b, err := tx.Local().Get(ctx, "test")
if err != nil {
return err
Expand All @@ -336,7 +370,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
require.NoError(t, err)
require.True(t, bytes.Equal(v, []byte("thing 2")))

err = rulesDB3.View(func(tx KVStoreTx) error {
err = rulesDB3.View(ctx, func(ctx context.Context,
tx KVStoreTx) error {

b, err := tx.Local().Get(ctx, "test")
if err != nil {
return err
Expand Down
Loading
Loading