Skip to content

Commit 9f950d6

Browse files
feat(GODT-1724): Guard database access with an RWLock
Ensure all the database queries are accessed through the scope of RWLock to avoid the infamous "database locked" error and to allow more parallel read only requests from clients. The RWLock guarantees that only one goroutine/thread can write to the datbase, therefore ensuring we never trigger the limitations of SQLite DB which causes the error. The removal of the global TX Lock now enables clients to perform more read only request in parallel. Going forward it is expected all database access to be accessed via the `DB` type.
1 parent 347513f commit 9f950d6

25 files changed

Lines changed: 559 additions & 383 deletions

benchmarks/gluon_bench/store_benchmarks/create.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ package store_benchmarks
22

33
import (
44
"context"
5-
"github.com/ProtonMail/gluon/imap"
65

76
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/benchmark"
87
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/flags"
98
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/reporter"
109
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/timing"
10+
"github.com/ProtonMail/gluon/imap"
1111
"github.com/ProtonMail/gluon/store"
1212
"github.com/google/uuid"
1313
)

benchmarks/gluon_bench/store_benchmarks/delete.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ package store_benchmarks
22

33
import (
44
"context"
5-
"github.com/ProtonMail/gluon/imap"
65

76
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/benchmark"
87
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/flags"
98
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/reporter"
109
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/timing"
10+
"github.com/ProtonMail/gluon/imap"
1111
"github.com/ProtonMail/gluon/store"
1212
)
1313

benchmarks/gluon_bench/store_benchmarks/get.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ package store_benchmarks
22

33
import (
44
"context"
5-
"github.com/ProtonMail/gluon/imap"
65
"math/rand"
76

87
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/benchmark"
98
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/flags"
109
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/reporter"
1110
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/timing"
11+
"github.com/ProtonMail/gluon/imap"
1212
"github.com/ProtonMail/gluon/store"
1313
)
1414

benchmarks/gluon_bench/store_benchmarks/utils.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ package store_benchmarks
22

33
import (
44
"context"
5-
"github.com/ProtonMail/gluon/imap"
65
"sync"
76
"time"
87

98
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/flags"
109
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/reporter"
1110
"github.com/ProtonMail/gluon/benchmarks/gluon_bench/timing"
11+
"github.com/ProtonMail/gluon/imap"
1212
"github.com/ProtonMail/gluon/store"
1313
"github.com/google/uuid"
1414
)

internal/backend/db.go

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,41 @@ func (d *DB) Init(ctx context.Context) error {
2424
}
2525

2626
func (d *DB) Read(ctx context.Context, fn func(context.Context, *ent.Client) error) error {
27-
d.lock.RLock()
28-
defer d.lock.Unlock()
27+
_, err := DBReadResult(ctx, d, func(ctx context.Context, client *ent.Client) (struct{}, error) {
28+
return struct{}{}, fn(ctx, client)
29+
})
2930

30-
return fn(ctx, d.db)
31+
return err
3132
}
3233

3334
func (d *DB) Write(ctx context.Context, fn func(context.Context, *ent.Tx) error) error {
34-
d.lock.Lock()
35-
defer d.lock.Unlock()
35+
_, err := DBWriteResult(ctx, d, func(ctx context.Context, tx *ent.Tx) (struct{}, error) {
36+
return struct{}{}, fn(ctx, tx)
37+
})
38+
39+
return err
40+
}
41+
42+
func (d *DB) Close() error {
43+
return d.db.Close()
44+
}
45+
46+
func DBReadResult[T any](ctx context.Context, db *DB, fn func(context.Context, *ent.Client) (T, error)) (T, error) {
47+
db.lock.RLock()
48+
defer db.lock.RUnlock()
3649

37-
tx, err := d.db.Tx(ctx)
50+
return fn(ctx, db.db)
51+
}
52+
53+
func DBWriteResult[T any](ctx context.Context, db *DB, fn func(context.Context, *ent.Tx) (T, error)) (T, error) {
54+
db.lock.Lock()
55+
defer db.lock.Unlock()
56+
57+
var failResult T
58+
59+
tx, err := db.db.Tx(ctx)
3860
if err != nil {
39-
return err
61+
return failResult, err
4062
}
4163

4264
defer func() {
@@ -49,23 +71,20 @@ func (d *DB) Write(ctx context.Context, fn func(context.Context, *ent.Tx) error)
4971
}
5072
}()
5173

