Skip to content

Commit d2b077b

Browse files
committed
session: add ListSessionsByState method
1 parent 01410f7 commit d2b077b

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

session/interface.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ type Store interface {
167167
// ListSessionsByType returns all sessions of the given type.
168168
ListSessionsByType(t Type) ([]*Session, error)
169169

170+
// ListSessionsByState returns all sessions currently known to the store
171+
// that are in the given states.
172+
ListSessionsByState(...State) ([]*Session, error)
173+
170174
// RevokeSession updates the state of the session with the given local
171175
// public key to be revoked.
172176
RevokeSession(*btcec.PublicKey) error

session/kvdb_store.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,22 @@ func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
383383
})
384384
}
385385

386+
// ListSessionsByState returns all sessions currently known to the store that
387+
// are in the given states.
388+
//
389+
// NOTE: this is part of the Store interface.
390+
func (db *BoltStore) ListSessionsByState(states ...State) ([]*Session, error) {
391+
return db.listSessions(func(s *Session) bool {
392+
for _, state := range states {
393+
if s.State == state {
394+
return true
395+
}
396+
}
397+
398+
return false
399+
})
400+
}
401+
386402
// listSessions returns all sessions currently known to the store that pass the
387403
// given filter function.
388404
func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session,

session/store_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,33 @@ func TestBasicSessionStore(t *testing.T) {
118118
assertEqualSessions(t, s1, sessions[0])
119119
assertEqualSessions(t, s2, sessions[1])
120120
assertEqualSessions(t, s3, sessions[2])
121+
122+
// Test that ListSessionsByState works.
123+
sessions, err = db.ListSessionsByState(StateRevoked)
124+
require.NoError(t, err)
125+
require.Equal(t, 1, len(sessions))
126+
assertEqualSessions(t, s1, sessions[0])
127+
128+
sessions, err = db.ListSessionsByState(StateCreated)
129+
require.NoError(t, err)
130+
require.Equal(t, 2, len(sessions))
131+
assertEqualSessions(t, s2, sessions[0])
132+
assertEqualSessions(t, s3, sessions[1])
133+
134+
sessions, err = db.ListSessionsByState(StateCreated, StateRevoked)
135+
require.NoError(t, err)
136+
require.Equal(t, 3, len(sessions))
137+
assertEqualSessions(t, s1, sessions[0])
138+
assertEqualSessions(t, s2, sessions[1])
139+
assertEqualSessions(t, s3, sessions[2])
140+
141+
sessions, err = db.ListSessionsByState()
142+
require.NoError(t, err)
143+
require.Empty(t, sessions)
144+
145+
sessions, err = db.ListSessionsByState(StateInUse)
146+
require.NoError(t, err)
147+
require.Empty(t, sessions)
121148
}
122149

123150
// TestLinkingSessions tests that session linking works as expected.

0 commit comments

Comments
 (0)