diff --git a/ratelimit/ratelimit.go b/ratelimit/ratelimit.go new file mode 100644 index 0000000..8928e5f --- /dev/null +++ b/ratelimit/ratelimit.go @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2025 LabStack and Echo contributors + +package ratelimit + +import ( + "net/http" + "sync" + "time" + + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" +) + +type ( + // Config defines the config for RateLimit middleware. + Config struct { + // Skipper defines a function to skip middleware. + Skipper middleware.Skipper + + // Limit is the maximum number of requests allowed within the defined window. + // Required. + Limit int + + // Window defines the time window for the rate limit (in seconds). + // Default is 60 seconds (1 minute). + Window time.Duration + + // KeyExtractor is a function used to generate a key for each request. + // Default implementation uses the client IP address. + KeyExtractor func(c echo.Context) string + + // ErrorHandler is a function to handle errors returned by the middleware. + ErrorHandler func(c echo.Context, err error) error + + // ExceedHandler is a function called when rate limit is exceeded. + // Default returns 429 Too Many Requests. + ExceedHandler func(c echo.Context) error + } + + // Store is an interface for storing rate limit data + Store interface { + // Increment increments the count for a key and returns the current count + Increment(key string, window time.Duration) (int, error) + + // Get returns the current count for a key + Get(key string) (int, error) + + // Cleanup removes expired entries + Cleanup() + } + + // MemoryStore implements in-memory storage for rate limiting + MemoryStore struct { + entries map[string]*entry + mu sync.RWMutex + } + + entry struct { + count int + expireAt time.Time + } +) + +var ( + // DefaultConfig is the default RateLimit middleware config. + DefaultConfig = Config{ + Skipper: middleware.DefaultSkipper, + Window: 60 * time.Second, // 1 minute + KeyExtractor: func(c echo.Context) string { + return c.RealIP() + }, + ErrorHandler: func(c echo.Context, err error) error { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + }, + ExceedHandler: func(c echo.Context) error { + return echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + }, + } + + // DefaultStore is the default in-memory store for rate limiting + DefaultStore Store +) + +// NewMemoryStore creates a new in-memory store for rate limiting +func NewMemoryStore() *MemoryStore { + store := &MemoryStore{ + entries: make(map[string]*entry), + } + + go func() { + // Clean up expired entries every minute + for { + time.Sleep(time.Minute) + store.Cleanup() + } + }() + + return store +} + +// Increment increments the count for a key and returns the current count +func (s *MemoryStore) Increment(key string, window time.Duration) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + if s.entries == nil { + s.entries = make(map[string]*entry) + } + + e, exists := s.entries[key] + if !exists || now.After(e.expireAt) { + s.entries[key] = &entry{ + count: 1, + expireAt: now.Add(window), + } + return 1, nil + } + + e.count++ + return e.count, nil +} + +// Get returns the current count for a key +func (s *MemoryStore) Get(key string) (int, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + now := time.Now() + e, exists := s.entries[key] + if !exists { + return 0, nil + } + + if now.After(e.expireAt) { + return 0, nil + } + + return e.count, nil +} + +// Cleanup removes expired entries from the memory store +func (s *MemoryStore) Cleanup() { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + for key, e := range s.entries { + if now.After(e.expireAt) { + delete(s.entries, key) + } + } +} + +// Initialize the default store +func init() { + DefaultStore = NewMemoryStore() +} + +// Middleware returns a RateLimit middleware. +func Middleware(limit int) echo.MiddlewareFunc { + c := DefaultConfig + c.Limit = limit + return MiddlewareWithConfig(c) +} + +// MiddlewareWithConfig returns a RateLimit middleware with config. +func MiddlewareWithConfig(config Config) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultConfig.Skipper + } + if config.Window == 0 { + config.Window = DefaultConfig.Window + } + if config.KeyExtractor == nil { + config.KeyExtractor = DefaultConfig.KeyExtractor + } + if config.ErrorHandler == nil { + config.ErrorHandler = DefaultConfig.ErrorHandler + } + if config.ExceedHandler == nil { + config.ExceedHandler = DefaultConfig.ExceedHandler + } + if config.Limit <= 0 { + panic("echo: rate limit middleware requires limit > 0") + } + + store := DefaultStore + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + + key := config.KeyExtractor(c) + count, err := store.Increment(key, config.Window) + if err != nil { + return config.ErrorHandler(c, err) + } + + // Set rate limit headers + c.Response().Header().Set("X-RateLimit-Limit", string(rune(config.Limit))) + c.Response().Header().Set("X-RateLimit-Remaining", string(rune(config.Limit-count))) + + if count > config.Limit { + return config.ExceedHandler(c) + } + + return next(c) + } + } +} + +// WithStore returns a RateLimit middleware with a custom store. +func WithStore(store Store) echo.MiddlewareFunc { + DefaultStore = store + return Middleware(DefaultConfig.Limit) +} + +// WithStoreAndConfig returns a RateLimit middleware with a custom store and config. +func WithStoreAndConfig(store Store, config Config) echo.MiddlewareFunc { + DefaultStore = store + return MiddlewareWithConfig(config) +} diff --git a/ratelimit/ratelimit_test.go b/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..dea453e --- /dev/null +++ b/ratelimit/ratelimit_test.go @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2025 LabStack and Echo contributors + +package ratelimit + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestRateLimit(t *testing.T) { + e := echo.New() + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + } + + // Create rate limiter middleware with limit of 3 requests + limiter := Middleware(3) + h := limiter(handler) + + // Reset the default store for testing + DefaultStore = NewMemoryStore() + + // Create a test request with IP 192.0.2.1 + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "192.0.2.1:1234" + + // First request should pass + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + // Second request should pass + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + err = h(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + // Third request should pass + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + err = h(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + // Fourth request should fail with 429 + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + err = h(c) + he, ok := err.(*echo.HTTPError) + assert.True(t, ok) + assert.Equal(t, http.StatusTooManyRequests, he.Code) +} + +func TestRateLimitWithCustomConfig(t *testing.T) { + e := echo.New() + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + } + + // Custom config with shorter window and custom key extractor + config := Config{ + Limit: 2, + Window: 100 * time.Millisecond, + KeyExtractor: func(c echo.Context) string { + return "test-key" + }, + } + + limiter := MiddlewareWithConfig(config) + h := limiter(handler) + + // Reset the default store for testing + DefaultStore = NewMemoryStore() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + + // First request should pass + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + // Second request should pass + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + err = h(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + // Third request should fail + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + err = h(c) + he, ok := err.(*echo.HTTPError) + assert.True(t, ok) + assert.Equal(t, http.StatusTooManyRequests, he.Code) + + // Wait for window to expire + time.Sleep(150 * time.Millisecond) + + // Request after window expiry should pass + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + err = h(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestSkipper(t *testing.T) { + e := echo.New() + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + } + + // Custom config with skipper + config := Config{ + Limit: 1, + Skipper: func(c echo.Context) bool { + return c.Path() == "/skip" + }, + } + + limiter := MiddlewareWithConfig(config) + h := limiter(handler) + + // Reset the default store for testing + DefaultStore = NewMemoryStore() + + // First request should pass + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + // Second request should fail with 429 + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + err = h(c) + he, ok := err.(*echo.HTTPError) + assert.True(t, ok) + assert.Equal(t, http.StatusTooManyRequests, he.Code) + + // Request to skipped path should always pass + req = httptest.NewRequest(http.MethodGet, "/skip", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + c.SetPath("/skip") + err = h(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestMemoryStoreCleanup(t *testing.T) { + store := NewMemoryStore() + + // Add an entry that will expire soon + _, err := store.Increment("test-key", 50*time.Millisecond) + assert.NoError(t, err) + + // Verify the entry exists + count, err := store.Get("test-key") + assert.NoError(t, err) + assert.Equal(t, 1, count) + + // Wait for expiration + time.Sleep(100 * time.Millisecond) + + // Manually trigger cleanup + store.Cleanup() + + // Verify the entry is removed + count, err = store.Get("test-key") + assert.NoError(t, err) + assert.Equal(t, 0, count) +} + +// Mock store for testing custom stores +type mockStore struct { + counts map[string]int +} + +func newMockStore() *mockStore { + return &mockStore{ + counts: make(map[string]int), + } +} + +func (s *mockStore) Increment(key string, _ time.Duration) (int, error) { + s.counts[key]++ + return s.counts[key], nil +} + +func (s *mockStore) Get(key string) (int, error) { + return s.counts[key], nil +} + +func (s *mockStore) Cleanup() { + // No-op for mock +}