Skip to content

Commit 38afe62

Browse files
committed
session: remove the filter fn in ListSessions
And instead let the caller pass in a list of States they are interested in. This will make SQL queries much more efficient since we can index by state.
1 parent 7a91212 commit 38afe62

File tree

4 files changed

+55
-18
lines changed

4 files changed

+55
-18
lines changed

session/interface.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,10 @@ type Store interface {
161161
// GetSession fetches the session with the given key.
162162
GetSession(key *btcec.PublicKey) (*Session, error)
163163

164-
// ListSessions returns all sessions currently known to the store.
165-
ListSessions(filterFn func(s *Session) bool) ([]*Session, error)
164+
// ListSessions returns all sessions currently known to the store that
165+
// are in the given states. If no states are provided, all sessions are
166+
// returned.
167+
ListSessions(states ...State) ([]*Session, error)
166168

167169
// ListSessionsByType returns all sessions of the given type.
168170
ListSessionsByType(t Type) ([]*Session, error)

session/kvdb_store.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,11 +364,24 @@ func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) {
364364
return session, nil
365365
}
366366

367-
// ListSessions returns all sessions currently known to the store.
367+
// ListSessions returns all sessions currently known to the store that are in
368+
// the given states. If no states are provided, all sessions are returned.
368369
//
369370
// NOTE: this is part of the Store interface.
370-
func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, error) {
371-
return db.listSessions(filterFn)
371+
func (db *BoltStore) ListSessions(states ...State) ([]*Session, error) {
372+
return db.listSessions(func(s *Session) bool {
373+
if len(states) == 0 {
374+
return true
375+
}
376+
377+
for _, state := range states {
378+
if s.State == state {
379+
return true
380+
}
381+
}
382+
383+
return false
384+
})
372385
}
373386

374387
// ListSessionsByType returns all sessions currently known to the store that

session/store_test.go

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,8 @@ func TestBasicSessionStore(t *testing.T) {
5656
require.NoError(t, db.CreateSession(s2))
5757
require.NoError(t, db.CreateSession(s3))
5858

59-
// Check that all sessions are returned in ListSessions.
60-
sessions, err := db.ListSessions(nil)
61-
require.NoError(t, err)
62-
require.Equal(t, 3, len(sessions))
63-
assertEqualSessions(t, s1, sessions[0])
64-
assertEqualSessions(t, s2, sessions[1])
65-
assertEqualSessions(t, s3, sessions[2])
66-
6759
// Test the ListSessionsByType method.
68-
sessions, err = db.ListSessionsByType(TypeMacaroonAdmin)
60+
sessions, err := db.ListSessionsByType(TypeMacaroonAdmin)
6961
require.NoError(t, err)
7062
require.Equal(t, 2, len(sessions))
7163
assertEqualSessions(t, s1, sessions[0])
@@ -115,9 +107,39 @@ func TestBasicSessionStore(t *testing.T) {
115107

116108
// Now revoke the session and assert that the state is revoked.
117109
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
118-
session1, err = db.GetSession(s1.LocalPublicKey)
110+
s1, err = db.GetSession(s1.LocalPublicKey)
119111
require.NoError(t, err)
120-
require.Equal(t, session1.State, StateRevoked)
112+
require.Equal(t, s1.State, StateRevoked)
113+
114+
// Test that ListSessions by certain states works.
115+
sessions, err = db.ListSessions(StateRevoked)
116+
require.NoError(t, err)
117+
require.Equal(t, 1, len(sessions))
118+
assertEqualSessions(t, s1, sessions[0])
119+
120+
sessions, err = db.ListSessions(StateCreated)
121+
require.NoError(t, err)
122+
require.Equal(t, 2, len(sessions))
123+
assertEqualSessions(t, s2, sessions[0])
124+
assertEqualSessions(t, s3, sessions[1])
125+
126+
sessions, err = db.ListSessions(StateCreated, StateRevoked)
127+
require.NoError(t, err)
128+
require.Equal(t, 3, len(sessions))
129+
assertEqualSessions(t, s1, sessions[0])
130+
assertEqualSessions(t, s2, sessions[1])
131+
assertEqualSessions(t, s3, sessions[2])
132+
133+
sessions, err = db.ListSessions()
134+
require.NoError(t, err)
135+
require.Equal(t, 3, len(sessions))
136+
assertEqualSessions(t, s1, sessions[0])
137+
assertEqualSessions(t, s2, sessions[1])
138+
assertEqualSessions(t, s3, sessions[2])
139+
140+
sessions, err = db.ListSessions(StateInUse)
141+
require.NoError(t, err)
142+
require.Empty(t, sessions)
121143
}
122144

123145
// TestLinkingSessions tests that session linking works as expected.

session_rpcserver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
101101
// requests. This includes resuming all non-revoked sessions.
102102
func (s *sessionRpcServer) start(ctx context.Context) error {
103103
// Start up all previously created sessions.
104-
sessions, err := s.cfg.db.ListSessions(nil)
104+
sessions, err := s.cfg.db.ListSessions()
105105
if err != nil {
106106
return fmt.Errorf("error listing sessions: %v", err)
107107
}
@@ -536,7 +536,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
536536
func (s *sessionRpcServer) ListSessions(_ context.Context,
537537
_ *litrpc.ListSessionsRequest) (*litrpc.ListSessionsResponse, error) {
538538

539-
sessions, err := s.cfg.db.ListSessions(nil)
539+
sessions, err := s.cfg.db.ListSessions()
540540
if err != nil {
541541
return nil, fmt.Errorf("error fetching sessions: %v", err)
542542
}

0 commit comments

Comments
 (0)