52-
if err := fn(ctx, tx); err != nil {
74+
result, err := fn(ctx, tx)
75+
if err != nil {
5376
if rerr := tx.Rollback(); rerr != nil {
54-
return fmt.Errorf("rolling back transaction: %w", rerr)
77+
return failResult, fmt.Errorf("rolling back transaction: %w", rerr)
5578
}
5679

57-
return err
80+
return failResult, err
5881
}
5982

6083
if err := tx.Commit(); err != nil {
61-
return fmt.Errorf("committing transaction: %w", err)
84+
return failResult, fmt.Errorf("committing transaction: %w", err)
6285
}
6386

64-
return nil
65-
}
66-
67-
func (d *DB) Close() error {
68-
return d.db.Close()
87+
return result, nil
6988
}
7089

7190
func (b *Backend) newDB(userID string) (*DB, error) {

internal/backend/db_mailbox.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func DBGetMailboxMessageIDPairs(ctx context.Context, client *ent.Client, mailbox
166166
}), nil
167167
}
168168

169-
func DBGetUIDInterval(ctx context.Context, mbox *ent.Mailbox, begin, end int) ([]*ent.UID, error) {
169+
func DBGetUIDInterval(ctx context.Context, client *ent.Client, mbox *ent.Mailbox, begin, end int) ([]*ent.UID, error) {
170170
return mbox.QueryUIDs().
171171
Where(uid.UIDGTE(begin), uid.UIDLTE(end)).
172172
WithMessage().
@@ -211,15 +211,15 @@ func DBGetMailboxByID(ctx context.Context, client *ent.Client, id imap.InternalM
211211
return client.Mailbox.Query().Where(mailbox.MailboxID(id)).Only(ctx)
212212
}
213213

214-
func DBGetMailboxMessages(ctx context.Context, mbox *ent.Mailbox) ([]*ent.UID, error) {
214+
func DBGetMailboxMessages(ctx context.Context, client *ent.Client, mbox *ent.Mailbox) ([]*ent.UID, error) {
215215
return mbox.QueryUIDs().WithMessage().All(ctx)
216216
}
217217

218-
func DBGetMailboxRecentCount(ctx context.Context, mbox *ent.Mailbox) (int, error) {
218+
func DBGetMailboxRecentCount(ctx context.Context, client *ent.Client, mbox *ent.Mailbox) (int, error) {
219219
return mbox.QueryUIDs().Where(uid.Recent(true)).Count(ctx)
220220
}
221221

222-
func DBGetMailboxMessagesForNewSnapshot(ctx context.Context, mbox *ent.Mailbox) ([]*ent.UID, error) {
222+
func DBGetMailboxMessagesForNewSnapshot(ctx context.Context, client *ent.Client, mbox *ent.Mailbox) ([]*ent.UID, error) {
223223
var msgUIDs []*ent.UID
224224

225225
const QueryLimit = 16000

internal/backend/db_message.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,8 @@ func DBDeleteMessages(ctx context.Context, tx *ent.Tx, messageIDs ...imap.Intern
440440
return nil
441441
}
442442

443-
func DBGetMessageIDsMarkedDeleted(ctx context.Context, tx *ent.Tx) ([]imap.InternalMessageID, error) {
444-
messages, err := tx.Message.Query().Where(message.Deleted(true)).Select(message.FieldMessageID).All(ctx)
443+
func DBGetMessageIDsMarkedDeleted(ctx context.Context, client *ent.Client) ([]imap.InternalMessageID, error) {
444+
messages, err := client.Message.Query().Where(message.Deleted(true)).Select(message.FieldMessageID).All(ctx)
445445
if err != nil {
446446
return nil, err
447447
}

internal/backend/mailbox.go

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66

77
"github.com/ProtonMail/gluon/imap"
88
"github.com/ProtonMail/gluon/internal/backend/ent"
9-
"github.com/ProtonMail/gluon/internal/backend/ent/mailbox"
109
"github.com/ProtonMail/gluon/internal/parser/proto"
1110
"github.com/ProtonMail/gluon/internal/response"
1211
"github.com/ProtonMail/gluon/rfc822"
@@ -16,7 +15,6 @@ import (
1615
)
1716

1817
type Mailbox struct {
19-
tx *ent.Tx
2018
mbox *ent.Mailbox
2119

2220
state *State
@@ -26,9 +24,8 @@ type Mailbox struct {
2624
readOnly bool
2725
}
2826

29-
func newMailbox(tx *ent.Tx, mbox *ent.Mailbox, state *State, snap *snapshot) *Mailbox {
27+
func newMailbox(mbox *ent.Mailbox, state *State, snap *snapshot) *Mailbox {
3028
return &Mailbox{
31-
tx: tx,
3229
mbox: mbox,
3330

3431
state: state,
@@ -133,23 +130,32 @@ func (m *Mailbox) Append(ctx context.Context, literal []byte, flags imap.FlagSet
133130

134131
if len(internalID) > 0 {
135132
msgID := imap.InternalMessageID(internalID)
136-
if exists, err := DBHasMessageWithID(ctx, m.tx.Client(), msgID); err != nil || !exists {
133+
134+
if exists, err := DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) (bool, error) {
135+
return DBHasMessageWithID(ctx, client, msgID)
136+
}); err != nil || !exists {
137137
logrus.WithError(err).Warn("The message has an unknown internal ID")
138-
} else if res, err := m.state.actionAddMessagesToMailbox(ctx, m.tx, []MessageIDPair{NewMessageIDPairWithoutRemote(msgID)}, NewMailboxIDPair(m.mbox)); err != nil {
138+
} else if res, err := DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]int, error) {
139+
return m.state.actionAddMessagesToMailbox(ctx, tx, []MessageIDPair{NewMessageIDPairWithoutRemote(msgID)}, NewMailboxIDPair(m.mbox))
140+
}); err != nil {
139141
return 0, err
140142
} else {
141143
return res[msgID], nil
142144
}
143145
}
144146

145-
return m.state.actionCreateMessage(ctx, m.tx, m.snap.mboxID, literal, flags, date)
147+
return DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) (int, error) {
148+
return m.state.actionCreateMessage(ctx, tx, m.snap.mboxID, literal, flags, date)
149+
})
146150
}
147151

148152
// Copy copies the messages represented by the given sequence set into the mailbox with the given name.
149153
// If the context is a UID context, the sequence set refers to message UIDs.
150154
// If no items are copied the response object will be nil.
151155
func (m *Mailbox) Copy(ctx context.Context, seq *proto.SequenceSet, name string) (response.Item, error) {
152-
mbox, err := m.tx.Mailbox.Query().Where(mailbox.Name(name)).Only(ctx)
156+
mbox, err := DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) {
157+
return DBGetMailboxByName(ctx, client, name)
158+
})
153159
if err != nil {
154160
return nil, ErrNoSuchMailbox
155161
}
@@ -167,7 +173,9 @@ func (m *Mailbox) Copy(ctx context.Context, seq *proto.SequenceSet, name string)
167173
return msg.UID
168174
})
169175

170-
destUIDs, err := m.state.actionAddMessagesToMailbox(ctx, m.tx, msgIDs, NewMailboxIDPair(mbox))
176+
destUIDs, err := DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]int, error) {
177+
return m.state.actionAddMessagesToMailbox(ctx, tx, msgIDs, NewMailboxIDPair(mbox))
178+
})
171179
if err != nil {
172180
return nil, err
173181
}
@@ -187,7 +195,9 @@ func (m *Mailbox) Copy(ctx context.Context, seq *proto.SequenceSet, name string)
187195
// If the context is a UID context, the sequence set refers to message UIDs.
188196
// If no items are moved the response object will be nil.
189197
func (m *Mailbox) Move(ctx context.Context, seq *proto.SequenceSet, name string) (response.Item, error) {
190-
mbox, err := m.tx.Mailbox.Query().Where(mailbox.Name(name)).Only(ctx)
198+
mbox, err := DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) {
199+
return DBGetMailboxByName(ctx, client, name)
200+
})
191201
if err != nil {
192202
return nil, ErrNoSuchMailbox
193203
}
@@ -205,7 +215,9 @@ func (m *Mailbox) Move(ctx context.Context, seq *proto.SequenceSet, name string)
205215
return msg.UID
206216
})
207217

