From 4b1853b478127793f9a049538c8e05f47361817c Mon Sep 17 00:00:00 2001 From: Yi Duan Date: Wed, 28 Aug 2024 20:44:54 +0800 Subject: [PATCH] fix: omit data-race panic on `GC()` (#6) * fix: avoid data race between `GC()` and `Load()` * refactor: not hypohesis * chore: move unused `go:nocheckptr` * test: realword biz simulation * tmp: add write lock --- api_test.go | 132 ++++++++++++++++++++++++++++++++++++++++++++++++++++ manager.go | 14 +++--- session.go | 2 +- 3 files changed, 141 insertions(+), 7 deletions(-) diff --git a/api_test.go b/api_test.go index 1b59a43..a0d6df2 100644 --- a/api_test.go +++ b/api_test.go @@ -16,6 +16,8 @@ package localsession import ( "context" + "fmt" + "math" "os" "runtime/pprof" "sync" @@ -389,6 +391,32 @@ func TestSessionManager_GC(t *testing.T) { require.Equal(t, N/2, sum) } +func TestRace(t *testing.T) { + manager := NewSessionManager(ManagerOptions{ + ShardNumber: 1, + GCInterval: time.Second, + }) + var N = 1000 + var start sync.RWMutex + start.Lock() + wg := sync.WaitGroup{} + for i := 0; i < N; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + s := NewSessionMap(map[interface{}]interface{}{}).WithValue("a", "b") + start.RLock() + manager.BindSession(SessionID(i), s) + ss, ok := manager.GetSession(SessionID(i)) + if !ok || ss.Get("a") != "b" { + t.Fatal("not equal") + } + }(i) + } + start.Unlock() + wg.Wait() +} + func BenchmarkSessionManager_CurSession(b *testing.B) { s := NewSessionCtx(context.Background()) @@ -569,3 +597,107 @@ func BenchmarkGLS_Set(b *testing.B) { }) }) } + +func emitLoops(m *SessionManager, ctx context.Context, N int, s *stat) { + for i := 0; i < N; i++ { + go func() { + for { + if ctx.Err() != nil { + return + } + start := time.Now() + session := NewSessionCtx(ctx) + ss := session.WithValue("a", "b") + m.BindSession(SessionID(goID()), ss) + sss, _ := m.GetSession(SessionID(goID())) + if val := sss.Get("a"); val != "b" { + panic(fmt.Sprintf("unexpected val: %#v", val)) + } + m.UnbindSession(SessionID(goID())) + cost := time.Now().Sub(start) + s.Update(cost) + for a := 0; a < 10; a++ { + time.Sleep(time.Microsecond * 50) + for b := 0; b < 100000; b++ { + _ = b + } + } + } + }() + } +} + +func BenchmarkLoops(b *testing.B) { + for i := 0; i < b.N; i++ { + for b := 0; b < 100000; b++ { + _ = b + } + } +} + +type stat struct { + max time.Duration + min time.Duration + sum time.Duration + count int + + mux sync.RWMutex +} + +func (st *stat) Update(cost time.Duration) { + st.mux.Lock() + defer st.mux.Unlock() + if cost > st.max { + st.max = cost + } else if cost < st.min { + st.min = cost + } + st.count++ + st.sum += cost + return +} + +func (st *stat) String() string { + st.mux.RLock() + defer st.mux.RUnlock() + return fmt.Sprintf("min:%dns, max:%dns, avg:%dns", st.min, st.max, st.sum/time.Duration(st.count)) +} + +func TestRealBizGLS(t *testing.T) { + var runner = func(N int) { + m := NewSessionManager(ManagerOptions{ + ShardNumber: 100, + GCInterval: time.Second, + }) + s := &stat{ + min: time.Duration(math.MaxInt64), + } + ctx, _ := context.WithTimeout(context.Background(), time.Second*60) + emitLoops(&m, ctx, N, s) + go func(ctx context.Context) { + tt := time.NewTicker(time.Second) + for { + select { + case <-tt.C: + { + println(s.String()) + } + case <-ctx.Done(): + return + } + + } + }(ctx) + <-ctx.Done() + } + + t.Run("10", func(t *testing.T) { + runner(10) + }) + t.Run("100", func(t *testing.T) { + runner(100) + }) + t.Run("1000", func(t *testing.T) { + runner(1000) + }) +} diff --git a/manager.go b/manager.go index 7dfa000..8e73c0c 100644 --- a/manager.go +++ b/manager.go @@ -26,7 +26,7 @@ type ManagerOptions struct { // current session to children goroutines // // WARNING: Once this option enables, if you want to use `pprof.Do()`, it must be called before `BindSession()`, - // otherwise transmitting will be dysfunctional + // otherwise transmitting will be disfunctional EnableImplicitlyTransmitAsync bool // ShardNumber is used to shard session id, it must be larger than zero @@ -87,9 +87,12 @@ func (self SessionManager) Options() ManagerOptions { // SessionID is the identity of a session type SessionID uint64 -//go:nocheckptr func (s *shard) Load(id SessionID) (Session, bool) { s.lock.RLock() + + // p := atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&s.m))) + // m := *(*map[SessionID]Session)(unsafe.Pointer(p)) + session, ok := s.m[id] s.lock.RUnlock() return session, ok @@ -157,15 +160,13 @@ func (self *SessionManager) UnbindSession(id SessionID) { } // GC sweep invalid sessions and release unused memory -// -//go:nocheckptr func (self SessionManager) GC() { if !atomic.CompareAndSwapUint32(&self.inGC, 0, 1) { return } for _, shard := range self.shards { - shard.lock.RLock() + shard.lock.Lock() n := shard.m m := make(map[SessionID]Session, len(n)) for id, s := range n { @@ -174,8 +175,9 @@ func (self SessionManager) GC() { m[id] = s } } + // atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&shard.m)), unsafe.Pointer(&m)) shard.m = m - shard.lock.RUnlock() + shard.lock.Unlock() } atomic.StoreUint32(&self.inGC, 0) diff --git a/session.go b/session.go index 2b647b8..9fc4df0 100644 --- a/session.go +++ b/session.go @@ -161,7 +161,7 @@ func (self *SessionMap) Get(key interface{}) interface{} { } // Set value for specific key,and return itself -func (self *SessionMap) WithValue(key interface{}, val interface{}) Session { +func (self *SessionMap) WithValue(key, val interface{}) Session { if self == nil { return nil }