diff --git a/firewall/privacy_mapper.go b/firewall/privacy_mapper.go index 26e053500..cbd8a8da4 100644 --- a/firewall/privacy_mapper.go +++ b/firewall/privacy_mapper.go @@ -325,17 +325,19 @@ func handleGetInfoResponse(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.GetInfoResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.GetInfoResponse) ( + return func(ctx context.Context, r *lnrpc.GetInfoResponse) ( proto.Message, error) { // We hide the pubkey unless it is disabled. pseudoPubKey := r.IdentityPubkey if !flags.Contains(session.ClearPubkeys) { - err := db.Update( - func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, + func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error - pseudoPubKey, err = firewalldb.HideString( - tx, r.IdentityPubkey, + pseudoPubKey, err = firewalldb.HideString( //nolint:lll + ctx, tx, r.IdentityPubkey, ) return err @@ -377,14 +379,16 @@ func handleFwdHistoryResponse(db firewalldb.PrivacyMapDB, randIntn func(int) (int, error)) func(ctx context.Context, r *lnrpc.ForwardingHistoryResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.ForwardingHistoryResponse) ( + return func(ctx context.Context, r *lnrpc.ForwardingHistoryResponse) ( proto.Message, error) { fwdEvents := make( []*lnrpc.ForwardingEvent, len(r.ForwardingEvents), ) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, fe := range r.ForwardingEvents { var err error @@ -393,14 +397,14 @@ func handleFwdHistoryResponse(db firewalldb.PrivacyMapDB, if !flags.Contains(session.ClearChanIDs) { // Deterministically hide channel ids. chanIn, err = firewalldb.HideUint64( - tx, chanIn, + ctx, tx, chanIn, ) if err != nil { return err } chanOut, err = firewalldb.HideUint64( - tx, chanOut, + ctx, tx, chanOut, ) if err != nil { return err @@ -487,14 +491,16 @@ func handleFeeReportResponse(db firewalldb.PrivacyMapDB, chanFees := make([]*lnrpc.ChannelFeeReport, len(r.ChannelFees)) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error for i, c := range r.ChannelFees { chanID := c.ChanId if !flags.Contains(session.ClearChanIDs) { chanID, err = firewalldb.HideUint64( - tx, chanID, + ctx, tx, chanID, ) if err != nil { return err @@ -504,7 +510,7 @@ func handleFeeReportResponse(db firewalldb.PrivacyMapDB, chanPoint := c.ChannelPoint if !flags.Contains(session.ClearChanIDs) { chanPoint, err = firewalldb.HideChanPointStr( - tx, chanPoint, + ctx, tx, chanPoint, ) if err != nil { return err @@ -550,8 +556,10 @@ func handleListChannelsRequest(db firewalldb.PrivacyMapDB, return r, nil } - err := db.View(func(tx firewalldb.PrivacyMapTx) error { - peer, err := firewalldb.RevealBytes(tx, r.Peer) + err := db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + + peer, err := firewalldb.RevealBytes(ctx, tx, r.Peer) if err != nil { return err } @@ -572,7 +580,7 @@ func handleListChannelsResponse(db firewalldb.PrivacyMapDB, randIntn func(int) (int, error)) func(ctx context.Context, r *lnrpc.ListChannelsResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.ListChannelsResponse) ( + return func(ctx context.Context, r *lnrpc.ListChannelsResponse) ( proto.Message, error) { hidePubkeys := !flags.Contains(session.ClearPubkeys) @@ -580,7 +588,9 @@ func handleListChannelsResponse(db firewalldb.PrivacyMapDB, channels := make([]*lnrpc.Channel, len(r.Channels)) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, c := range r.Channels { var err error @@ -589,7 +599,7 @@ func handleListChannelsResponse(db firewalldb.PrivacyMapDB, remotePub := c.RemotePubkey if hidePubkeys { remotePub, err = firewalldb.HideString( - tx, c.RemotePubkey, + ctx, tx, c.RemotePubkey, ) if err != nil { return err @@ -600,14 +610,14 @@ func handleListChannelsResponse(db firewalldb.PrivacyMapDB, chanID := c.ChanId if hideChanIds { chanPoint, err = firewalldb.HideChanPointStr( - tx, c.ChannelPoint, + ctx, tx, c.ChannelPoint, ) if err != nil { return err } chanID, err = firewalldb.HideUint64( - tx, c.ChanId, + ctx, tx, c.ChanId, ) if err != nil { return err @@ -745,7 +755,7 @@ func handleUpdatePolicyRequest(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.PolicyUpdateRequest) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.PolicyUpdateRequest) ( + return func(ctx context.Context, r *lnrpc.PolicyUpdateRequest) ( proto.Message, error) { chanPoint := r.GetChanPoint() @@ -764,10 +774,12 @@ func handleUpdatePolicyRequest(db firewalldb.PrivacyMapDB, newTxid := txid.String() newIndex := chanPoint.GetOutputIndex() if !flags.Contains(session.ClearChanIDs) { - err = db.View(func(tx firewalldb.PrivacyMapTx) error { + err = db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error - newTxid, newIndex, err = firewalldb.RevealChanPoint( - tx, newTxid, newIndex, + newTxid, newIndex, err = firewalldb.RevealChanPoint( //nolint:lll + ctx, tx, newTxid, newIndex, ) return err }) @@ -793,7 +805,7 @@ func handleUpdatePolicyResponse(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.PolicyUpdateResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.PolicyUpdateResponse) ( + return func(ctx context.Context, r *lnrpc.PolicyUpdateResponse) ( proto.Message, error) { if flags.Contains(session.ClearChanIDs) { @@ -804,7 +816,9 @@ func handleUpdatePolicyResponse(db firewalldb.PrivacyMapDB, []*lnrpc.FailedUpdate, len(r.FailedUpdates), ) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, u := range r.FailedUpdates { failedUpdates[i] = &lnrpc.FailedUpdate{ Reason: u.Reason, @@ -816,7 +830,7 @@ func handleUpdatePolicyResponse(db firewalldb.PrivacyMapDB, } txid, index, err := firewalldb.HideChanPoint( - tx, u.Outpoint.TxidStr, + ctx, tx, u.Outpoint.TxidStr, u.Outpoint.OutputIndex, ) if err != nil { @@ -926,7 +940,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, randIntn func(int) (int, error)) func(ctx context.Context, r *lnrpc.ClosedChannelsResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.ClosedChannelsResponse) ( + return func(ctx context.Context, r *lnrpc.ClosedChannelsResponse) ( proto.Message, error) { closedChannels := make( @@ -934,14 +948,16 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, len(r.Channels), ) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, c := range r.Channels { var err error remotePub := c.RemotePubkey if !flags.Contains(session.ClearPubkeys) { remotePub, err = firewalldb.HideString( - tx, remotePub, + ctx, tx, remotePub, ) if err != nil { return err @@ -969,7 +985,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, channelPoint := c.ChannelPoint if !flags.Contains(session.ClearChanIDs) { channelPoint, err = firewalldb.HideChanPointStr( - tx, c.ChannelPoint, + ctx, tx, c.ChannelPoint, ) if err != nil { return err @@ -979,7 +995,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, chanID := c.ChanId if !flags.Contains(session.ClearChanIDs) { chanID, err = firewalldb.HideUint64( - tx, c.ChanId, + ctx, tx, c.ChanId, ) if err != nil { return err @@ -989,7 +1005,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, closingTxid := c.ClosingTxHash if !flags.Contains(session.ClearClosingTxIds) { closingTxid, err = firewalldb.HideString( - tx, c.ClosingTxHash, + ctx, tx, c.ClosingTxHash, ) if err != nil { return err @@ -1036,7 +1052,8 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB, // obfuscatePendingChannel is a helper to obfuscate the fields of a pending // channel. -func obfuscatePendingChannel(c *lnrpc.PendingChannelsResponse_PendingChannel, +func obfuscatePendingChannel(ctx context.Context, + c *lnrpc.PendingChannelsResponse_PendingChannel, tx firewalldb.PrivacyMapTx, randIntn func(int) (int, error), flags session.PrivacyFlags) ( *lnrpc.PendingChannelsResponse_PendingChannel, error) { @@ -1046,7 +1063,7 @@ func obfuscatePendingChannel(c *lnrpc.PendingChannelsResponse_PendingChannel, remotePub := c.RemoteNodePub if !flags.Contains(session.ClearPubkeys) { remotePub, err = firewalldb.HideString( - tx, remotePub, + ctx, tx, remotePub, ) if err != nil { return nil, err @@ -1083,7 +1100,7 @@ func obfuscatePendingChannel(c *lnrpc.PendingChannelsResponse_PendingChannel, chanPoint := c.ChannelPoint if !flags.Contains(session.ClearChanIDs) { chanPoint, err = firewalldb.HideChanPointStr( - tx, c.ChannelPoint, + ctx, tx, c.ChannelPoint, ) if err != nil { return nil, err @@ -1117,7 +1134,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, randIntn func(int) (int, error)) func(ctx context.Context, r *lnrpc.PendingChannelsResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.PendingChannelsResponse) ( + return func(ctx context.Context, r *lnrpc.PendingChannelsResponse) ( proto.Message, error) { pendingOpens := make( @@ -1140,12 +1157,14 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, len(r.WaitingCloseChannels), ) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, c := range r.PendingOpenChannels { var err error pendingChannel, err := obfuscatePendingChannel( - c.Channel, tx, randIntn, flags, + ctx, c.Channel, tx, randIntn, flags, ) if err != nil { return err @@ -1169,7 +1188,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, var err error pendingChannel, err := obfuscatePendingChannel( - c.Channel, tx, randIntn, flags, + ctx, c.Channel, tx, randIntn, flags, ) if err != nil { return err @@ -1177,8 +1196,8 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, closingTxid := c.ClosingTxid if !flags.Contains(session.ClearClosingTxIds) { - closingTxid, err = firewalldb.HideString( - tx, c.ClosingTxid, + closingTxid, err = firewalldb.HideString( //nolint:lll + ctx, tx, c.ClosingTxid, ) if err != nil { return err @@ -1198,7 +1217,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, var err error pendingChannel, err := obfuscatePendingChannel( - c.Channel, tx, randIntn, flags, + ctx, c.Channel, tx, randIntn, flags, ) if err != nil { return err @@ -1207,7 +1226,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, closingTxid := c.ClosingTxid if !flags.Contains(session.ClearClosingTxIds) { closingTxid, err = firewalldb.HideString( - tx, c.ClosingTxid, + ctx, tx, c.ClosingTxid, ) if err != nil { return err @@ -1259,7 +1278,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, var err error pendingChannel, err := obfuscatePendingChannel( - c.Channel, tx, randIntn, flags, + ctx, c.Channel, tx, randIntn, flags, ) if err != nil { return err @@ -1279,7 +1298,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, closingTxid := c.ClosingTxid if !flags.Contains(session.ClearClosingTxIds) { closingTxid, err = firewalldb.HideString( - tx, closingTxid, + ctx, tx, closingTxid, ) if err != nil { return err @@ -1296,7 +1315,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB, ) { closingTxHex, err = firewalldb.HideString( - tx, closingTxHex, + ctx, tx, closingTxHex, ) if err != nil { return err @@ -1343,12 +1362,14 @@ func handleBatchOpenChannelRequest(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.BatchOpenChannelRequest) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.BatchOpenChannelRequest) ( + return func(ctx context.Context, r *lnrpc.BatchOpenChannelRequest) ( proto.Message, error) { var reqs = make([]*lnrpc.BatchOpenChannel, len(r.Channels)) - err := db.View(func(tx firewalldb.PrivacyMapTx) error { + err := db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, c := range r.Channels { var err error @@ -1359,7 +1380,7 @@ func handleBatchOpenChannelRequest(db firewalldb.PrivacyMapDB, nodePubkey := c.NodePubkey if !flags.Contains(session.ClearPubkeys) { nodePubkey, err = firewalldb.RevealBytes( - tx, c.NodePubkey, + ctx, tx, c.NodePubkey, ) if err != nil { return err @@ -1414,12 +1435,14 @@ func handleBatchOpenChannelResponse(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.BatchOpenChannelResponse) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.BatchOpenChannelResponse) ( + return func(ctx context.Context, r *lnrpc.BatchOpenChannelResponse) ( proto.Message, error) { resps := make([]*lnrpc.PendingUpdate, len(r.PendingChannels)) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, p := range r.PendingChannels { var ( txIdBytes = p.Txid @@ -1432,8 +1455,9 @@ func handleBatchOpenChannelResponse(db firewalldb.PrivacyMapDB, return err } - txID, outIdx, err := firewalldb.HideChanPoint( - tx, txId.String(), p.OutputIndex, + txID, outIdx, err := firewalldb.HideChanPoint( //nolint:lll + ctx, tx, txId.String(), + p.OutputIndex, ) if err != nil { return err @@ -1471,14 +1495,15 @@ func handleChannelOpenRequest(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.OpenChannelRequest) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.OpenChannelRequest) ( + return func(ctx context.Context, r *lnrpc.OpenChannelRequest) ( proto.Message, error) { var nodePubkey []byte - err := db.View(func(tx firewalldb.PrivacyMapTx) error { - var err error + err := db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error // We use the byte slice representation of the // pubkey and fall back to the hex string if present. nodePubkey = r.NodePubkey @@ -1493,7 +1518,7 @@ func handleChannelOpenRequest(db firewalldb.PrivacyMapDB, if !flags.Contains(session.ClearPubkeys) { nodePubkey, err = firewalldb.RevealBytes( - tx, nodePubkey, + ctx, tx, nodePubkey, ) if err != nil { return err @@ -1548,7 +1573,7 @@ func handleChannelOpenResponse(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.ChannelPoint) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.ChannelPoint) ( + return func(ctx context.Context, r *lnrpc.ChannelPoint) ( proto.Message, error) { var ( @@ -1556,7 +1581,9 @@ func handleChannelOpenResponse(db firewalldb.PrivacyMapDB, index uint32 ) - err := db.Update(func(tx firewalldb.PrivacyMapTx) error { + err := db.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error txid = r.GetFundingTxidStr() @@ -1575,7 +1602,7 @@ func handleChannelOpenResponse(db firewalldb.PrivacyMapDB, if !flags.Contains(session.ClearChanIDs) { txid, index, err = firewalldb.HideChanPoint( - tx, txid, index, + ctx, tx, txid, index, ) if err != nil { return err @@ -1622,12 +1649,14 @@ func handleConnectPeerRequest(db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) func(ctx context.Context, r *lnrpc.ConnectPeerRequest) (proto.Message, error) { - return func(_ context.Context, r *lnrpc.ConnectPeerRequest) ( + return func(ctx context.Context, r *lnrpc.ConnectPeerRequest) ( proto.Message, error) { var addr *lnrpc.LightningAddress - err := db.View(func(tx firewalldb.PrivacyMapTx) error { + err := db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error // Note, this only works if the pubkey alias was @@ -1636,7 +1665,7 @@ func handleConnectPeerRequest(db firewalldb.PrivacyMapDB, pubkey := r.Addr.Pubkey if !flags.Contains(session.ClearPubkeys) { pubkey, err = firewalldb.RevealString( - tx, r.Addr.Pubkey, + ctx, tx, r.Addr.Pubkey, ) if err != nil { return err @@ -1646,7 +1675,7 @@ func handleConnectPeerRequest(db firewalldb.PrivacyMapDB, host := r.Addr.Host if !flags.Contains(session.ClearNetworkAddresses) { host, err = firewalldb.RevealString( - tx, r.Addr.Host, + ctx, tx, r.Addr.Host, ) if err != nil { return err diff --git a/firewall/privacy_mapper_test.go b/firewall/privacy_mapper_test.go index 1b67068e0..1998d1280 100644 --- a/firewall/privacy_mapper_test.go +++ b/firewall/privacy_mapper_test.go @@ -1073,9 +1073,11 @@ func newMockDB(t *testing.T, preloadRealToPseudo map[string]string, db := mockDB{privDB: make(map[string]*mockPrivacyMapDB)} sessDB := db.NewSessionDB(sessID) - _ = sessDB.Update(func(tx firewalldb.PrivacyMapTx) error { + _ = sessDB.Update(context.Background(), func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for r, p := range preloadRealToPseudo { - require.NoError(t, tx.NewPair(r, p)) + require.NoError(t, tx.NewPair(ctx, r, p)) } return nil }) @@ -1107,25 +1109,29 @@ type mockPrivacyMapDB struct { p2r map[string]string } -func (m *mockPrivacyMapDB) Update( - f func(tx firewalldb.PrivacyMapTx) error) error { +func (m *mockPrivacyMapDB) Update(ctx context.Context, + f func(ctx context.Context, tx firewalldb.PrivacyMapTx) error) error { - return f(m) + return f(ctx, m) } -func (m *mockPrivacyMapDB) View( - f func(tx firewalldb.PrivacyMapTx) error) error { +func (m *mockPrivacyMapDB) View(ctx context.Context, + f func(ctx context.Context, tx firewalldb.PrivacyMapTx) error) error { - return f(m) + return f(ctx, m) } -func (m *mockPrivacyMapDB) NewPair(real, pseudo string) error { +func (m *mockPrivacyMapDB) NewPair(_ context.Context, real, + pseudo string) error { + m.r2p[real] = pseudo m.p2r[pseudo] = real return nil } -func (m *mockPrivacyMapDB) PseudoToReal(pseudo string) (string, error) { +func (m *mockPrivacyMapDB) PseudoToReal(_ context.Context, pseudo string) ( + string, error) { + r, ok := m.p2r[pseudo] if !ok { return "", firewalldb.ErrNoSuchKeyFound @@ -1134,7 +1140,9 @@ func (m *mockPrivacyMapDB) PseudoToReal(pseudo string) (string, error) { return r, nil } -func (m *mockPrivacyMapDB) RealToPseudo(real string) (string, error) { +func (m *mockPrivacyMapDB) RealToPseudo(_ context.Context, real string) (string, + error) { + p, ok := m.r2p[real] if !ok { return "", firewalldb.ErrNoSuchKeyFound @@ -1143,8 +1151,8 @@ func (m *mockPrivacyMapDB) RealToPseudo(real string) (string, error) { return p, nil } -func (m *mockPrivacyMapDB) FetchAllPairs() (*firewalldb.PrivacyMapPairs, - error) { +func (m *mockPrivacyMapDB) FetchAllPairs(_ context.Context) ( + *firewalldb.PrivacyMapPairs, error) { return firewalldb.NewPrivacyMapPairs(m.r2p), nil } diff --git a/firewall/rule_enforcer.go b/firewall/rule_enforcer.go index c99671cdf..7914965ed 100644 --- a/firewall/rule_enforcer.go +++ b/firewall/rule_enforcer.go @@ -395,7 +395,7 @@ func (r *RuleEnforcer) initRule(ctx context.Context, reqID uint64, name string, privMap := r.newPrivMap(session.GroupID) ruleValues, err = ruleValues.PseudoToReal( - privMap, session.PrivacyFlags, + ctx, privMap, session.PrivacyFlags, ) if err != nil { return nil, fmt.Errorf("could not prepare rule "+ diff --git a/firewalldb/kvdb_store.go b/firewalldb/kvdb_store.go new file mode 100644 index 000000000..d4ce79f20 --- /dev/null +++ b/firewalldb/kvdb_store.go @@ -0,0 +1,44 @@ +package firewalldb + +import ( + "context" + + "go.etcd.io/bbolt" +) + +// kvdbExecutor is a concrete implementation of the DBExecutor interface that +// uses a bbolt database as its backing store. +type kvdbExecutor[T any] struct { + db *bbolt.DB + wrapTx func(tx *bbolt.Tx) T +} + +// Update opens a database read/write transaction and executes the function f +// with the transaction passed as a parameter. After f exits, if f did not +// error, the transaction is committed. Otherwise, if f did error, the +// transaction is rolled back. If the rollback fails, the original error +// returned by f is still returned. If the commit fails, the commit error is +// returned. +// +// NOTE: this is part of the DBExecutor interface. +func (e *kvdbExecutor[T]) Update(ctx context.Context, + fn func(ctx context.Context, tx T) error) error { + + return e.db.Update(func(tx *bbolt.Tx) error { + return fn(ctx, e.wrapTx(tx)) + }) +} + +// View opens a database read transaction and executes the function f with the +// transaction passed as a parameter. After f exits, the transaction is rolled +// back. If f errors, its error is returned, not a rollback error (if any +// occur). +// +// NOTE: this is part of the DBExecutor interface. +func (e *kvdbExecutor[T]) View(ctx context.Context, + fn func(ctx context.Context, tx T) error) error { + + return e.db.View(func(tx *bbolt.Tx) error { + return fn(ctx, e.wrapTx(tx)) + }) +} diff --git a/firewalldb/kvstores.go b/firewalldb/kvstores.go index 1dffd54cf..9dad0a0cc 100644 --- a/firewalldb/kvstores.go +++ b/firewalldb/kvstores.go @@ -107,62 +107,28 @@ type RulesDB interface { func (db *DB) GetKVStores(rule string, groupID session.ID, feature string) KVStores { - return &kvStores{ - db: db.DB, - ruleName: rule, - groupID: groupID, - featureName: feature, + return &kvdbExecutor[KVStoreTx]{ + db: db.DB, + wrapTx: func(tx *bbolt.Tx) KVStoreTx { + return &kvStoreTx{ + boltTx: tx, + kvStores: &kvStores{ + ruleName: rule, + groupID: groupID, + featureName: feature, + }, + } + }, } } // kvStores implements the rules.KVStores interface. type kvStores struct { - db *bbolt.DB ruleName string groupID session.ID featureName string } -// Update opens a database read/write transaction and executes the function f -// with the transaction passed as a parameter. After f exits, if f did not -// error, the transaction is committed. Otherwise, if f did error, the -// transaction is rolled back. If the rollback fails, the original error -// returned by f is still returned. If the commit fails, the commit error is -// returned. -// -// NOTE: this is part of the KVStores interface. -func (s *kvStores) Update(ctx context.Context, fn func(ctx context.Context, - tx KVStoreTx) error) error { - - return s.db.Update(func(tx *bbolt.Tx) error { - boltTx := &kvStoreTx{ - boltTx: tx, - kvStores: s, - } - - return fn(ctx, boltTx) - }) -} - -// View opens a database read transaction and executes the function f with the -// transaction passed as a parameter. After f exits, the transaction is rolled -// back. If f errors, its error is returned, not a rollback error (if any -// occur). -// -// NOTE: this is part of the KVStores interface. -func (s *kvStores) View(ctx context.Context, fn func(ctx context.Context, - tx KVStoreTx) error) error { - - return s.db.View(func(tx *bbolt.Tx) error { - boltTx := &kvStoreTx{ - boltTx: tx, - kvStores: s, - } - - return fn(ctx, boltTx) - }) -} - // getBucketFunc defines the signature of the bucket creation/fetching function // required by kvStoreTx. If create is true, then all the bucket (and all // buckets leading up to the bucket) should be created if they do not already diff --git a/firewalldb/privacy_mapper.go b/firewalldb/privacy_mapper.go index e2f10f281..ab8e60e40 100644 --- a/firewalldb/privacy_mapper.go +++ b/firewalldb/privacy_mapper.go @@ -1,6 +1,7 @@ package firewalldb import ( + "context" "crypto/rand" "encoding/binary" "encoding/hex" @@ -41,132 +42,48 @@ type NewPrivacyMapDB func(groupID session.ID) PrivacyMapDB // PrivacyDB constructs a PrivacyMapDB that will be indexed under the given // group ID key. func (db *DB) PrivacyDB(groupID session.ID) PrivacyMapDB { - return &privacyMapDB{ - DB: db, - groupID: groupID, + return &kvdbExecutor[PrivacyMapTx]{ + db: db.DB, + wrapTx: func(tx *bbolt.Tx) PrivacyMapTx { + return &privacyMapTx{ + boltTx: tx, + privacyMapDB: &privacyMapDB{ + groupID: groupID, + }, + } + }, } } // PrivacyMapDB provides an Update and View method that will allow the caller // to perform atomic read and write transactions defined by PrivacyMapTx on the // underlying DB. -type PrivacyMapDB interface { - // Update opens a database read/write transaction and executes the - // function f with the transaction passed as a parameter. After f exits, - // if f did not error, the transaction is committed. Otherwise, if f did - // error, the transaction is rolled back. If the rollback fails, the - // original error returned by f is still returned. If the commit fails, - // the commit error is returned. - Update(f func(tx PrivacyMapTx) error) error - - // View opens a database read transaction and executes the function f - // with the transaction passed as a parameter. After f exits, the - // transaction is rolled back. If f errors, its error is returned, not a - // rollback error (if any occur). - View(f func(tx PrivacyMapTx) error) error -} +type PrivacyMapDB = DBExecutor[PrivacyMapTx] // PrivacyMapTx represents a db that can be used to create, store and fetch // real-pseudo pairs. type PrivacyMapTx interface { // NewPair persists a new real-pseudo pair. - NewPair(real, pseudo string) error + NewPair(ctx context.Context, real, pseudo string) error // PseudoToReal returns the real value associated with the given pseudo // value. If no such pair is found, then ErrNoSuchKeyFound is returned. - PseudoToReal(pseudo string) (string, error) + PseudoToReal(ctx context.Context, pseudo string) (string, error) // RealToPseudo returns the pseudo value associated with the given real // value. If no such pair is found, then ErrNoSuchKeyFound is returned. - RealToPseudo(real string) (string, error) + RealToPseudo(ctx context.Context, real string) (string, error) // FetchAllPairs loads and returns the real-to-pseudo pairs in the form // of a PrivacyMapPairs struct. - FetchAllPairs() (*PrivacyMapPairs, error) + FetchAllPairs(ctx context.Context) (*PrivacyMapPairs, error) } // privacyMapDB is an implementation of PrivacyMapDB. type privacyMapDB struct { - *DB groupID session.ID } -// beginTx starts db transaction. The transaction will be a read or read-write -// transaction depending on the value of the `writable` parameter. -func (p *privacyMapDB) beginTx(writable bool) (*privacyMapTx, error) { - boltTx, err := p.Begin(writable) - if err != nil { - return nil, err - } - return &privacyMapTx{ - privacyMapDB: p, - boltTx: boltTx, - }, nil -} - -// Update opens a database read/write transaction and executes the function f -// with the transaction passed as a parameter. After f exits, if f did not -// error, the transaction is committed. Otherwise, if f did error, the -// transaction is rolled back. If the rollback fails, the original error -// returned by f is still returned. If the commit fails, the commit error is -// returned. -// -// NOTE: this is part of the PrivacyMapDB interface. -func (p *privacyMapDB) Update(f func(tx PrivacyMapTx) error) error { - tx, err := p.beginTx(true) - if err != nil { - return err - } - - // Make sure the transaction rolls back in the event of a panic. - defer func() { - if tx != nil { - _ = tx.boltTx.Rollback() - } - }() - - err = f(tx) - if err != nil { - // Want to return the original error, not a rollback error if - // any occur. - _ = tx.boltTx.Rollback() - return err - } - - return tx.boltTx.Commit() -} - -// View opens a database read transaction and executes the function f with the -// transaction passed as a parameter. After f exits, the transaction is rolled -// back. If f errors, its error is returned, not a rollback error (if any -// occur). -// -// NOTE: this is part of the PrivacyMapDB interface. -func (p *privacyMapDB) View(f func(tx PrivacyMapTx) error) error { - tx, err := p.beginTx(false) - if err != nil { - return err - } - - // Make sure the transaction rolls back in the event of a panic. - defer func() { - if tx != nil { - _ = tx.boltTx.Rollback() - } - }() - - err = f(tx) - rollbackErr := tx.boltTx.Rollback() - if err != nil { - return err - } - - if rollbackErr != nil { - return rollbackErr - } - return nil -} - // privacyMapTx is an implementation of PrivacyMapTx. type privacyMapTx struct { *privacyMapDB @@ -176,7 +93,7 @@ type privacyMapTx struct { // NewPair inserts a new real-pseudo pair into the db. // // NOTE: this is part of the PrivacyMapTx interface. -func (p *privacyMapTx) NewPair(real, pseudo string) error { +func (p *privacyMapTx) NewPair(_ context.Context, real, pseudo string) error { privacyBucket, err := getBucket(p.boltTx, privacyBucketKey) if err != nil { return err @@ -223,7 +140,9 @@ func (p *privacyMapTx) NewPair(real, pseudo string) error { // it does then the real value is returned, else an error is returned. // // NOTE: this is part of the PrivacyMapTx interface. -func (p *privacyMapTx) PseudoToReal(pseudo string) (string, error) { +func (p *privacyMapTx) PseudoToReal(_ context.Context, pseudo string) (string, + error) { + privacyBucket, err := getBucket(p.boltTx, privacyBucketKey) if err != nil { return "", err @@ -251,7 +170,9 @@ func (p *privacyMapTx) PseudoToReal(pseudo string) (string, error) { // it does then the pseudo value is returned, else an error is returned. // // NOTE: this is part of the PrivacyMapTx interface. -func (p *privacyMapTx) RealToPseudo(real string) (string, error) { +func (p *privacyMapTx) RealToPseudo(_ context.Context, real string) (string, + error) { + privacyBucket, err := getBucket(p.boltTx, privacyBucketKey) if err != nil { return "", err @@ -278,7 +199,9 @@ func (p *privacyMapTx) RealToPseudo(real string) (string, error) { // FetchAllPairs loads and returns the real-to-pseudo pairs. // // NOTE: this is part of the PrivacyMapTx interface. -func (p *privacyMapTx) FetchAllPairs() (*PrivacyMapPairs, error) { +func (p *privacyMapTx) FetchAllPairs(_ context.Context) (*PrivacyMapPairs, + error) { + privacyBucket, err := getBucket(p.boltTx, privacyBucketKey) if err != nil { return nil, err @@ -309,8 +232,10 @@ func (p *privacyMapTx) FetchAllPairs() (*PrivacyMapPairs, error) { return NewPrivacyMapPairs(pairs), nil } -func HideString(tx PrivacyMapTx, real string) (string, error) { - pseudo, err := tx.RealToPseudo(real) +func HideString(ctx context.Context, tx PrivacyMapTx, real string) (string, + error) { + + pseudo, err := tx.RealToPseudo(ctx, real) if err != nil && err != ErrNoSuchKeyFound { return "", err } @@ -323,7 +248,7 @@ func HideString(tx PrivacyMapTx, real string) (string, error) { return "", err } - if err = tx.NewPair(real, pseudo); err != nil { + if err = tx.NewPair(ctx, real, pseudo); err != nil { return "", err } @@ -347,17 +272,21 @@ func NewPseudoStr(n int) (string, error) { return string(b), nil } -func RevealString(tx PrivacyMapTx, pseudo string) (string, error) { +func RevealString(ctx context.Context, tx PrivacyMapTx, pseudo string) (string, + error) { + if pseudo == "" { return pseudo, nil } - return tx.PseudoToReal(pseudo) + return tx.PseudoToReal(ctx, pseudo) } -func HideUint64(tx PrivacyMapTx, real uint64) (uint64, error) { +func HideUint64(ctx context.Context, tx PrivacyMapTx, real uint64) (uint64, + error) { + str := Uint64ToStr(real) - pseudo, err := tx.RealToPseudo(str) + pseudo, err := tx.RealToPseudo(ctx, str) if err != nil && err != ErrNoSuchKeyFound { return 0, err } @@ -366,19 +295,21 @@ func HideUint64(tx PrivacyMapTx, real uint64) (uint64, error) { } pseudoUint64, pseudoUint64Str := NewPseudoUint64() - if err := tx.NewPair(str, pseudoUint64Str); err != nil { + if err := tx.NewPair(ctx, str, pseudoUint64Str); err != nil { return 0, err } return pseudoUint64, nil } -func RevealUint64(tx PrivacyMapTx, pseudo uint64) (uint64, error) { +func RevealUint64(ctx context.Context, tx PrivacyMapTx, pseudo uint64) (uint64, + error) { + if pseudo == 0 { return 0, nil } - real, err := tx.PseudoToReal(Uint64ToStr(pseudo)) + real, err := tx.PseudoToReal(ctx, Uint64ToStr(pseudo)) if err != nil { return 0, err } @@ -386,11 +317,11 @@ func RevealUint64(tx PrivacyMapTx, pseudo uint64) (uint64, error) { return StrToUint64(real) } -func HideChanPoint(tx PrivacyMapTx, txid string, index uint32) (string, - uint32, error) { +func HideChanPoint(ctx context.Context, tx PrivacyMapTx, txid string, + index uint32) (string, uint32, error) { cp := fmt.Sprintf("%s:%d", txid, index) - pseudo, err := tx.RealToPseudo(cp) + pseudo, err := tx.RealToPseudo(ctx, cp) if err != nil && err != ErrNoSuchKeyFound { return "", 0, err } @@ -403,7 +334,7 @@ func HideChanPoint(tx PrivacyMapTx, txid string, index uint32) (string, return "", 0, err } - if err := tx.NewPair(cp, newCp); err != nil { + if err := tx.NewPair(ctx, cp, newCp); err != nil { return "", 0, err } @@ -420,11 +351,11 @@ func NewPseudoChanPoint() (string, error) { return fmt.Sprintf("%s:%d", pseudoTXID, pseudoIndex), nil } -func RevealChanPoint(tx PrivacyMapTx, txid string, index uint32) (string, - uint32, error) { +func RevealChanPoint(ctx context.Context, tx PrivacyMapTx, txid string, + index uint32) (string, uint32, error) { fakePoint := fmt.Sprintf("%s:%d", txid, index) - real, err := tx.PseudoToReal(fakePoint) + real, err := tx.PseudoToReal(ctx, fakePoint) if err != nil { return "", 0, err } @@ -439,13 +370,15 @@ func NewPseudoUint32() uint32 { return binary.BigEndian.Uint32(b) } -func HideChanPointStr(tx PrivacyMapTx, cp string) (string, error) { +func HideChanPointStr(ctx context.Context, tx PrivacyMapTx, cp string) (string, + error) { + txid, index, err := DecodeChannelPoint(cp) if err != nil { return "", err } - newTxid, newIndex, err := HideChanPoint(tx, txid, index) + newTxid, newIndex, err := HideChanPoint(ctx, tx, txid, index) if err != nil { return "", err } @@ -453,10 +386,12 @@ func HideChanPointStr(tx PrivacyMapTx, cp string) (string, error) { return fmt.Sprintf("%s:%d", newTxid, newIndex), nil } -func HideBytes(tx PrivacyMapTx, realBytes []byte) ([]byte, error) { +func HideBytes(ctx context.Context, tx PrivacyMapTx, realBytes []byte) ([]byte, + error) { + real := hex.EncodeToString(realBytes) - pseudo, err := HideString(tx, real) + pseudo, err := HideString(ctx, tx, real) if err != nil { return nil, err } @@ -464,13 +399,15 @@ func HideBytes(tx PrivacyMapTx, realBytes []byte) ([]byte, error) { return hex.DecodeString(pseudo) } -func RevealBytes(tx PrivacyMapTx, pseudoBytes []byte) ([]byte, error) { +func RevealBytes(ctx context.Context, tx PrivacyMapTx, + pseudoBytes []byte) ([]byte, error) { + if pseudoBytes == nil { return nil, nil } pseudo := hex.EncodeToString(pseudoBytes) - pseudo, err := RevealString(tx, pseudo) + pseudo, err := RevealString(ctx, tx, pseudo) if err != nil { return nil, err } diff --git a/firewalldb/privacy_mapper_test.go b/firewalldb/privacy_mapper_test.go index 5ba9d50fe..7be4d3b64 100644 --- a/firewalldb/privacy_mapper_test.go +++ b/firewalldb/privacy_mapper_test.go @@ -1,6 +1,7 @@ package firewalldb import ( + "context" "fmt" "testing" @@ -9,6 +10,9 @@ import ( // TestPrivacyMapStorage tests the privacy mapper CRUD logic. func TestPrivacyMapStorage(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() db, err := NewDB(tmpDir, "test.db", nil) require.NoError(t, err) @@ -18,25 +22,25 @@ func TestPrivacyMapStorage(t *testing.T) { pdb1 := db.PrivacyDB([4]byte{1, 1, 1, 1}) - _ = pdb1.Update(func(tx PrivacyMapTx) error { - _, err = tx.RealToPseudo("real") + _ = pdb1.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { + _, err = tx.RealToPseudo(ctx, "real") require.ErrorIs(t, err, ErrNoSuchKeyFound) - _, err = tx.PseudoToReal("pseudo") + _, err = tx.PseudoToReal(ctx, "pseudo") require.ErrorIs(t, err, ErrNoSuchKeyFound) - err = tx.NewPair("real", "pseudo") + err = tx.NewPair(ctx, "real", "pseudo") require.NoError(t, err) - pseudo, err := tx.RealToPseudo("real") + pseudo, err := tx.RealToPseudo(ctx, "real") require.NoError(t, err) require.Equal(t, "pseudo", pseudo) - real, err := tx.PseudoToReal("pseudo") + real, err := tx.PseudoToReal(ctx, "pseudo") require.NoError(t, err) require.Equal(t, "real", real) - pairs, err := tx.FetchAllPairs() + pairs, err := tx.FetchAllPairs(ctx) require.NoError(t, err) require.EqualValues(t, pairs.pairs, map[string]string{ @@ -48,25 +52,25 @@ func TestPrivacyMapStorage(t *testing.T) { pdb2 := db.PrivacyDB([4]byte{2, 2, 2, 2}) - _ = pdb2.Update(func(tx PrivacyMapTx) error { - _, err = tx.RealToPseudo("real") + _ = pdb2.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { + _, err = tx.RealToPseudo(ctx, "real") require.ErrorIs(t, err, ErrNoSuchKeyFound) - _, err = tx.PseudoToReal("pseudo") + _, err = tx.PseudoToReal(ctx, "pseudo") require.ErrorIs(t, err, ErrNoSuchKeyFound) - err = tx.NewPair("real 2", "pseudo 2") + err = tx.NewPair(ctx, "real 2", "pseudo 2") require.NoError(t, err) - pseudo, err := tx.RealToPseudo("real 2") + pseudo, err := tx.RealToPseudo(ctx, "real 2") require.NoError(t, err) require.Equal(t, "pseudo 2", pseudo) - real, err := tx.PseudoToReal("pseudo 2") + real, err := tx.PseudoToReal(ctx, "pseudo 2") require.NoError(t, err) require.Equal(t, "real 2", real) - pairs, err := tx.FetchAllPairs() + pairs, err := tx.FetchAllPairs(ctx) require.NoError(t, err) require.EqualValues(t, pairs.pairs, map[string]string{ @@ -78,41 +82,41 @@ func TestPrivacyMapStorage(t *testing.T) { pdb3 := db.PrivacyDB([4]byte{3, 3, 3, 3}) - _ = pdb3.Update(func(tx PrivacyMapTx) error { + _ = pdb3.Update(ctx, func(ctx context.Context, tx PrivacyMapTx) error { // Check that calling FetchAllPairs returns an empty map if // nothing exists in the DB yet. - m, err := tx.FetchAllPairs() + m, err := tx.FetchAllPairs(ctx) require.NoError(t, err) require.Empty(t, m.pairs) // Add a new pair. - err = tx.NewPair("real 1", "pseudo 1") + err = tx.NewPair(ctx, "real 1", "pseudo 1") require.NoError(t, err) // Try to add a new pair that has the same real value as the // first pair. This should fail. - err = tx.NewPair("real 1", "pseudo 2") + err = tx.NewPair(ctx, "real 1", "pseudo 2") require.ErrorContains(t, err, "an entry already exists for "+ "real value") // Try to add a new pair that has the same pseudo value as the // first pair. This should fail. - err = tx.NewPair("real 2", "pseudo 1") + err = tx.NewPair(ctx, "real 2", "pseudo 1") require.ErrorContains(t, err, "an entry already exists for "+ "pseudo value") // Add a few more pairs. - err = tx.NewPair("real 2", "pseudo 2") + err = tx.NewPair(ctx, "real 2", "pseudo 2") require.NoError(t, err) - err = tx.NewPair("real 3", "pseudo 3") + err = tx.NewPair(ctx, "real 3", "pseudo 3") require.NoError(t, err) - err = tx.NewPair("real 4", "pseudo 4") + err = tx.NewPair(ctx, "real 4", "pseudo 4") require.NoError(t, err) // Check that FetchAllPairs correctly returns all the pairs. - pairs, err := tx.FetchAllPairs() + pairs, err := tx.FetchAllPairs(ctx) require.NoError(t, err) require.EqualValues(t, pairs.pairs, map[string]string{ @@ -180,6 +184,9 @@ func TestPrivacyMapStorage(t *testing.T) { // provide atomic access to the db. If anything fails in the middle of an // `Update` function, then all the changes prior should be rolled back. func TestPrivacyMapTxs(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() db, err := NewDB(tmpDir, "test.db", nil) require.NoError(t, err) @@ -191,13 +198,15 @@ func TestPrivacyMapTxs(t *testing.T) { // Test that if an action fails midway through the transaction, then // it is rolled back. - err = pdb1.Update(func(tx PrivacyMapTx) error { - err := tx.NewPair("real", "pseudo") + err = pdb1.Update(ctx, func(ctx context.Context, + tx PrivacyMapTx) error { + + err := tx.NewPair(ctx, "real", "pseudo") if err != nil { return err } - p, err := tx.RealToPseudo("real") + p, err := tx.RealToPseudo(ctx, "real") if err != nil { return err } @@ -208,8 +217,8 @@ func TestPrivacyMapTxs(t *testing.T) { }) require.Error(t, err) - err = pdb1.View(func(tx PrivacyMapTx) error { - _, err := tx.RealToPseudo("real") + err = pdb1.View(ctx, func(ctx context.Context, tx PrivacyMapTx) error { + _, err := tx.RealToPseudo(ctx, "real") return err }) require.ErrorIs(t, err, ErrNoSuchKeyFound) diff --git a/rules/chan_policy_bounds.go b/rules/chan_policy_bounds.go index 9ba90ded6..55b79598e 100644 --- a/rules/chan_policy_bounds.go +++ b/rules/chan_policy_bounds.go @@ -396,8 +396,8 @@ func (f *ChanPolicyBounds) RuleName() string { // find the real values. This is a no-op for the ChanPolicyBounds rule. // // NOTE: this is part of the Values interface. -func (f *ChanPolicyBounds) PseudoToReal(_ firewalldb.PrivacyMapDB, - _ session.PrivacyFlags) (Values, error) { +func (f *ChanPolicyBounds) PseudoToReal(_ context.Context, + _ firewalldb.PrivacyMapDB, _ session.PrivacyFlags) (Values, error) { return f, nil } @@ -407,8 +407,9 @@ func (f *ChanPolicyBounds) PseudoToReal(_ firewalldb.PrivacyMapDB, // that should be persisted. This is a no-op for the ChanPolicyBounds rule. // // NOTE: this is part of the Values interface. -func (f *ChanPolicyBounds) RealToPseudo(_ firewalldb.PrivacyMapReader, - _ session.PrivacyFlags) (Values, map[string]string, error) { +func (f *ChanPolicyBounds) RealToPseudo(_ context.Context, + _ firewalldb.PrivacyMapReader, _ session.PrivacyFlags) (Values, + map[string]string, error) { return f, nil, nil } diff --git a/rules/channel_constraints.go b/rules/channel_constraints.go index e50e30df3..8e8524b20 100644 --- a/rules/channel_constraints.go +++ b/rules/channel_constraints.go @@ -333,8 +333,8 @@ func (v *ChannelConstraint) RuleName() string { // find the real values. This is a no-op for the ChannelConstraint rule. // // NOTE: this is part of the Values interface. -func (v *ChannelConstraint) PseudoToReal(_ firewalldb.PrivacyMapDB, - _ session.PrivacyFlags) (Values, error) { +func (v *ChannelConstraint) PseudoToReal(_ context.Context, + _ firewalldb.PrivacyMapDB, _ session.PrivacyFlags) (Values, error) { return v, nil } @@ -344,8 +344,9 @@ func (v *ChannelConstraint) PseudoToReal(_ firewalldb.PrivacyMapDB, // that should be persisted. This is a no-op for the ChannelConstraint rule. // // NOTE: this is part of the Values interface. -func (v *ChannelConstraint) RealToPseudo(_ firewalldb.PrivacyMapReader, - _ session.PrivacyFlags) (Values, map[string]string, error) { +func (v *ChannelConstraint) RealToPseudo(_ context.Context, + _ firewalldb.PrivacyMapReader, _ session.PrivacyFlags) (Values, + map[string]string, error) { return v, nil, nil } diff --git a/rules/channel_restrictions.go b/rules/channel_restrictions.go index 745ed85be..c6cb134d8 100644 --- a/rules/channel_restrictions.go +++ b/rules/channel_restrictions.go @@ -336,8 +336,9 @@ func (c *ChannelRestrict) ToProto() *litrpc.RuleValue { // It constructs a new ChannelRestrict instance with these real channel IDs. // // NOTE: this is part of the Values interface. -func (c *ChannelRestrict) PseudoToReal(db firewalldb.PrivacyMapDB, - flags session.PrivacyFlags) (Values, error) { +func (c *ChannelRestrict) PseudoToReal(ctx context.Context, + db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) (Values, + error) { restrictList := make([]uint64, len(c.DenyList)) @@ -348,9 +349,11 @@ func (c *ChannelRestrict) PseudoToReal(db firewalldb.PrivacyMapDB, return &ChannelRestrict{DenyList: restrictList}, nil } - err := db.View(func(tx firewalldb.PrivacyMapTx) error { + err := db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, chanID := range c.DenyList { - real, err := firewalldb.RevealUint64(tx, chanID) + real, err := firewalldb.RevealUint64(ctx, tx, chanID) if err != nil { return err } @@ -372,7 +375,8 @@ func (c *ChannelRestrict) PseudoToReal(db firewalldb.PrivacyMapDB, // not find in the given PrivacyMapReader. // // NOTE: this is part of the Values interface. -func (c *ChannelRestrict) RealToPseudo(db firewalldb.PrivacyMapReader, +func (c *ChannelRestrict) RealToPseudo(_ context.Context, + db firewalldb.PrivacyMapReader, flags session.PrivacyFlags) (Values, map[string]string, error) { pseudoIDs := make([]uint64, len(c.DenyList)) diff --git a/rules/channel_restrictions_test.go b/rules/channel_restrictions_test.go index a12c80916..b2d6200c3 100644 --- a/rules/channel_restrictions_test.go +++ b/rules/channel_restrictions_test.go @@ -167,6 +167,9 @@ func (m *mockLndClient) ListChannels(_ context.Context, _, _ bool) ( // method correctly determines which real strings to generate pseudo pairs for // based on the privacy map db passed to it. func TestChannelRestrictRealToPseudo(t *testing.T) { + t.Parallel() + + ctx := context.Background() chanID1 := firewalldb.Uint64ToStr(1) chanID2 := firewalldb.Uint64ToStr(2) chanID3 := firewalldb.Uint64ToStr(3) @@ -249,7 +252,7 @@ func TestChannelRestrictRealToPseudo(t *testing.T) { // form along with any new privacy map pairs that should // be added to the DB. v, newPairs, err := cr.RealToPseudo( - privMapPairDB, test.privacyFlags, + ctx, privMapPairDB, test.privacyFlags, ) require.NoError(t, err) require.Len(t, newPairs, len(test.expectNewPairs)) diff --git a/rules/history_limit.go b/rules/history_limit.go index dccebef44..be2894f42 100644 --- a/rules/history_limit.go +++ b/rules/history_limit.go @@ -256,8 +256,8 @@ func (h *HistoryLimit) GetStartDate() time.Time { // find the real values. This is a no-op for the HistoryLimit rule. // // NOTE: this is part of the Values interface. -func (h *HistoryLimit) PseudoToReal(_ firewalldb.PrivacyMapDB, - _ session.PrivacyFlags) (Values, error) { +func (h *HistoryLimit) PseudoToReal(_ context.Context, + _ firewalldb.PrivacyMapDB, _ session.PrivacyFlags) (Values, error) { return h, nil } @@ -267,8 +267,9 @@ func (h *HistoryLimit) PseudoToReal(_ firewalldb.PrivacyMapDB, // that should be persisted. This is a no-op for the HistoryLimit rule. // // NOTE: this is part of the Values interface. -func (h *HistoryLimit) RealToPseudo(_ firewalldb.PrivacyMapReader, - _ session.PrivacyFlags) (Values, map[string]string, error) { +func (h *HistoryLimit) RealToPseudo(_ context.Context, + _ firewalldb.PrivacyMapReader, _ session.PrivacyFlags) (Values, + map[string]string, error) { return h, nil, nil } diff --git a/rules/interfaces.go b/rules/interfaces.go index a1683c4c5..e657a5c03 100644 --- a/rules/interfaces.go +++ b/rules/interfaces.go @@ -64,13 +64,13 @@ type Values interface { // keys, channel IDs, channel points etc. It returns a map of any new // real to pseudo strings that should be persisted that it did not find // in the given PrivacyMapReader. - RealToPseudo(db firewalldb.PrivacyMapReader, + RealToPseudo(ctx context.Context, db firewalldb.PrivacyMapReader, flags session.PrivacyFlags) (Values, map[string]string, error) // PseudoToReal attempts to convert any appropriate pseudo fields in // the rule Values to their corresponding real values. It uses the // passed PrivacyMapDB to find the real values. - PseudoToReal(db firewalldb.PrivacyMapDB, + PseudoToReal(ctx context.Context, db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) (Values, error) } diff --git a/rules/onchain_budget.go b/rules/onchain_budget.go index 248b2f699..783e3a664 100644 --- a/rules/onchain_budget.go +++ b/rules/onchain_budget.go @@ -363,8 +363,8 @@ func (o *OnChainBudget) ToProto() *litrpc.RuleValue { // find the real values. This is a no-op for the OnChainBudget rule. // // NOTE: this is part of the Values interface. -func (o *OnChainBudget) PseudoToReal(_ firewalldb.PrivacyMapDB, - _ session.PrivacyFlags) (Values, error) { +func (o *OnChainBudget) PseudoToReal(_ context.Context, + _ firewalldb.PrivacyMapDB, _ session.PrivacyFlags) (Values, error) { return o, nil } @@ -374,8 +374,9 @@ func (o *OnChainBudget) PseudoToReal(_ firewalldb.PrivacyMapDB, // that should be persisted. This is a no-op for the OnChainBudget rule. // // NOTE: this is part of the Values interface. -func (o *OnChainBudget) RealToPseudo(db firewalldb.PrivacyMapReader, - flags session.PrivacyFlags) (Values, map[string]string, error) { +func (o *OnChainBudget) RealToPseudo(_ context.Context, + _ firewalldb.PrivacyMapReader, _ session.PrivacyFlags) (Values, + map[string]string, error) { return o, nil, nil } diff --git a/rules/peer_restrictions.go b/rules/peer_restrictions.go index fbaefe94c..009ee8ab8 100644 --- a/rules/peer_restrictions.go +++ b/rules/peer_restrictions.go @@ -381,8 +381,9 @@ func (c *PeerRestrict) ToProto() *litrpc.RuleValue { // It constructs a new PeerRestrict instance with these real peer IDs. // // NOTE: this is part of the Values interface. -func (c *PeerRestrict) PseudoToReal(db firewalldb.PrivacyMapDB, - flags session.PrivacyFlags) (Values, error) { +func (c *PeerRestrict) PseudoToReal(ctx context.Context, + db firewalldb.PrivacyMapDB, flags session.PrivacyFlags) (Values, + error) { restrictList := make([]string, len(c.DenyList)) @@ -393,9 +394,13 @@ func (c *PeerRestrict) PseudoToReal(db firewalldb.PrivacyMapDB, return &PeerRestrict{DenyList: restrictList}, nil } - err := db.View(func(tx firewalldb.PrivacyMapTx) error { + err := db.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for i, peerPubKey := range c.DenyList { - real, err := firewalldb.RevealString(tx, peerPubKey) + real, err := firewalldb.RevealString( + ctx, tx, peerPubKey, + ) if err != nil { return err } @@ -418,7 +423,8 @@ func (c *PeerRestrict) PseudoToReal(db firewalldb.PrivacyMapDB, // find in the given PrivacyMapReader. // // NOTE: this is part of the Values interface. -func (c *PeerRestrict) RealToPseudo(db firewalldb.PrivacyMapReader, +func (c *PeerRestrict) RealToPseudo(_ context.Context, + db firewalldb.PrivacyMapReader, flags session.PrivacyFlags) (Values, map[string]string, error) { pseudoIDs := make([]string, len(c.DenyList)) diff --git a/rules/peer_restrictions_test.go b/rules/peer_restrictions_test.go index faa3c18d3..abfa30540 100644 --- a/rules/peer_restrictions_test.go +++ b/rules/peer_restrictions_test.go @@ -204,6 +204,9 @@ func TestPeerRestrictCheckRequest(t *testing.T) { // method correctly determines which real strings to generate pseudo pairs for // based on the privacy map db passed to it. func TestPeerRestrictRealToPseudo(t *testing.T) { + t.Parallel() + ctx := context.Background() + tests := []struct { name string privacyFlags session.PrivacyFlags @@ -276,7 +279,7 @@ func TestPeerRestrictRealToPseudo(t *testing.T) { // form along with any new privacy map pairs that should // be added to the DB. v, newPairs, err := pr.RealToPseudo( - privMapPairDB, test.privacyFlags, + ctx, privMapPairDB, test.privacyFlags, ) require.NoError(t, err) require.Len(t, newPairs, len(test.expectNewPairs)) diff --git a/rules/rate_limit.go b/rules/rate_limit.go index 4bff4bbe0..f324721a0 100644 --- a/rules/rate_limit.go +++ b/rules/rate_limit.go @@ -267,8 +267,8 @@ func (r *RateLimit) ToProto() *litrpc.RuleValue { // find the real values. This is a no-op for the RateLimit rule. // // NOTE: this is part of the Values interface. -func (r *RateLimit) PseudoToReal(_ firewalldb.PrivacyMapDB, - _ session.PrivacyFlags) (Values, error) { +func (r *RateLimit) PseudoToReal(_ context.Context, + _ firewalldb.PrivacyMapDB, _ session.PrivacyFlags) (Values, error) { return r, nil } @@ -278,8 +278,9 @@ func (r *RateLimit) PseudoToReal(_ firewalldb.PrivacyMapDB, // that should be persisted. This is a no-op for the RateLimit rule. // // NOTE: this is part of the Values interface. -func (r *RateLimit) RealToPseudo(_ firewalldb.PrivacyMapReader, - flags session.PrivacyFlags) (Values, map[string]string, error) { +func (r *RateLimit) RealToPseudo(_ context.Context, + _ firewalldb.PrivacyMapReader, flags session.PrivacyFlags) (Values, + map[string]string, error) { return r, nil, nil } diff --git a/session_rpcserver.go b/session_rpcserver.go index 092ca2e7e..652196f59 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -355,7 +355,7 @@ func (s *sessionRpcServer) AddSession(ctx context.Context, return nil, fmt.Errorf("error fetching session: %v", err) } - rpcSession, err := s.marshalRPCSession(sess) + rpcSession, err := s.marshalRPCSession(ctx, sess) if err != nil { return nil, fmt.Errorf("error marshaling session: %v", err) } @@ -557,7 +557,7 @@ func (s *sessionRpcServer) ListSessions(ctx context.Context, Sessions: make([]*litrpc.Session, len(sessions)), } for idx, sess := range sessions { - response.Sessions[idx], err = s.marshalRPCSession(sess) + response.Sessions[idx], err = s.marshalRPCSession(ctx, sess) if err != nil { return nil, fmt.Errorf("error marshaling session: %v", err) @@ -629,14 +629,16 @@ func (s *sessionRpcServer) PrivacyMapConversion(ctx context.Context, var res string privMap := s.cfg.privMap(groupID) - err = privMap.View(func(tx firewalldb.PrivacyMapTx) error { + err = privMap.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + var err error if req.RealToPseudo { - res, err = tx.RealToPseudo(req.Input) + res, err = tx.RealToPseudo(ctx, req.Input) return err } - res, err = tx.PseudoToReal(req.Input) + res, err = tx.PseudoToReal(ctx, req.Input) return err }) if err != nil { @@ -899,8 +901,10 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, linkedGroupSession = groupSess privDB := s.cfg.privMap(groupID) - err = privDB.View(func(tx firewalldb.PrivacyMapTx) error { - knownPrivMapPairs, err = tx.FetchAllPairs() + err = privDB.View(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + + knownPrivMapPairs, err = tx.FetchAllPairs(ctx) return err }) @@ -1002,7 +1006,8 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, if privacy { var privMapPairs map[string]string v, privMapPairs, err = v.RealToPseudo( - knownPrivMapPairs, privacyFlags, + ctx, knownPrivMapPairs, + privacyFlags, ) if err != nil { return nil, err @@ -1221,9 +1226,11 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, // Register all the privacy map pairs for this session ID. privDB := s.cfg.privMap(sess.GroupID) - err = privDB.Update(func(tx firewalldb.PrivacyMapTx) error { + err = privDB.Update(ctx, func(ctx context.Context, + tx firewalldb.PrivacyMapTx) error { + for r, p := range newPrivMapPairs { - err := tx.NewPair(r, p) + err := tx.NewPair(ctx, r, p) if err != nil { return err } @@ -1272,7 +1279,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, return nil, fmt.Errorf("error fetching session: %v", err) } - rpcSession, err := s.marshalRPCSession(sess) + rpcSession, err := s.marshalRPCSession(ctx, sess) if err != nil { return nil, fmt.Errorf("error marshaling session: %v", err) } @@ -1297,7 +1304,7 @@ func (s *sessionRpcServer) ListAutopilotSessions(ctx context.Context, Sessions: make([]*litrpc.Session, len(sessions)), } for idx, sess := range sessions { - response.Sessions[idx], err = s.marshalRPCSession(sess) + response.Sessions[idx], err = s.marshalRPCSession(ctx, sess) if err != nil { return nil, fmt.Errorf("error marshaling session: %v", err) @@ -1426,8 +1433,8 @@ func marshalPerms(perms map[string][]bakery.Op) []*litrpc.Permissions { } // marshalRPCSession converts a session into its RPC counterpart. -func (s *sessionRpcServer) marshalRPCSession(sess *session.Session) ( - *litrpc.Session, error) { +func (s *sessionRpcServer) marshalRPCSession(ctx context.Context, + sess *session.Session) (*litrpc.Session, error) { rpcState, err := marshalRPCState(sess.State) if err != nil { @@ -1484,7 +1491,8 @@ func (s *sessionRpcServer) marshalRPCSession(sess *session.Session) ( sess.GroupID, ) val, err = val.PseudoToReal( - db, sess.PrivacyFlags, + ctx, db, + sess.PrivacyFlags, ) if err != nil { return nil, err