208-
destUIDs, err := m.state.actionMoveMessages(ctx, m.tx, msgIDs, m.snap.mboxID, NewMailboxIDPair(mbox))
218+
destUIDs, err := DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]int, error) {
219+
return m.state.actionMoveMessages(ctx, tx, msgIDs, m.snap.mboxID, NewMailboxIDPair(mbox))
220+
})
209221
if err != nil {
210222
return nil, err
211223
}
@@ -231,24 +243,26 @@ func (m *Mailbox) Store(ctx context.Context, seq *proto.SequenceSet, operation p
231243
return msg.ID
232244
})
233245

234-
switch operation {
235-
case proto.Operation_Add:
236-
if _, err := m.state.actionAddMessageFlags(ctx, m.tx, msgIDs, flags); err != nil {
237-
return err
246+
return m.state.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error {
247+
switch operation {
248+
case proto.Operation_Add:
249+
if _, err := m.state.actionAddMessageFlags(ctx, tx, msgIDs, flags); err != nil {
250+
return err
251+
}
252+
253+
case proto.Operation_Remove:
254+
if _, err := m.state.actionRemoveMessageFlags(ctx, tx, msgIDs, flags); err != nil {
255+
return err
256+
}
257+
258+
case proto.Operation_Replace:
259+
if err := m.state.actionSetMessageFlags(ctx, tx, msgIDs, flags); err != nil {
260+
return err
261+
}
238262
}
239263

240-
case proto.Operation_Remove:
241-
if _, err := m.state.actionRemoveMessageFlags(ctx, m.tx, msgIDs, flags); err != nil {
242-
return err
243-
}
244-
245-
case proto.Operation_Replace:
246-
if err := m.state.actionSetMessageFlags(ctx, m.tx, msgIDs, flags); err != nil {
247-
return err
248-
}
249-
}
250-
251-
return nil
264+
return nil
265+
})
252266
}
253267

