Skip to content

Commit 66a8814

Browse files
authored
fix(middleware/session): mutex for thread safety (#3050)
* chore: Remove extra release and acquire ctx calls in session_test.go * feat: Remove unnecessary session mutex lock in decodeSessionData function * chore: Refactor session benchmark tests * fix(middleware/session): mutex for thread safety * feat: Add session mutex lock for thread safety * chore: Refactor releaseSession mutex
1 parent 6fa0e7c commit 66a8814

File tree

3 files changed

+276
-24
lines changed

3 files changed

+276
-24
lines changed

middleware/session/session.go

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
)
1515

1616
type Session struct {
17+
mu sync.RWMutex // Mutex to protect non-data fields
1718
id string // session id
1819
fresh bool // if new session
1920
ctx *fiber.Ctx // fiber context
@@ -42,6 +43,7 @@ func acquireSession() *Session {
4243
}
4344

4445
func releaseSession(s *Session) {
46+
s.mu.Lock()
4547
s.id = ""
4648
s.exp = 0
4749
s.ctx = nil
@@ -52,16 +54,21 @@ func releaseSession(s *Session) {
5254
if s.byteBuffer != nil {
5355
s.byteBuffer.Reset()
5456
}
57+
s.mu.Unlock()
5558
sessionPool.Put(s)
5659
}
5760

5861
// Fresh is true if the current session is new
5962
func (s *Session) Fresh() bool {
63+
s.mu.RLock()
64+
defer s.mu.RUnlock()
6065
return s.fresh
6166
}
6267

6368
// ID returns the session id
6469
func (s *Session) ID() string {
70+
s.mu.RLock()
71+
defer s.mu.RUnlock()
6572
return s.id
6673
}
6774

@@ -102,6 +109,9 @@ func (s *Session) Destroy() error {
102109
// Reset local data
103110
s.data.Reset()
104111

112+
s.mu.RLock()
113+
defer s.mu.RUnlock()
114+
105115
// Use external Storage if exist
106116
if err := s.config.Storage.Delete(s.id); err != nil {
107117
return err
@@ -114,6 +124,9 @@ func (s *Session) Destroy() error {
114124

115125
// Regenerate generates a new session id and delete the old one from Storage
116126
func (s *Session) Regenerate() error {
127+
s.mu.Lock()
128+
defer s.mu.Unlock()
129+
117130
// Delete old id from storage
118131
if err := s.config.Storage.Delete(s.id); err != nil {
119132
return err
@@ -131,6 +144,10 @@ func (s *Session) Reset() error {
131144
if s.data != nil {
132145
s.data.Reset()
133146
}
147+
148+
s.mu.Lock()
149+
defer s.mu.Unlock()
150+
134151
// Reset byte buffer
135152
if s.byteBuffer != nil {
136153
s.byteBuffer.Reset()
@@ -154,20 +171,24 @@ func (s *Session) Reset() error {
154171

155172
// refresh generates a new session, and set session.fresh to be true
156173
func (s *Session) refresh() {
157-
// Create a new id
158174
s.id = s.config.KeyGenerator()
159-
160-
// We assign a new id to the session, so the session must be fresh
161175
s.fresh = true
162176
}
163177

164178
// Save will update the storage and client cookie
179+
//
180+
// sess.Save() will save the session data to the storage and update the
181+
// client cookie, and it will release the session after saving.
182+
//
183+
// It's not safe to use the session after calling Save().
165184
func (s *Session) Save() error {
166185
// Better safe than sorry
167186
if s.data == nil {
168187
return nil
169188
}
170189

190+
s.mu.Lock()
191+
171192
// Check if session has your own expiration, otherwise use default value
172193
if s.exp <= 0 {
173194
s.exp = s.config.Expiration
@@ -177,25 +198,25 @@ func (s *Session) Save() error {
177198
s.setSession()
178199

179200
// Convert data to bytes
180-
mux.Lock()
181-
defer mux.Unlock()
182201
encCache := gob.NewEncoder(s.byteBuffer)
183202
err := encCache.Encode(&s.data.Data)
184203
if err != nil {
185204
return fmt.Errorf("failed to encode data: %w", err)
186205
}
187206

188-
// copy the data in buffer
207+
// Copy the data in buffer
189208
encodedBytes := make([]byte, s.byteBuffer.Len())
190209
copy(encodedBytes, s.byteBuffer.Bytes())
191210

192-
// pass copied bytes with session id to provider
211+
// Pass copied bytes with session id to provider
193212
if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
194213
return err
195214
}
196215

216+
s.mu.Unlock()
217+
197218
// Release session
198-
// TODO: It's not safe to use the Session after called Save()
219+
// TODO: It's not safe to use the Session after calling Save()
199220
releaseSession(s)
200221

201222
return nil
@@ -211,6 +232,8 @@ func (s *Session) Keys() []string {
211232

212233
// SetExpiry sets a specific expiration for this session
213234
func (s *Session) SetExpiry(exp time.Duration) {
235+
s.mu.Lock()
236+
defer s.mu.Unlock()
214237
s.exp = exp
215238
}
216239

@@ -276,3 +299,13 @@ func (s *Session) delSession() {
276299
fasthttp.ReleaseCookie(fcookie)
277300
}
278301
}
302+
303+
// decodeSessionData decodes the session data from raw bytes.
304+
func (s *Session) decodeSessionData(rawData []byte) error {
305+
_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
306+
encCache := gob.NewDecoder(s.byteBuffer)
307+
if err := encCache.Decode(&s.data.Data); err != nil {
308+
return fmt.Errorf("failed to decode session data: %w", err)
309+
}
310+
return nil
311+
}

middleware/session/session_test.go

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package session
22

33
import (
4+
"errors"
5+
"sync"
46
"testing"
57
"time"
68

@@ -673,3 +675,230 @@ func Benchmark_Session(b *testing.B) {
673675
utils.AssertEqual(b, nil, err)
674676
})
675677
}
678+
679+
// go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4
680+
func Benchmark_Session_Parallel(b *testing.B) {
681+
b.Run("default", func(b *testing.B) {
682+
app, store := fiber.New(), New()
683+
b.ReportAllocs()
684+
b.ResetTimer()
685+
b.RunParallel(func(pb *testing.PB) {
686+
for pb.Next() {
687+
c := app.AcquireCtx(&fasthttp.RequestCtx{})
688+
c.Request().Header.SetCookie(store.sessionName, "12356789")
689+
690+
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
691+
sess.Set("john", "doe")
692+
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
693+
app.ReleaseCtx(c)
694+
}
695+
})
696+
})
697+
698+
b.Run("storage", func(b *testing.B) {
699+
app := fiber.New()
700+
store := New(Config{
701+
Storage: memory.New(),
702+
})
703+
b.ReportAllocs()
704+
b.ResetTimer()
705+
b.RunParallel(func(pb *testing.PB) {
706+
for pb.Next() {
707+
c := app.AcquireCtx(&fasthttp.RequestCtx{})
708+
c.Request().Header.SetCookie(store.sessionName, "12356789")
709+
710+
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
711+
sess.Set("john", "doe")
712+
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
713+
app.ReleaseCtx(c)
714+
}
715+
})
716+
})
717+
}
718+
719+
// go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4
720+
func Benchmark_Session_Asserted(b *testing.B) {
721+
b.Run("default", func(b *testing.B) {
722+
app, store := fiber.New(), New()
723+
c := app.AcquireCtx(&fasthttp.RequestCtx{})
724+
defer app.ReleaseCtx(c)
725+
c.Request().Header.SetCookie(store.sessionName, "12356789")
726+
727+
b.ReportAllocs()
728+
b.ResetTimer()
729+
for n := 0; n < b.N; n++ {
730+
sess, err := store.Get(c)
731+
utils.AssertEqual(b, nil, err)
732+
sess.Set("john", "doe")
733+
err = sess.Save()
734+
utils.AssertEqual(b, nil, err)
735+
}
736+
})
737+
738+
b.Run("storage", func(b *testing.B) {
739+
app := fiber.New()
740+
store := New(Config{
741+
Storage: memory.New(),
742+
})
743+
c := app.AcquireCtx(&fasthttp.RequestCtx{})
744+
defer app.ReleaseCtx(c)
745+
c.Request().Header.SetCookie(store.sessionName, "12356789")
746+
747+
b.ReportAllocs()
748+
b.ResetTimer()
749+
for n := 0; n < b.N; n++ {
750+
sess, err := store.Get(c)
751+
utils.AssertEqual(b, nil, err)
752+
sess.Set("john", "doe")
753+
err = sess.Save()
754+
utils.AssertEqual(b, nil, err)
755+
}
756+
})
757+
}
758+
759+
// go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4
760+
func Benchmark_Session_Asserted_Parallel(b *testing.B) {
761+
b.Run("default", func(b *testing.B) {
762+
app, store := fiber.New(), New()
763+
b.ReportAllocs()
764+
b.ResetTimer()
765+
b.RunParallel(func(pb *testing.PB) {
766+
for pb.Next() {
767+
c := app.AcquireCtx(&fasthttp.RequestCtx{})
768+
c.Request().Header.SetCookie(store.sessionName, "12356789")
769+
770+
sess, err := store.Get(c)
771+
utils.AssertEqual(b, nil, err)
772+
sess.Set("john", "doe")
773+
utils.AssertEqual(b, nil, sess.Save())
774+
app.ReleaseCtx(c)
775+
}
776+
})
777+
})
778+
779+
b.Run("storage", func(b *testing.B) {
780+
app := fiber.New()
781+
store := New(Config{
782+
Storage: memory.New(),
783+
})
784+
b.ReportAllocs()
785+
b.ResetTimer()
786+
b.RunParallel(func(pb *testing.PB) {
787+
for pb.Next() {
788+
c := app.AcquireCtx(&fasthttp.RequestCtx{})
789+
c.Request().Header.SetCookie(store.sessionName, "12356789")
790+
791+
sess, err := store.Get(c)
792+
utils.AssertEqual(b, nil, err)
793+
sess.Set("john", "doe")
794+
utils.AssertEqual(b, nil, sess.Save())
795+
app.ReleaseCtx(c)
796+
}
797+
})
798+
})
799+
}
800+
801+
// go test -v -race -run Test_Session_Concurrency ./...
802+
func Test_Session_Concurrency(t *testing.T) {
803+
t.Parallel()
804+
app := fiber.New()
805+
store := New()
806+
807+
var wg sync.WaitGroup
808+
errChan := make(chan error, 10) // Buffered channel to collect errors
809+
const numGoroutines = 10 // Number of concurrent goroutines to test
810+
811+
// Start numGoroutines goroutines
812+
for i := 0; i < numGoroutines; i++ {
813+
wg.Add(1)
814+
go func() {
815+
defer wg.Done()
816+
817+
localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})
818+
819+
sess, err := store.Get(localCtx)
820+
if err != nil {
821+
errChan <- err
822+
return
823+
}
824+
825+
// Set a value
826+
sess.Set("name", "john")
827+
828+
// get the session id
829+
id := sess.ID()
830+
831+
// Check if the session is fresh
832+
if !sess.Fresh() {
833+
errChan <- errors.New("session should be fresh")
834+
return
835+
}
836+
837+
// Save the session
838+
if err := sess.Save(); err != nil {
839+
errChan <- err
840+
return
841+
}
842+
843+
// Release the context
844+
app.ReleaseCtx(localCtx)
845+
846+
// Acquire a new context
847+
localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
848+
defer app.ReleaseCtx(localCtx)
849+
850+
// Set the session id in the header
851+
localCtx.Request().Header.SetCookie(store.sessionName, id)
852+
853+
// Get the session
854+
sess, err = store.Get(localCtx)
855+
if err != nil {
856+
errChan <- err
857+
return
858+
}
859+
860+
// Get the value
861+
name := sess.Get("name")
862+
if name != "john" {
863+
errChan <- errors.New("name should be john")
864+
return
865+
}
866+
867+
// Get ID from the session
868+
if sess.ID() != id {
869+
errChan <- errors.New("id should be the same")
870+
return
871+
}
872+
873+
// Check if the session is fresh
874+
if sess.Fresh() {
875+
errChan <- errors.New("session should not be fresh")
876+
return
877+
}
878+
879+
// Delete the key
880+
sess.Delete("name")
881+
882+
// Get the value
883+
name = sess.Get("name")
884+
if name != nil {
885+
errChan <- errors.New("name should be nil")
886+
return
887+
}
888+
889+
// Destroy the session
890+
if err := sess.Destroy(); err != nil {
891+
errChan <- err
892+
return
893+
}
894+
}()
895+
}
896+
897+
wg.Wait() // Wait for all goroutines to finish
898+
close(errChan) // Close the channel to signal no more errors will be sent
899+
900+
// Check for errors sent to errChan
901+
for err := range errChan {
902+
utils.AssertEqual(t, nil, err)
903+
}
904+
}

0 commit comments

Comments
 (0)