Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/ProtonMail/gluon/db"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/imap/connectioncounter"
"github.com/ProtonMail/gluon/imap/connectionlimiter"
"github.com/ProtonMail/gluon/internal/backend"
"github.com/ProtonMail/gluon/internal/db_impl/sqlite3"
"github.com/ProtonMail/gluon/internal/session"
Expand Down Expand Up @@ -46,6 +47,7 @@ type serverBuilder struct {
observabilitySender observability.Sender
featureFlagProvider unleash.FeatureFlagValueProvider
connectionRollingCounter *connectioncounter.RollingCounter
connectionLimiter connectionlimiter.ConnectionLimiter
}

func newBuilder() (*serverBuilder, error) {
Expand Down Expand Up @@ -139,6 +141,7 @@ func (builder *serverBuilder) build() (*Server, error) {
observabilitySender: builder.observabilitySender,
connectionRollingCounter: builder.connectionRollingCounter,
featureFlagProvider: builder.featureFlagProvider,
connectionLimiter: builder.connectionLimiter,
}

return s, nil
Expand Down
47 changes: 32 additions & 15 deletions imap/connectioncounter/connectioncounter.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ type RollingCounter struct {

log *logrus.Entry

newConnectionThreshold int
observabilityConnectionThreshold int
connectionLimitThreshold int

numberOfBuckets int
buckets []int
Expand All @@ -35,17 +36,19 @@ type RollingCounter struct {
connProvider openConnectionProvider
}

func NewRollingCounter(newConnectionTreshold, numberOfBuckets int, bucketRotationInterval time.Duration) *RollingCounter {
func NewRollingCounter(connectionLimitThreshold, observabilityConnectionThreshold, numberOfBuckets int, bucketRotationInterval time.Duration) *RollingCounter {
log := logrus.WithFields(logrus.Fields{
"pkg": "gluon/rollingcounter",
"threshold": newConnectionTreshold,
"pkg": "gluon/rollingcounter",
"connectionLimitThreshold": connectionLimitThreshold,
"observabilityConnectionThreshold": observabilityConnectionThreshold,
})

rc := &RollingCounter{
newConnectionThreshold: newConnectionTreshold,
numberOfBuckets: numberOfBuckets,
bucketRotationInterval: bucketRotationInterval,
log: log,
observabilityConnectionThreshold: observabilityConnectionThreshold,
connectionLimitThreshold: connectionLimitThreshold,
numberOfBuckets: numberOfBuckets,
bucketRotationInterval: bucketRotationInterval,
log: log,
}

return rc
Expand Down Expand Up @@ -74,20 +77,18 @@ func (rc *RollingCounter) Start(ctx context.Context, obsSender observability.Sen
}

func (rc *RollingCounter) run() {
rc.wg.Add(1)
go func() {
defer rc.wg.Done()
rc.wg.Go(func() {
for {
select {
case <-rc.ctx.Done():
return

case <-rc.bucketRotationTicker.C:
rc.thresholdCheck()
rc.observabilityThresholdCheck()
rc.onBucketRotationTick()
}
}
}()
})
}

func (rc *RollingCounter) Stop() {
Expand All @@ -105,9 +106,10 @@ func (rc *RollingCounter) withBucketLock(fn func()) {
fn()
}

func (rc *RollingCounter) thresholdCheck() {
func (rc *RollingCounter) observabilityThresholdCheck() {
rollingCount := rc.GetRollingCount()
if rollingCount < rc.newConnectionThreshold {

if rollingCount < rc.observabilityConnectionThreshold {
return
}

Expand Down Expand Up @@ -135,9 +137,24 @@ func (rc *RollingCounter) NewConnection() {
}

func (rc *RollingCounter) GetRollingCount() int {
return rc.getRollingCounterSafe()
}

func (rc *RollingCounter) OverConnectionLimitThreshold() bool {
rc.bucketLock.Lock()
defer rc.bucketLock.Unlock()

return rc.getRollingCountUnsafe() >= rc.connectionLimitThreshold
}

func (rc *RollingCounter) getRollingCounterSafe() int {
rc.bucketLock.Lock()
defer rc.bucketLock.Unlock()

return rc.getRollingCountUnsafe()
}

func (rc *RollingCounter) getRollingCountUnsafe() int {
var rollingCount int
for _, count := range rc.buckets {
rollingCount += count
Expand Down
63 changes: 46 additions & 17 deletions imap/connectioncounter/connectioncounter_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package connectioncounter_test

import (
"context"
"sync"
"testing"
"time"
Expand All @@ -24,9 +23,9 @@ type mockObsSender struct {
lastOpenConns, lastNewlyOpenedConns int
}

func (m *mockObsSender) AddMetrics(_ ...map[string]interface{}) {}
func (m *mockObsSender) AddMetrics(_ ...map[string]any) {}

func (m *mockObsSender) AddDistinctMetrics(_ interface{}, _ ...map[string]interface{}) {}
func (m *mockObsSender) AddDistinctMetrics(_ any, _ ...map[string]any) {}

func (m *mockObsSender) AddIMAPConnectionsExceededThresholdMetric(openConns, newlyOpenedConns int) {
m.mu.Lock()
Expand All @@ -48,18 +47,20 @@ func (m *mockObsSender) LastValues() (int, int) {
return m.lastOpenConns, m.lastNewlyOpenedConns
}

Comment thread
ElectroNafta marked this conversation as resolved.
func TestRollingCounter_ThresholdNotExceeded(t *testing.T) {
func TestRollingCounter_ObservabilityThresholdNotExceeded(t *testing.T) {
observabilityThreshold := 5
connectionLimitThreshold := 5
rc := connectioncounter.NewRollingCounter(
5,
connectionLimitThreshold,
observabilityThreshold,
3,
100*time.Millisecond,
)

mockSender := &mockObsSender{}
mockProvider := &mockConnProvider{openSessions: 10}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()

rc.Start(ctx, mockSender, mockProvider)

Expand All @@ -76,24 +77,25 @@ func TestRollingCounter_ThresholdNotExceeded(t *testing.T) {
rc.Stop()
}

func TestRollingCounter_ThresholdExceeded(t *testing.T) {
newConnThreshold := 3
func TestRollingCounter_ObservabilityThresholdExceeded(t *testing.T) {
observabilityThreshold := 3
connectionLimitThreshold := 5
rc := connectioncounter.NewRollingCounter(
newConnThreshold,
connectionLimitThreshold,
observabilityThreshold,
3,
100*time.Millisecond,
)

mockSender := &mockObsSender{}
mockProvider := &mockConnProvider{openSessions: 7}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()

rc.Start(ctx, mockSender, mockProvider)

newConnsOpened := newConnThreshold * 5
for i := 0; i < newConnsOpened; i++ {
newConnsOpened := observabilityThreshold * 5
for range newConnsOpened {
rc.NewConnection()
}

Expand All @@ -108,23 +110,50 @@ func TestRollingCounter_ThresholdExceeded(t *testing.T) {
rc.Stop()
}

func TestRollingCounter_ConnectionLimitThresholdExceeded(t *testing.T) {
observabilityThreshold := 3
connectionLimitThreshold := 3
rc := connectioncounter.NewRollingCounter(
connectionLimitThreshold,
observabilityThreshold,
3,
100*time.Millisecond,
)

mockSender := &mockObsSender{}
mockProvider := &mockConnProvider{openSessions: 7}

ctx := t.Context()

rc.Start(ctx, mockSender, mockProvider)

newConnsOpened := observabilityThreshold * 5
for range newConnsOpened {
rc.NewConnection()
}
assert.Equal(t, rc.GetRollingCount(), newConnsOpened)
assert.True(t, rc.OverConnectionLimitThreshold())

rc.Stop()
}

func TestRollingCounter_BucketRotation(t *testing.T) {
jitterPeriod := 50 * time.Millisecond
bucketRotationInterval := 500 * time.Millisecond
rc := connectioncounter.NewRollingCounter(
100,
100,
3,
bucketRotationInterval,
)

post10Connections := func() {
for i := 0; i < 10; i++ {
for range 10 {
rc.NewConnection()
}
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()

mockSender := &mockObsSender{}
mockProvider := &mockConnProvider{}
Expand Down
32 changes: 32 additions & 0 deletions imap/connectionlimiter/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package connectionlimiter

import (
"strings"

"github.com/ProtonMail/gluon/imap"
)

type Client string

const (
ClientAppleMail Client = "apple-mail"
ClientOutlook Client = "outlook"
ClientThunderbird Client = "thunderbird"
ClientUnknown Client = "unknown"
)

func normalizeClientKey(id imap.IMAPID) Client {
name := strings.TrimSpace(strings.ToLower(id.Name))
switch {
case strings.Contains(name, "outlook"):
return ClientOutlook
case strings.Contains(name, "thunderbird"):
return ClientThunderbird
case strings.Contains(name, "mac") && strings.Contains(name, "mail"):
return ClientAppleMail
case strings.Contains(name, "mac") && strings.Contains(name, "notes"):
return ClientUnknown
default:
return ClientUnknown
}
}
Loading
Loading