Skip to content

Commit 4ca1afe

Browse files
committed
feat(BRIDGE-464): add the ability to limit the amount of IMAP connections via client
chore
1 parent d2a0b8c commit 4ca1afe

11 files changed

Lines changed: 581 additions & 37 deletions

File tree

builder.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/ProtonMail/gluon/db"
1212
"github.com/ProtonMail/gluon/imap"
1313
"github.com/ProtonMail/gluon/imap/connectioncounter"
14+
"github.com/ProtonMail/gluon/imap/connectionlimiter"
1415
"github.com/ProtonMail/gluon/internal/backend"
1516
"github.com/ProtonMail/gluon/internal/db_impl/sqlite3"
1617
"github.com/ProtonMail/gluon/internal/session"
@@ -46,6 +47,7 @@ type serverBuilder struct {
4647
observabilitySender observability.Sender
4748
featureFlagProvider unleash.FeatureFlagValueProvider
4849
connectionRollingCounter *connectioncounter.RollingCounter
50+
connectionLimiter connectionlimiter.ConnectionLimiter
4951
}
5052

5153
func newBuilder() (*serverBuilder, error) {
@@ -139,6 +141,7 @@ func (builder *serverBuilder) build() (*Server, error) {
139141
observabilitySender: builder.observabilitySender,
140142
connectionRollingCounter: builder.connectionRollingCounter,
141143
featureFlagProvider: builder.featureFlagProvider,
144+
clientLimiter: builder.connectionLimiter,
142145
}
143146

144147
return s, nil

imap/connectioncounter/connectioncounter.go

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ type RollingCounter struct {
2020

2121
log *logrus.Entry
2222

23-
newConnectionThreshold int
23+
observabilityConnectionThreshold int
24+
connectionLimitThreshold int
2425

2526
numberOfBuckets int
2627
buckets []int
@@ -35,17 +36,19 @@ type RollingCounter struct {
3536
connProvider openConnectionProvider
3637
}
3738

38-
func NewRollingCounter(newConnectionTreshold, numberOfBuckets int, bucketRotationInterval time.Duration) *RollingCounter {
39+
func NewRollingCounter(connectionLimitThreshold, observabilityConnectionThreshold, numberOfBuckets int, bucketRotationInterval time.Duration) *RollingCounter {
3940
log := logrus.WithFields(logrus.Fields{
40-
"pkg": "gluon/rollingcounter",
41-
"threshold": newConnectionTreshold,
41+
"pkg": "gluon/rollingcounter",
42+
"connectionLimitThreshold": connectionLimitThreshold,
43+
"observabilityConnectionThreshold": observabilityConnectionThreshold,
4244
})
4345

4446
rc := &RollingCounter{
45-
newConnectionThreshold: newConnectionTreshold,
46-
numberOfBuckets: numberOfBuckets,
47-
bucketRotationInterval: bucketRotationInterval,
48-
log: log,
47+
observabilityConnectionThreshold: observabilityConnectionThreshold,
48+
connectionLimitThreshold: connectionLimitThreshold,
49+
numberOfBuckets: numberOfBuckets,
50+
bucketRotationInterval: bucketRotationInterval,
51+
log: log,
4952
}
5053

5154
return rc
@@ -74,20 +77,18 @@ func (rc *RollingCounter) Start(ctx context.Context, obsSender observability.Sen
7477
}
7578

7679
func (rc *RollingCounter) run() {
77-
rc.wg.Add(1)
78-
go func() {
79-
defer rc.wg.Done()
80+
rc.wg.Go(func() {
8081
for {
8182
select {
8283
case <-rc.ctx.Done():
8384
return
8485

8586
case <-rc.bucketRotationTicker.C:
86-
rc.thresholdCheck()
87+
rc.observabilityThresholdCheck()
8788
rc.onBucketRotationTick()
8889
}
8990
}
90-
}()
91+
})
9192
}
9293

9394
func (rc *RollingCounter) Stop() {
@@ -105,9 +106,10 @@ func (rc *RollingCounter) withBucketLock(fn func()) {
105106
fn()
106107
}
107108

108-
func (rc *RollingCounter) thresholdCheck() {
109+
func (rc *RollingCounter) observabilityThresholdCheck() {
109110
rollingCount := rc.GetRollingCount()
110-
if rollingCount < rc.newConnectionThreshold {
111+
112+
if rollingCount < rc.observabilityConnectionThreshold {
111113
return
112114
}
113115

@@ -135,9 +137,24 @@ func (rc *RollingCounter) NewConnection() {
135137
}
136138

137139
func (rc *RollingCounter) GetRollingCount() int {
140+
return rc.getRollingCounterSafe()
141+
}
142+
143+
func (rc *RollingCounter) OverConnectionLimitThreshold() bool {
138144
rc.bucketLock.Lock()
139145
defer rc.bucketLock.Unlock()
140146

147+
return rc.getRollingCountUnsafe() >= rc.connectionLimitThreshold
148+
}
149+
150+
func (rc *RollingCounter) getRollingCounterSafe() int {
151+
rc.bucketLock.Lock()
152+
defer rc.bucketLock.Unlock()
153+
154+
return rc.getRollingCountUnsafe()
155+
}
156+
157+
func (rc *RollingCounter) getRollingCountUnsafe() int {
141158
var rollingCount int
142159
for _, count := range rc.buckets {
143160
rollingCount += count

imap/connectioncounter/connectioncounter_test.go

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package connectioncounter_test
22

33
import (
4-
"context"
54
"sync"
65
"testing"
76
"time"
@@ -24,9 +23,9 @@ type mockObsSender struct {
2423
lastOpenConns, lastNewlyOpenedConns int
2524
}
2625

27-
func (m *mockObsSender) AddMetrics(_ ...map[string]interface{}) {}
26+
func (m *mockObsSender) AddMetrics(_ ...map[string]any) {}
2827

29-
func (m *mockObsSender) AddDistinctMetrics(_ interface{}, _ ...map[string]interface{}) {}
28+
func (m *mockObsSender) AddDistinctMetrics(_ any, _ ...map[string]any) {}
3029

3130
func (m *mockObsSender) AddIMAPConnectionsExceededThresholdMetric(openConns, newlyOpenedConns int) {
3231
m.mu.Lock()
@@ -48,18 +47,20 @@ func (m *mockObsSender) LastValues() (int, int) {
4847
return m.lastOpenConns, m.lastNewlyOpenedConns
4948
}
5049

51-
func TestRollingCounter_ThresholdNotExceeded(t *testing.T) {
50+
func TestRollingCounter_ObservabilityThresholdNotExceeded(t *testing.T) {
51+
observabilityThreshold := 5
52+
connectionLimitThreshold := 5
5253
rc := connectioncounter.NewRollingCounter(
53-
5,
54+
connectionLimitThreshold,
55+
observabilityThreshold,
5456
3,
5557
100*time.Millisecond,
5658
)
5759

5860
mockSender := &mockObsSender{}
5961
mockProvider := &mockConnProvider{openSessions: 10}
6062

61-
ctx, cancel := context.WithCancel(context.Background())
62-
defer cancel()
63+
ctx := t.Context()
6364

6465
rc.Start(ctx, mockSender, mockProvider)
6566

@@ -76,24 +77,25 @@ func TestRollingCounter_ThresholdNotExceeded(t *testing.T) {
7677
rc.Stop()
7778
}
7879

79-
func TestRollingCounter_ThresholdExceeded(t *testing.T) {
80-
newConnThreshold := 3
80+
func TestRollingCounter_ObservabilityThresholdExceeded(t *testing.T) {
81+
observabilityThreshold := 3
82+
connectionLimitThreshold := 5
8183
rc := connectioncounter.NewRollingCounter(
82-
newConnThreshold,
84+
connectionLimitThreshold,
85+
observabilityThreshold,
8386
3,
8487
100*time.Millisecond,
8588
)
8689

8790
mockSender := &mockObsSender{}
8891
mockProvider := &mockConnProvider{openSessions: 7}
8992

90-
ctx, cancel := context.WithCancel(context.Background())
91-
defer cancel()
93+
ctx := t.Context()
9294

9395
rc.Start(ctx, mockSender, mockProvider)
9496

95-
newConnsOpened := newConnThreshold * 5
96-
for i := 0; i < newConnsOpened; i++ {
97+
newConnsOpened := observabilityThreshold * 5
98+
for range newConnsOpened {
9799
rc.NewConnection()
98100
}
99101

@@ -108,23 +110,50 @@ func TestRollingCounter_ThresholdExceeded(t *testing.T) {
108110
rc.Stop()
109111
}
110112

113+
func TestRollingCounter_ConnectionLimitThresholdExceeded(t *testing.T) {
114+
observabilityThreshold := 3
115+
connectionLimitThreshold := 3
116+
rc := connectioncounter.NewRollingCounter(
117+
connectionLimitThreshold,
118+
observabilityThreshold,
119+
3,
120+
100*time.Millisecond,
121+
)
122+
123+
mockSender := &mockObsSender{}
124+
mockProvider := &mockConnProvider{openSessions: 7}
125+
126+
ctx := t.Context()
127+
128+
rc.Start(ctx, mockSender, mockProvider)
129+
130+
newConnsOpened := observabilityThreshold * 5
131+
for range newConnsOpened {
132+
rc.NewConnection()
133+
}
134+
assert.Equal(t, rc.GetRollingCount(), newConnsOpened)
135+
assert.True(t, rc.OverConnectionLimitThreshold())
136+
137+
rc.Stop()
138+
}
139+
111140
func TestRollingCounter_BucketRotation(t *testing.T) {
112141
jitterPeriod := 50 * time.Millisecond
113142
bucketRotationInterval := 500 * time.Millisecond
114143
rc := connectioncounter.NewRollingCounter(
144+
100,
115145
100,
116146
3,
117147
bucketRotationInterval,
118148
)
119149

120150
post10Connections := func() {
121-
for i := 0; i < 10; i++ {
151+
for range 10 {
122152
rc.NewConnection()
123153
}
124154
}
125155

126-
ctx, cancel := context.WithCancel(context.Background())
127-
defer cancel()
156+
ctx := t.Context()
128157

129158
mockSender := &mockObsSender{}
130159
mockProvider := &mockConnProvider{}

imap/connectionlimiter/client.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package connectionlimiter
2+
3+
import (
4+
"strings"
5+
6+
"github.com/ProtonMail/gluon/imap"
7+
)
8+
9+
type Client string
10+
11+
const (
12+
ClientAppleMail Client = "apple-mail"
13+
ClientOutlook Client = "outlook"
14+
ClientThunderbird Client = "thunderbird"
15+
ClientUnknown Client = "unknown"
16+
)
17+
18+
func normalizeClientKey(id imap.IMAPID) Client {
19+
name := strings.TrimSpace(strings.ToLower(id.Name))
20+
switch {
21+
case strings.Contains(name, "outlook"):
22+
return ClientOutlook
23+
case strings.Contains(name, "thunderbird"):
24+
return ClientThunderbird
25+
case strings.Contains(name, "mac") && strings.Contains(name, "mail"):
26+
return ClientAppleMail
27+
case strings.Contains(name, "mac") && strings.Contains(name, "notes"):
28+
return ClientUnknown
29+
default:
30+
return ClientUnknown
31+
}
32+
}

0 commit comments

Comments
 (0)