Skip to content

Commit

Permalink
fix: omit data-race panic on GC() (#6)
Browse files Browse the repository at this point in the history
* fix: avoid data race between `GC()` and `Load()`

* refactor: not hypohesis

* chore: move unused `go:nocheckptr`

* test: realword biz simulation

* tmp: add write lock
  • Loading branch information
AsterDY authored Aug 28, 2024
1 parent dd21ab6 commit 4b1853b
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 7 deletions.
132 changes: 132 additions & 0 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package localsession

import (
"context"
"fmt"
"math"
"os"
"runtime/pprof"
"sync"
Expand Down Expand Up @@ -389,6 +391,32 @@ func TestSessionManager_GC(t *testing.T) {
require.Equal(t, N/2, sum)
}

func TestRace(t *testing.T) {
manager := NewSessionManager(ManagerOptions{
ShardNumber: 1,
GCInterval: time.Second,
})
var N = 1000
var start sync.RWMutex
start.Lock()
wg := sync.WaitGroup{}
for i := 0; i < N; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
s := NewSessionMap(map[interface{}]interface{}{}).WithValue("a", "b")
start.RLock()
manager.BindSession(SessionID(i), s)
ss, ok := manager.GetSession(SessionID(i))
if !ok || ss.Get("a") != "b" {
t.Fatal("not equal")
}
}(i)
}
start.Unlock()
wg.Wait()
}

func BenchmarkSessionManager_CurSession(b *testing.B) {
s := NewSessionCtx(context.Background())

Expand Down Expand Up @@ -569,3 +597,107 @@ func BenchmarkGLS_Set(b *testing.B) {
})
})
}

func emitLoops(m *SessionManager, ctx context.Context, N int, s *stat) {
for i := 0; i < N; i++ {
go func() {
for {
if ctx.Err() != nil {
return
}
start := time.Now()
session := NewSessionCtx(ctx)
ss := session.WithValue("a", "b")
m.BindSession(SessionID(goID()), ss)
sss, _ := m.GetSession(SessionID(goID()))
if val := sss.Get("a"); val != "b" {
panic(fmt.Sprintf("unexpected val: %#v", val))
}
m.UnbindSession(SessionID(goID()))
cost := time.Now().Sub(start)
s.Update(cost)
for a := 0; a < 10; a++ {
time.Sleep(time.Microsecond * 50)
for b := 0; b < 100000; b++ {
_ = b
}
}
}
}()
}
}

func BenchmarkLoops(b *testing.B) {
for i := 0; i < b.N; i++ {
for b := 0; b < 100000; b++ {
_ = b
}
}
}

type stat struct {
max time.Duration
min time.Duration
sum time.Duration
count int

mux sync.RWMutex
}

func (st *stat) Update(cost time.Duration) {
st.mux.Lock()
defer st.mux.Unlock()
if cost > st.max {
st.max = cost
} else if cost < st.min {
st.min = cost
}
st.count++
st.sum += cost
return
}

func (st *stat) String() string {
st.mux.RLock()
defer st.mux.RUnlock()
return fmt.Sprintf("min:%dns, max:%dns, avg:%dns", st.min, st.max, st.sum/time.Duration(st.count))
}

func TestRealBizGLS(t *testing.T) {
var runner = func(N int) {
m := NewSessionManager(ManagerOptions{
ShardNumber: 100,
GCInterval: time.Second,
})
s := &stat{
min: time.Duration(math.MaxInt64),
}
ctx, _ := context.WithTimeout(context.Background(), time.Second*60)
emitLoops(&m, ctx, N, s)
go func(ctx context.Context) {
tt := time.NewTicker(time.Second)
for {
select {
case <-tt.C:
{
println(s.String())
}
case <-ctx.Done():
return
}

}
}(ctx)
<-ctx.Done()
}

t.Run("10", func(t *testing.T) {
runner(10)
})
t.Run("100", func(t *testing.T) {
runner(100)
})
t.Run("1000", func(t *testing.T) {
runner(1000)
})
}
14 changes: 8 additions & 6 deletions manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type ManagerOptions struct {
// current session to children goroutines
//
// WARNING: Once this option enables, if you want to use `pprof.Do()`, it must be called before `BindSession()`,
// otherwise transmitting will be dysfunctional
// otherwise transmitting will be disfunctional
EnableImplicitlyTransmitAsync bool

// ShardNumber is used to shard session id, it must be larger than zero
Expand Down Expand Up @@ -87,9 +87,12 @@ func (self SessionManager) Options() ManagerOptions {
// SessionID is the identity of a session
type SessionID uint64

//go:nocheckptr
func (s *shard) Load(id SessionID) (Session, bool) {
s.lock.RLock()

// p := atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&s.m)))
// m := *(*map[SessionID]Session)(unsafe.Pointer(p))

session, ok := s.m[id]
s.lock.RUnlock()
return session, ok
Expand Down Expand Up @@ -157,15 +160,13 @@ func (self *SessionManager) UnbindSession(id SessionID) {
}

// GC sweep invalid sessions and release unused memory
//
//go:nocheckptr
func (self SessionManager) GC() {
if !atomic.CompareAndSwapUint32(&self.inGC, 0, 1) {
return
}

for _, shard := range self.shards {
shard.lock.RLock()
shard.lock.Lock()
n := shard.m
m := make(map[SessionID]Session, len(n))
for id, s := range n {
Expand All @@ -174,8 +175,9 @@ func (self SessionManager) GC() {
m[id] = s
}
}
// atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&shard.m)), unsafe.Pointer(&m))
shard.m = m
shard.lock.RUnlock()
shard.lock.Unlock()
}

atomic.StoreUint32(&self.inGC, 0)
Expand Down
2 changes: 1 addition & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (self *SessionMap) Get(key interface{}) interface{} {
}

// Set value for specific key,and return itself
func (self *SessionMap) WithValue(key interface{}, val interface{}) Session {
func (self *SessionMap) WithValue(key, val interface{}) Session {
if self == nil {
return nil
}
Expand Down

0 comments on commit 4b1853b

Please sign in to comment.