254268
func (m *Mailbox) Expunge(ctx context.Context, seq *proto.SequenceSet) error {
@@ -276,17 +290,21 @@ func (m *Mailbox) expunge(ctx context.Context, messages []*snapMsg) error {
276290
return msg.ID
277291
})
278292

279-
return m.state.actionRemoveMessagesFromMailbox(ctx, m.tx, msgIDs, m.snap.mboxID)
293+
return m.state.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error {
294+
return m.state.actionRemoveMessagesFromMailbox(ctx, tx, msgIDs, m.snap.mboxID)
295+
})
280296
}
281297

282298
func (m *Mailbox) Flush(ctx context.Context, permitExpunge bool) ([]response.Response, error) {
283-
return m.state.flushResponses(ctx, m.tx, permitExpunge)
299+
return DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) ([]response.Response, error) {
300+
return m.state.flushResponses(ctx, tx, permitExpunge)
301+
})
284302
}
285303

286304
func (m *Mailbox) Close(ctx context.Context) error {
287305
if err := m.state.deleteConnMetadata(); err != nil {
288306
return err
289307
}
290308

291-
return m.state.close(ctx, m.tx)
309+
return m.state.close()
292310
}

internal/backend/mailbox_fetch.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ func (m *Mailbox) fetchItems(ctx context.Context, msg *snapMsg, attributes []*pr
4343
setSeen bool
4444
)
4545

46-
message, err := DBGetMessage(ctx, m.tx.Client(), msg.ID.InternalID)
46+
message, err := DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) (*ent.Message, error) {
47+
return DBGetMessage(ctx, client, msg.ID.InternalID)
48+
})
4749
if err != nil {
4850
return 0, nil, err
4951
}
@@ -90,7 +92,9 @@ func (m *Mailbox) fetchItems(ctx context.Context, msg *snapMsg, attributes []*pr
9092
}
9193

9294
if setSeen {
93-
newFlags, err := m.state.actionAddMessageFlags(ctx, m.tx, []MessageIDPair{msg.ID}, imap.NewFlagSet(imap.FlagSeen))
95+
newFlags, err := DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]imap.FlagSet, error) {
96+
return m.state.actionAddMessageFlags(ctx, tx, []MessageIDPair{msg.ID}, imap.NewFlagSet(imap.FlagSeen))
97+
})
9498
if err != nil {
9599
return 0, nil, err
96100
}

0 commit comments

Comments
 (0)