Skip to content

Commit cc1b93d

Browse files
committed
feat: Add rate limiting
Adds a rate limiting mechanism, that will send HTTP/429 responses once a defined limit is reached (Token Bucket) Signed-off-by: Manuel Rüger <[email protected]>
1 parent fb7682a commit cc1b93d

File tree

7 files changed

+77
-25
lines changed

7 files changed

+77
-25
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ require (
99
github.com/prometheus/common v0.64.0
1010
golang.org/x/crypto v0.39.0
1111
golang.org/x/sync v0.15.0
12+
golang.org/x/time v0.12.0
1213
gopkg.in/yaml.v2 v2.4.0
1314
)
1415

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
6161
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
6262
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
6363
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
64+
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
65+
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
6466
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
6567
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
6668
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

web/handler.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"sync"
2525

2626
"golang.org/x/crypto/bcrypt"
27+
"golang.org/x/time/rate"
2728
)
2829

2930
// extraHTTPHeaders is a map of HTTP headers that can be added to HTTP
@@ -80,6 +81,7 @@ type webHandler struct {
8081
handler http.Handler
8182
logger *slog.Logger
8283
cache *cache
84+
limiter *rate.Limiter
8385
// bcryptMtx is there to ensure that bcrypt.CompareHashAndPassword is run
8486
// only once in parallel as this is CPU intensive.
8587
bcryptMtx sync.Mutex
@@ -93,6 +95,11 @@ func (u *webHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
9395
return
9496
}
9597

98+
if u.limiter != nil && !u.limiter.Allow() {
99+
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
100+
return
101+
}
102+
96103
// Configure http headers.
97104
for k, v := range c.HTTPConfig.Header {
98105
w.Header().Set(k, v)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
rate_limiter_config:
2+
rate: 1
3+
burst: 1
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
rate_limiter_config:
2+
rate: 0
3+
burst: 0

web/tls_config.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"github.com/mdlayher/vsock"
3232
config_util "github.com/prometheus/common/config"
3333
"golang.org/x/sync/errgroup"
34+
"golang.org/x/time/rate"
3435
"gopkg.in/yaml.v2"
3536
)
3637

@@ -40,9 +41,10 @@ var (
4041
)
4142

4243
type Config struct {
43-
TLSConfig TLSConfig `yaml:"tls_server_config"`
44-
HTTPConfig HTTPConfig `yaml:"http_server_config"`
45-
Users map[string]config_util.Secret `yaml:"basic_auth_users"`
44+
TLSConfig TLSConfig `yaml:"tls_server_config"`
45+
HTTPConfig HTTPConfig `yaml:"http_server_config"`
46+
RateLimiterConfig RateLimiterConfig `yaml:"rate_limiter_config"`
47+
Users map[string]config_util.Secret `yaml:"basic_auth_users"`
4648
}
4749

4850
type TLSConfig struct {
@@ -109,6 +111,11 @@ type HTTPConfig struct {
109111
Header map[string]string `yaml:"headers,omitempty"`
110112
}
111113

114+
type RateLimiterConfig struct {
115+
Burst int `yaml:"burst"`
116+
Rate int `yaml:"rate"`
117+
}
118+
112119
func getConfig(configPath string) (*Config, error) {
113120
content, err := os.ReadFile(configPath)
114121
if err != nil {
@@ -366,11 +373,19 @@ func Serve(l net.Listener, server *http.Server, flags *FlagConfig, logger *slog.
366373
return err
367374
}
368375

376+
var limiter *rate.Limiter
377+
// Setup Rate Limiter
378+
if c.RateLimiterConfig.Rate != 0 && c.RateLimiterConfig.Burst != 0 {
379+
limiter = rate.NewLimiter(rate.Limit(c.RateLimiterConfig.Rate), c.RateLimiterConfig.Burst)
380+
logger.Info("Rate Limiter is enabled.", "burst", c.RateLimiterConfig.Burst, "rate", c.RateLimiterConfig.Rate)
381+
}
382+
369383
server.Handler = &webHandler{
370384
tlsConfigPath: tlsConfigPath,
371385
logger: logger,
372386
handler: handler,
373387
cache: newCache(),
388+
limiter: limiter,
374389
}
375390

376391
config, err := ConfigToTLSConfig(&c.TLSConfig)

web/tls_config_test.go

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ var (
7272
// Introduced in Go 1.21
7373
"Certificate required": regexp.MustCompile(`certificate required`),
7474
"Unknown CA": regexp.MustCompile(`unknown certificate authority`),
75+
"Too Many Requests": regexp.MustCompile(`Too Many Requests`),
7576
}
7677
)
7778

@@ -98,6 +99,7 @@ type TestInputs struct {
9899
Username string
99100
Password string
100101
ClientCertificate string
102+
Requests int
101103
}
102104

103105
func TestYAMLFiles(t *testing.T) {
@@ -364,6 +366,19 @@ func TestServerBehaviour(t *testing.T) {
364366
ClientCertificate: "client2_selfsigned",
365367
ExpectedError: ErrorMap["Invalid client cert"],
366368
},
369+
{
370+
Name: "valid rate limiter that doesn't block",
371+
YAMLConfigPath: "testdata/web_config_rate_limiter_nonblocking.yaml",
372+
UseTLSClient: false,
373+
ExpectedError: nil,
374+
},
375+
{
376+
Name: "valid rate limiter with a capacity of one",
377+
YAMLConfigPath: "testdata/web_config_rate_limiter_capacity_one.yaml",
378+
UseTLSClient: false,
379+
Requests: 10,
380+
ExpectedError: ErrorMap["Too Many Requests"],
381+
},
367382
}
368383
for _, testInputs := range testTables {
369384
t.Run(testInputs.Name, testInputs.Test)
@@ -511,35 +526,41 @@ func (test *TestInputs) Test(t *testing.T) {
511526
if test.Username != "" {
512527
req.SetBasicAuth(test.Username, test.Password)
513528
}
529+
514530
return client.Do(req)
515531
}
516532
go func() {
517533
time.Sleep(250 * time.Millisecond)
518-
r, err := ClientConnection()
519-
if err != nil {
520-
recordConnectionError(err)
521-
return
522-
}
523534

524-
if test.ActualCipher != 0 {
525-
if r.TLS.CipherSuite != test.ActualCipher {
526-
recordConnectionError(
527-
fmt.Errorf("bad cipher suite selected. Expected: %s, got: %s",
528-
tls.CipherSuiteName(test.ActualCipher),
529-
tls.CipherSuiteName(r.TLS.CipherSuite),
530-
),
531-
)
535+
for req := 0; req <= test.Requests; req++ {
536+
537+
r, err := ClientConnection()
538+
539+
if err != nil {
540+
recordConnectionError(err)
541+
return
532542
}
533-
}
534543

535-
body, err := io.ReadAll(r.Body)
536-
if err != nil {
537-
recordConnectionError(err)
538-
return
539-
}
540-
if string(body) != "Hello World!" {
541-
recordConnectionError(errors.New(string(body)))
542-
return
544+
if test.ActualCipher != 0 {
545+
if r.TLS.CipherSuite != test.ActualCipher {
546+
recordConnectionError(
547+
fmt.Errorf("bad cipher suite selected. Expected: %s, got: %s",
548+
tls.CipherSuiteName(test.ActualCipher),
549+
tls.CipherSuiteName(r.TLS.CipherSuite),
550+
),
551+
)
552+
}
553+
}
554+
555+
body, err := io.ReadAll(r.Body)
556+
if err != nil {
557+
recordConnectionError(err)
558+
return
559+
}
560+
if string(body) != "Hello World!" {
561+
recordConnectionError(errors.New(string(body)))
562+
return
563+
}
543564
}
544565
recordConnectionError(nil)
545566
}()

0 commit comments

Comments
 (0)