From d4155e7bd6860f150be3e6b044eeb804459d01b5 Mon Sep 17 00:00:00 2001 From: Sebastijan Zindl Date: Tue, 10 Mar 2026 14:58:27 +0100 Subject: [PATCH] feat(BRIDGE-464): add the ability to limit the amount of IMAP connections via client --- builder.go | 3 + imap/connectioncounter/connectioncounter.go | 47 +++-- .../connectioncounter_test.go | 63 +++++-- imap/connectionlimiter/client.go | 32 ++++ imap/connectionlimiter/limiter.go | 166 ++++++++++++++++++ imap/connectionlimiter/limiter_test.go | 157 +++++++++++++++++ imap/connectionlimiter/limits.go | 45 +++++ internal/session/session.go | 17 ++ internal/unleash/featureflags/flags.go | 7 +- option.go | 20 ++- server.go | 69 +++++++- 11 files changed, 589 insertions(+), 37 deletions(-) create mode 100644 imap/connectionlimiter/client.go create mode 100644 imap/connectionlimiter/limiter.go create mode 100644 imap/connectionlimiter/limiter_test.go create mode 100644 imap/connectionlimiter/limits.go diff --git a/builder.go b/builder.go index 4ba0d139..4e02f1c8 100644 --- a/builder.go +++ b/builder.go @@ -11,6 +11,7 @@ import ( "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/connectioncounter" + "github.com/ProtonMail/gluon/imap/connectionlimiter" "github.com/ProtonMail/gluon/internal/backend" "github.com/ProtonMail/gluon/internal/db_impl/sqlite3" "github.com/ProtonMail/gluon/internal/session" @@ -46,6 +47,7 @@ type serverBuilder struct { observabilitySender observability.Sender featureFlagProvider unleash.FeatureFlagValueProvider connectionRollingCounter *connectioncounter.RollingCounter + connectionLimiter connectionlimiter.ConnectionLimiter } func newBuilder() (*serverBuilder, error) { @@ -139,6 +141,7 @@ func (builder *serverBuilder) build() (*Server, error) { observabilitySender: builder.observabilitySender, connectionRollingCounter: builder.connectionRollingCounter, featureFlagProvider: builder.featureFlagProvider, + connectionLimiter: builder.connectionLimiter, } return s, nil diff --git a/imap/connectioncounter/connectioncounter.go b/imap/connectioncounter/connectioncounter.go index 34968b1c..6ebab7cb 100644 --- a/imap/connectioncounter/connectioncounter.go +++ b/imap/connectioncounter/connectioncounter.go @@ -20,7 +20,8 @@ type RollingCounter struct { log *logrus.Entry - newConnectionThreshold int + observabilityConnectionThreshold int + connectionLimitThreshold int numberOfBuckets int buckets []int @@ -35,17 +36,19 @@ type RollingCounter struct { connProvider openConnectionProvider } -func NewRollingCounter(newConnectionTreshold, numberOfBuckets int, bucketRotationInterval time.Duration) *RollingCounter { +func NewRollingCounter(connectionLimitThreshold, observabilityConnectionThreshold, numberOfBuckets int, bucketRotationInterval time.Duration) *RollingCounter { log := logrus.WithFields(logrus.Fields{ - "pkg": "gluon/rollingcounter", - "threshold": newConnectionTreshold, + "pkg": "gluon/rollingcounter", + "connectionLimitThreshold": connectionLimitThreshold, + "observabilityConnectionThreshold": observabilityConnectionThreshold, }) rc := &RollingCounter{ - newConnectionThreshold: newConnectionTreshold, - numberOfBuckets: numberOfBuckets, - bucketRotationInterval: bucketRotationInterval, - log: log, + observabilityConnectionThreshold: observabilityConnectionThreshold, + connectionLimitThreshold: connectionLimitThreshold, + numberOfBuckets: numberOfBuckets, + bucketRotationInterval: bucketRotationInterval, + log: log, } return rc @@ -74,20 +77,18 @@ func (rc *RollingCounter) Start(ctx context.Context, obsSender observability.Sen } func (rc *RollingCounter) run() { - rc.wg.Add(1) - go func() { - defer rc.wg.Done() + rc.wg.Go(func() { for { select { case <-rc.ctx.Done(): return case <-rc.bucketRotationTicker.C: - rc.thresholdCheck() + rc.observabilityThresholdCheck() rc.onBucketRotationTick() } } - }() + }) } func (rc *RollingCounter) Stop() { @@ -105,9 +106,10 @@ func (rc *RollingCounter) withBucketLock(fn func()) { fn() } -func (rc *RollingCounter) thresholdCheck() { +func (rc *RollingCounter) observabilityThresholdCheck() { rollingCount := rc.GetRollingCount() - if rollingCount < rc.newConnectionThreshold { + + if rollingCount < rc.observabilityConnectionThreshold { return } @@ -135,9 +137,24 @@ func (rc *RollingCounter) NewConnection() { } func (rc *RollingCounter) GetRollingCount() int { + return rc.getRollingCounterSafe() +} + +func (rc *RollingCounter) OverConnectionLimitThreshold() bool { rc.bucketLock.Lock() defer rc.bucketLock.Unlock() + return rc.getRollingCountUnsafe() >= rc.connectionLimitThreshold +} + +func (rc *RollingCounter) getRollingCounterSafe() int { + rc.bucketLock.Lock() + defer rc.bucketLock.Unlock() + + return rc.getRollingCountUnsafe() +} + +func (rc *RollingCounter) getRollingCountUnsafe() int { var rollingCount int for _, count := range rc.buckets { rollingCount += count diff --git a/imap/connectioncounter/connectioncounter_test.go b/imap/connectioncounter/connectioncounter_test.go index 6e139bf6..3276e9ce 100644 --- a/imap/connectioncounter/connectioncounter_test.go +++ b/imap/connectioncounter/connectioncounter_test.go @@ -1,7 +1,6 @@ package connectioncounter_test import ( - "context" "sync" "testing" "time" @@ -24,9 +23,9 @@ type mockObsSender struct { lastOpenConns, lastNewlyOpenedConns int } -func (m *mockObsSender) AddMetrics(_ ...map[string]interface{}) {} +func (m *mockObsSender) AddMetrics(_ ...map[string]any) {} -func (m *mockObsSender) AddDistinctMetrics(_ interface{}, _ ...map[string]interface{}) {} +func (m *mockObsSender) AddDistinctMetrics(_ any, _ ...map[string]any) {} func (m *mockObsSender) AddIMAPConnectionsExceededThresholdMetric(openConns, newlyOpenedConns int) { m.mu.Lock() @@ -48,9 +47,12 @@ func (m *mockObsSender) LastValues() (int, int) { return m.lastOpenConns, m.lastNewlyOpenedConns } -func TestRollingCounter_ThresholdNotExceeded(t *testing.T) { +func TestRollingCounter_ObservabilityThresholdNotExceeded(t *testing.T) { + observabilityThreshold := 5 + connectionLimitThreshold := 5 rc := connectioncounter.NewRollingCounter( - 5, + connectionLimitThreshold, + observabilityThreshold, 3, 100*time.Millisecond, ) @@ -58,8 +60,7 @@ func TestRollingCounter_ThresholdNotExceeded(t *testing.T) { mockSender := &mockObsSender{} mockProvider := &mockConnProvider{openSessions: 10} - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() rc.Start(ctx, mockSender, mockProvider) @@ -76,10 +77,12 @@ func TestRollingCounter_ThresholdNotExceeded(t *testing.T) { rc.Stop() } -func TestRollingCounter_ThresholdExceeded(t *testing.T) { - newConnThreshold := 3 +func TestRollingCounter_ObservabilityThresholdExceeded(t *testing.T) { + observabilityThreshold := 3 + connectionLimitThreshold := 5 rc := connectioncounter.NewRollingCounter( - newConnThreshold, + connectionLimitThreshold, + observabilityThreshold, 3, 100*time.Millisecond, ) @@ -87,13 +90,12 @@ func TestRollingCounter_ThresholdExceeded(t *testing.T) { mockSender := &mockObsSender{} mockProvider := &mockConnProvider{openSessions: 7} - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() rc.Start(ctx, mockSender, mockProvider) - newConnsOpened := newConnThreshold * 5 - for i := 0; i < newConnsOpened; i++ { + newConnsOpened := observabilityThreshold * 5 + for range newConnsOpened { rc.NewConnection() } @@ -108,23 +110,50 @@ func TestRollingCounter_ThresholdExceeded(t *testing.T) { rc.Stop() } +func TestRollingCounter_ConnectionLimitThresholdExceeded(t *testing.T) { + observabilityThreshold := 3 + connectionLimitThreshold := 3 + rc := connectioncounter.NewRollingCounter( + connectionLimitThreshold, + observabilityThreshold, + 3, + 100*time.Millisecond, + ) + + mockSender := &mockObsSender{} + mockProvider := &mockConnProvider{openSessions: 7} + + ctx := t.Context() + + rc.Start(ctx, mockSender, mockProvider) + + newConnsOpened := observabilityThreshold * 5 + for range newConnsOpened { + rc.NewConnection() + } + assert.Equal(t, rc.GetRollingCount(), newConnsOpened) + assert.True(t, rc.OverConnectionLimitThreshold()) + + rc.Stop() +} + func TestRollingCounter_BucketRotation(t *testing.T) { jitterPeriod := 50 * time.Millisecond bucketRotationInterval := 500 * time.Millisecond rc := connectioncounter.NewRollingCounter( + 100, 100, 3, bucketRotationInterval, ) post10Connections := func() { - for i := 0; i < 10; i++ { + for range 10 { rc.NewConnection() } } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() mockSender := &mockObsSender{} mockProvider := &mockConnProvider{} diff --git a/imap/connectionlimiter/client.go b/imap/connectionlimiter/client.go new file mode 100644 index 00000000..8581d2f1 --- /dev/null +++ b/imap/connectionlimiter/client.go @@ -0,0 +1,32 @@ +package connectionlimiter + +import ( + "strings" + + "github.com/ProtonMail/gluon/imap" +) + +type Client string + +const ( + ClientAppleMail Client = "apple-mail" + ClientOutlook Client = "outlook" + ClientThunderbird Client = "thunderbird" + ClientUnknown Client = "unknown" +) + +func normalizeClientKey(id imap.IMAPID) Client { + name := strings.TrimSpace(strings.ToLower(id.Name)) + switch { + case strings.Contains(name, "outlook"): + return ClientOutlook + case strings.Contains(name, "thunderbird"): + return ClientThunderbird + case strings.Contains(name, "mac") && strings.Contains(name, "mail"): + return ClientAppleMail + case strings.Contains(name, "mac") && strings.Contains(name, "notes"): + return ClientUnknown + default: + return ClientUnknown + } +} diff --git a/imap/connectionlimiter/limiter.go b/imap/connectionlimiter/limiter.go new file mode 100644 index 00000000..9ff247f1 --- /dev/null +++ b/imap/connectionlimiter/limiter.go @@ -0,0 +1,166 @@ +package connectionlimiter + +import ( + "sync" + + "github.com/ProtonMail/gluon/imap" + "github.com/sirupsen/logrus" +) + +type ConnectionLimiter interface { + //TryBind tries to bind a given sessionID to a client. + TryBind(sessionID int, id imap.IMAPID, useFallback bool) (allowed bool, key Client, current int, max int) + + //Unbind unbinds a given sessionID from a client. + Unbind(sessionID int) +} + +type limiter struct { + mu sync.Mutex + limits Limits + fallbackLimits Limits + + //sessionID mapping to normalized client key + sessionClient map[int]Client + + //normalized key mapping to current open sessions + clientCount map[Client]int + + log *logrus.Entry +} + +func NewConnectionLimiter(limits, fallbackLimits Limits) ConnectionLimiter { + return newLimiter(limits, fallbackLimits) +} + +func newLimiter(limits, fallbackLimits Limits) *limiter { + log := logrus.WithFields(logrus.Fields{ + "pkg": "gluon/connectionlimiter", + "limits": limits, + }) + + return &limiter{ + limits: limits, + fallbackLimits: fallbackLimits, + sessionClient: make(map[int]Client), + clientCount: make(map[Client]int), + log: log, + } +} + +func (l *limiter) TryBind(sessionID int, id imap.IMAPID, useFallback bool) (allowed bool, key Client, current int, max int) { + l.mu.Lock() + defer l.mu.Unlock() + + key = normalizeClientKey(id) + if useFallback { + l.log.WithField("fallbackLimits", l.fallbackLimits).Debug("Using fallback limits") + } + + // already bound to this client, no-op allow + if prev, ok := l.sessionClient[sessionID]; ok && prev == key { + maxUsages := l.maxForKey(key, useFallback) + + l.log.WithFields(logrus.Fields{ + "sessionID": sessionID, + "client": key, + "current": l.clientCount[key], + "max": maxUsages, + }).Info("Already bound to this client, no-op allow") + + return true, key, l.clientCount[key], maxUsages + } + + // if rebind, release the old key first + if prev, ok := l.sessionClient[sessionID]; ok { + if c := l.clientCount[prev]; c > 0 { + l.log.WithFields(logrus.Fields{ + "sessionID": sessionID, + "client": prev, + "current": c, + }).Info("Releasing old client") + + l.clientCount[prev] = c - 1 + } + + } + + max = l.maxForKey(key, useFallback) + cur := l.clientCount[key] + + if max > 0 && cur >= max { + delete(l.sessionClient, sessionID) + + return false, key, cur, max + } + + l.clientCount[key] = cur + 1 + l.sessionClient[sessionID] = key + + l.log.WithFields(logrus.Fields{ + "sessionID": sessionID, + "client": key, + "current": l.clientCount[key], + "max": max, + }).Debug("Binding session to client") + + return true, key, l.clientCount[key], max +} + +func (l *limiter) Unbind(sessionID int) { + l.mu.Lock() + defer l.mu.Unlock() + + key, ok := l.sessionClient[sessionID] + if !ok { + return + } + delete(l.sessionClient, sessionID) + + if c := l.clientCount[key]; c > 1 { + l.clientCount[key] = c - 1 + + l.log.WithFields(logrus.Fields{ + "sessionID": sessionID, + "client": key, + "current": l.clientCount[key], + }).Debug("Unbinding session from client") + + } else { + delete(l.clientCount, key) + + l.log.WithFields(logrus.Fields{ + "sessionID": sessionID, + "client": key, + }).Debug("Unbinding session from client") + } +} + +// maxForKey returns the maximum allowed current connections for the client +// If the key is not found in the limits or the fallbackLimits, it will use the unknownClientLimit value, +// which is set separately. +func (l *limiter) maxForKey(key Client, useFallback bool) int { + if useFallback { + if limit, ok := l.fallbackLimits.PerClient[key]; ok { + return limit + } + + l.log.WithFields(logrus.Fields{ + "client": key, + "limit": l.fallbackLimits.UnknownLimit, + }).Debug("Client key not found in fallbackLimits, using unknown client limit") + + return l.fallbackLimits.UnknownLimit + } else { + if limit, ok := l.limits.PerClient[key]; ok { + return limit + } + l.log.WithFields(logrus.Fields{ + "client": key, + "limit": l.limits.UnknownLimit, + }).Debug("Client key not found in limits, using unknown client limit") + + return l.limits.UnknownLimit + } + +} diff --git a/imap/connectionlimiter/limiter_test.go b/imap/connectionlimiter/limiter_test.go new file mode 100644 index 00000000..640c2443 --- /dev/null +++ b/imap/connectionlimiter/limiter_test.go @@ -0,0 +1,157 @@ +package connectionlimiter + +import ( + "testing" + + "github.com/ProtonMail/gluon/imap" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func imapID(name string) imap.IMAPID { + return imap.IMAPID{Name: name} +} + +func TestTryBind_FirstConnectionAllowed(t *testing.T) { + limits := NewLimits(map[Client]int{ClientAppleMail: 3}, 1) + fallbackLimits := NewDefaultLimits() + useFallback := false + + l := NewConnectionLimiter(limits, fallbackLimits) + allowed, key, current, max := l.TryBind(1, imapID("MacOS X Mail"), useFallback) + require.True(t, allowed) + assert.Equal(t, ClientAppleMail, key) + assert.Equal(t, 1, current) + assert.Equal(t, 3, max) +} + +func TestTryBind_OverLimitDenied(t *testing.T) { + limits := NewLimits(map[Client]int{ClientAppleMail: 2}, 1) + fallbackLimits := NewDefaultLimits() + useFallback := false + + l := NewConnectionLimiter(limits, fallbackLimits) + l.TryBind(1, imapID("MacOS X Mail"), useFallback) + l.TryBind(2, imapID("MacOS X Mail"), useFallback) + allowed, key, current, max := l.TryBind(3, imapID("MacOS X Mail"), useFallback) + require.False(t, allowed) + assert.Equal(t, ClientAppleMail, key) + assert.Equal(t, 2, current) + assert.Equal(t, 2, max) +} + +func TestUnbind_FreesSlot(t *testing.T) { + limits := NewLimits(map[Client]int{ClientAppleMail: 2}, 1) + fallbackLimits := NewDefaultLimits() + useFallback := false + + l := NewConnectionLimiter(limits, fallbackLimits) + l.TryBind(1, imapID("MacOS X Mail"), useFallback) + l.TryBind(2, imapID("MacOS X Mail"), useFallback) + l.Unbind(1) + allowed, _, current, _ := l.TryBind(3, imapID("MacOS X Mail"), useFallback) + require.True(t, allowed) + assert.Equal(t, 2, current) +} + +func TestUnbind_UnknownSessionNoop(t *testing.T) { + limits := NewLimits(map[Client]int{ClientAppleMail: 2}, 1) + fallbackLimits := NewDefaultLimits() + useFallback := false + + l := NewConnectionLimiter(limits, fallbackLimits) + l.TryBind(1, imapID("MacOS X Mail"), useFallback) + l.Unbind(999) // never bound + _, _, current, _ := l.TryBind(2, imapID("MacOS X Mail"), useFallback) + assert.Equal(t, 2, current) +} + +func TestTryBind_SameSessionSameClientNoop(t *testing.T) { + limits := NewLimits(map[Client]int{ClientAppleMail: 2}, 1) + fallbackLimits := NewDefaultLimits() + useFallback := false + + l := NewConnectionLimiter(limits, fallbackLimits) + l.TryBind(1, imapID("MacOS X Mail"), useFallback) + allowed, key, current, max := l.TryBind(1, imapID("MacOS X Mail"), useFallback) + require.True(t, allowed) + assert.Equal(t, ClientAppleMail, key) + assert.Equal(t, 1, current) + assert.Equal(t, 2, max) +} + +func TestTryBind_RebindToDifferentClient(t *testing.T) { + limits := NewLimits(map[Client]int{ + ClientAppleMail: 2, + ClientOutlook: 2, + }, 1) + fallbackLimits := NewDefaultLimits() + useFallback := false + + l := NewConnectionLimiter(limits, fallbackLimits) + l.TryBind(1, imapID("MacOS X Mail"), useFallback) + allowed, key, cur, _ := l.TryBind(1, imapID("Microsoft Outlook"), useFallback) + require.True(t, allowed) + assert.Equal(t, ClientOutlook, key) + assert.Equal(t, 1, cur) + _, _, appleCur, _ := l.TryBind(2, imapID("MacOS X Mail"), useFallback) + assert.Equal(t, 1, appleCur) +} + +func TestTryBind_UnknownClientUsesUnknownLimit(t *testing.T) { + limits := NewLimits(map[Client]int{ClientAppleMail: 10}, 2) + fallbackLimits := NewDefaultLimits() + useFallback := false + + l := NewConnectionLimiter(limits, fallbackLimits) + l.TryBind(1, imapID("SomeOtherClient"), useFallback) + l.TryBind(2, imapID("Unknown"), useFallback) + allowed, key, current, max := l.TryBind(3, imapID("Custom"), useFallback) + require.False(t, allowed) + assert.Equal(t, ClientUnknown, key) + assert.Equal(t, 2, current) + assert.Equal(t, 2, max) +} + +func TestTryBind_UnlimitedLimit(t *testing.T) { + limits := NewLimits(map[Client]int{ClientAppleMail: 0}, 1) + fallbackLimits := NewDefaultLimits() + useFallback := false + + l := NewConnectionLimiter(limits, fallbackLimits) + for i := 1; i <= 5; i++ { + allowed, key, current, max := l.TryBind(i, imapID("MacOS X Mail"), useFallback) + require.True(t, allowed) + assert.Equal(t, ClientAppleMail, key) + assert.Equal(t, i, current) + assert.Equal(t, 0, max) + } +} + +func TestTryBind_UseFallbackLimits(t *testing.T) { + limits := NewLimits(map[Client]int{ClientAppleMail: 1}, 1) + fallbackLimits := NewDefaultLimits() + + useFallback := true + + l := NewConnectionLimiter(limits, fallbackLimits) + for i := 1; i <= 5; i++ { + allowed, key, current, max := l.TryBind(i, imapID("Mac OS X Mail"), useFallback) + require.True(t, allowed) + assert.Equal(t, ClientAppleMail, key) + assert.Equal(t, i, current) + assert.Equal(t, fallbackLimits.PerClient[ClientAppleMail], max) + } +} + +func TestNormalizeClientKey_AppleMail(t *testing.T) { + assert.Equal(t, ClientAppleMail, normalizeClientKey(imapID("MacOS X Mail"))) +} +func TestNormalizeClientKey_OutlookAndThunderbird(t *testing.T) { + assert.Equal(t, ClientOutlook, normalizeClientKey(imapID("Microsoft Outlook"))) + assert.Equal(t, ClientThunderbird, normalizeClientKey(imapID("Thunderbird"))) +} +func TestNormalizeClientKey_Unknown(t *testing.T) { + assert.Equal(t, ClientUnknown, normalizeClientKey(imapID(""))) + assert.Equal(t, ClientUnknown, normalizeClientKey(imapID("Custom Client"))) +} diff --git a/imap/connectionlimiter/limits.go b/imap/connectionlimiter/limits.go new file mode 100644 index 00000000..4d757fc9 --- /dev/null +++ b/imap/connectionlimiter/limits.go @@ -0,0 +1,45 @@ +package connectionlimiter + +const ( + defaultClientLimit = 25 + unlimitedClientLimit = 0 +) + +type Limits struct { + // Normalized client name with max open sessions. + // If we want unlimited connections for a client set the limit to 0. + PerClient map[Client]int `json:"per_client"` + + // Max open sessions for unknown clients. + // If we want unlimited connections for unknown clients set the limit to 0. + UnknownLimit int `json:"unknown_limit"` +} + +func NewDefaultLimits() Limits { + return Limits{ + PerClient: map[Client]int{ + ClientAppleMail: defaultClientLimit, + ClientOutlook: defaultClientLimit, + ClientThunderbird: defaultClientLimit, + }, + UnknownLimit: defaultClientLimit, + } +} + +func NewDefaultFallbackValues() Limits { + return Limits{ + PerClient: map[Client]int{ + ClientAppleMail: unlimitedClientLimit, + ClientOutlook: unlimitedClientLimit, + ClientThunderbird: unlimitedClientLimit, + }, + UnknownLimit: unlimitedClientLimit, + } +} + +func NewLimits(perClient map[Client]int, unknownLimit int) Limits { + return Limits{ + PerClient: perClient, + UnknownLimit: unknownLimit, + } +} diff --git a/internal/session/session.go b/internal/session/session.go index bae31b63..25e7afb3 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -98,6 +98,8 @@ type Session struct { log *logrus.Entry featureFlagProvider unleash.FeatureFlagValueProvider + + closeOnce sync.Once } func New( @@ -355,3 +357,18 @@ func (s *Session) decodeMailboxName(name string) (string, error) { return utf7.Encoding.NewDecoder().String(fmt.Sprintf("INBOX%v%v", delimiter, split[1])) } + +func (s *Session) CloseWithBye(reason string) error { + var retErr error + s.closeOnce.Do(func() { + if reason == "" { + reason = "Connection closed by server" + } + + _ = response.Bye().WithMessage(reason).Send(s) + if err := s.conn.Close(); err != nil { + retErr = err + } + }) + return retErr +} diff --git a/internal/unleash/featureflags/flags.go b/internal/unleash/featureflags/flags.go index a2632eaf..0e723acf 100644 --- a/internal/unleash/featureflags/flags.go +++ b/internal/unleash/featureflags/flags.go @@ -1,6 +1,9 @@ package featureflags const ( - CommandWatcherGlobalDisabled = "InboxBridgeGenericImapOkHeartbeatDisabled" - CommandWatcherNonThunderbirdDisabled = "InboxBridgeGenericImapOkHeartbeatNonThunderbirdDisabled" + CommandWatcherGlobalDisabled = "InboxBridgeGenericImapOkHeartbeatDisabled" + CommandWatcherNonThunderbirdDisabled = "InboxBridgeGenericImapOkHeartbeatNonThunderbirdDisabled" + ConnectionLimiterDisabled = "InboxBridgeGluonConnectionLimiterDisabled" + ConnectionLimiterDefaultLimitsDisabled = "InboxBridgeGluonConnectionLimiterDefaultLimitsDisabled" + ConnectionCounterConnectionsLimitDisabled = "InboxBridgeGluonRollingCounterConnectionLimitDisabled" ) diff --git a/option.go b/option.go index e048b07a..0b667b71 100644 --- a/option.go +++ b/option.go @@ -9,6 +9,7 @@ import ( "github.com/ProtonMail/gluon/db" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/connectioncounter" + "github.com/ProtonMail/gluon/imap/connectionlimiter" "github.com/ProtonMail/gluon/internal/unleash" limits2 "github.com/ProtonMail/gluon/limits" "github.com/ProtonMail/gluon/observability" @@ -278,10 +279,11 @@ func (w withConnectionRollingCounter) config(builder *serverBuilder) { builder.connectionRollingCounter = w.rollingConnectionCounter } -func WithConnectionRollingCounter(newConnectionTreshold, numberOfBuckets int, thresholdCheckInterval time.Duration) Option { +func WithConnectionRollingCounter(connectionLimitThreshold, observabilityThreshold, numberOfBuckets int, thresholdCheckInterval time.Duration) Option { return &withConnectionRollingCounter{ rollingConnectionCounter: connectioncounter.NewRollingCounter( - newConnectionTreshold, + connectionLimitThreshold, + observabilityThreshold, numberOfBuckets, thresholdCheckInterval, )} @@ -300,3 +302,17 @@ func WithFeatureFlagProvider(featureFlagProvider unleash.FeatureFlagValueProvide featureFlagProvider: featureFlagProvider, } } + +type withConnectionLimiter struct { + limiter connectionlimiter.ConnectionLimiter +} + +func (w withConnectionLimiter) config(builder *serverBuilder) { + builder.connectionLimiter = w.limiter +} + +func WithConnectionLimiter(limits, fallbackLimits connectionlimiter.Limits) Option { + return &withConnectionLimiter{ + limiter: connectionlimiter.NewConnectionLimiter(limits, fallbackLimits), + } +} diff --git a/server.go b/server.go index d1c1df4c..a742f163 100644 --- a/server.go +++ b/server.go @@ -16,10 +16,12 @@ import ( "github.com/ProtonMail/gluon/events" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/imap/connectioncounter" + "github.com/ProtonMail/gluon/imap/connectionlimiter" "github.com/ProtonMail/gluon/internal/backend" "github.com/ProtonMail/gluon/internal/contexts" "github.com/ProtonMail/gluon/internal/session" "github.com/ProtonMail/gluon/internal/unleash" + "github.com/ProtonMail/gluon/internal/unleash/featureflags" "github.com/ProtonMail/gluon/logging" "github.com/ProtonMail/gluon/observability" "github.com/ProtonMail/gluon/profiling" @@ -100,6 +102,8 @@ type Server struct { connectionRollingCounter *connectioncounter.RollingCounter featureFlagProvider unleash.FeatureFlagValueProvider + + connectionLimiter connectionlimiter.ConnectionLimiter } // New creates a new server with the given options. @@ -224,6 +228,10 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) error { func (s *Server) serve(ctx context.Context, connCh <-chan net.Conn) { connWG := async.MakeWaitGroup(s.panicHandler) + if s.connectionLimiter != nil { + s.useConnectionLimiter(ctx, &connWG) + } + for { select { case <-ctx.Done(): @@ -239,9 +247,17 @@ func (s *Server) serve(ctx context.Context, connCh <-chan net.Conn) { logrus.Debug("Stopping serve, listener closed") return } - defer conn.Close() + disableConnectionCounterConnectionLimit := s.featureFlagProvider.GetFlagValue(featureflags.ConnectionCounterConnectionsLimitDisabled) + if !disableConnectionCounterConnectionLimit { + if s.connectionRollingCounter != nil && s.connectionRollingCounter.OverConnectionLimitThreshold() { + logrus.Debug("Rejecting IMAP session due to rolling connection count threshold") + conn.Close() + continue + } + } + connWG.Go(func() { session, sessionID := s.addSession(ctx, conn) defer s.removeSession(sessionID) @@ -265,6 +281,45 @@ func (s *Server) serve(ctx context.Context, connCh <-chan net.Conn) { } } +func (s *Server) useConnectionLimiter(ctx context.Context, connWG *async.WaitGroup) { + eventCh := s.AddWatcher(events.IMAPID{}, events.SessionRemoved{}) + connWG.Go(func() { + for { + select { + case <-ctx.Done(): + return + case <-s.serveDoneCh: + return + case ev, ok := <-eventCh: + if !ok { + return + } + switch e := ev.(type) { + case events.IMAPID: + connectionLimiterDisabled := s.featureFlagProvider.GetFlagValue(featureflags.ConnectionLimiterDisabled) + + if !connectionLimiterDisabled { + useFallback := s.featureFlagProvider.GetFlagValue(featureflags.ConnectionLimiterDefaultLimitsDisabled) + allowed, key, cur, max := s.connectionLimiter.TryBind(e.SessionID, e.IMAPID, useFallback) + if !allowed { + logrus.WithFields(logrus.Fields{ + "sessionID": e.SessionID, + "client": key, + "current": cur, + "max": max, + }).Warn("Rejecting IMAP session due to client connection limit") + + _ = s.CloseSessionByID(e.SessionID, "Too many connections for this IMAP client") + } + } + case events.SessionRemoved: + s.connectionLimiter.Unbind(e.SessionID) + } + } + } + }) +} + // GetErrorCh returns the error channel. func (s *Server) GetErrorCh() <-chan error { return s.serveErrCh.GetChannel() @@ -440,3 +495,15 @@ func (s *Server) GetRollingIMAPConnectionCount() int { return 0 } + +func (s *Server) CloseSessionByID(sessionID int, reason string) error { + s.sessionsLock.RLock() + sess, ok := s.sessions[sessionID] + s.sessionsLock.RUnlock() + + if !ok { + return fmt.Errorf("no such session: %d", sessionID) + } + + return sess.CloseWithBye(reason) +}