diff --git a/loadbalance/loadbalance.go b/loadbalance/loadbalance.go index 92752f1..99ce9d9 100644 --- a/loadbalance/loadbalance.go +++ b/loadbalance/loadbalance.go @@ -6,9 +6,10 @@ var ErrEmptyLoadBalancer = errors.New("empty load balancer") type LoadBalancer[E any] interface { Next() E + Close() } type Weighted[E any] struct { Item E - Weight int + Weight int64 } diff --git a/loadbalance/random.go b/loadbalance/random.go index 4bf53d0..9fd5ccf 100644 --- a/loadbalance/random.go +++ b/loadbalance/random.go @@ -1,56 +1,67 @@ package loadbalance -import ( - "math/rand/v2" - "sync" -) +import "math/rand/v2" -var _ LoadBalancer[struct{}] = &random[struct{}]{} +var _ LoadBalancer[any] = &random[any]{} type random[E any] struct { - sync.Mutex - items []E - c chan E + rr *roundrobin[E] + next chan int64 + c chan struct{} } func Random[E any](items ...E) (LoadBalancer[E], error) { - if len(items) == 0 { - return nil, ErrEmptyLoadBalancer + rr, err := newRoundRobin[E](items) + if err != nil { + return nil, err } - return &random[E]{items: items, c: make(chan E, len(items))}, nil + return &random[E]{rr: rr, c: make(chan struct{})}, nil } -func WeightedRandom[E any](items ...Weighted[E]) (LoadBalancer[E], error) { - var pool []E - for _, i := range items { - for n := i.Weight; n > 0; n-- { - pool = append(pool, i.Item) - } - } - if len(pool) == 0 { - return nil, ErrEmptyLoadBalancer +func WeightedRandom[E any](items ...*Weighted[E]) (LoadBalancer[E], error) { + rr, err := newRoundRobin[E](items) + if err != nil { + return nil, err } - return Random(pool...) + return &random[E]{rr: rr, c: make(chan struct{})}, nil } -func (r *random[E]) load() { - length := len(r.items) - var s []int - for i := range length { - s = append(s, i) - } - rand.Shuffle(length, func(i, j int) { s[i], s[j] = s[j], s[i] }) - for _, i := range s { - r.c <- r.items[i] - } +func (r *random[E]) init() { + r.next = make(chan int64, r.rr.n) + go func() { + for { + if _, ok := <-r.c; !ok { + return + } + var s []int64 + for i := range r.rr.n { + s = append(s, i) + } + rand.Shuffle(len(s), func(i, j int) { s[i], s[j] = s[j], s[i] }) + for _, i := range s { + r.next <- i + } + } + }() } func (r *random[E]) Next() E { - r.Lock() - defer r.Unlock() + if r.rr == nil { + panic("load balancer is closed") + } + if r.next == nil { + r.init() + } + if len(r.next) <= int(r.rr.n/4) { + r.c <- struct{}{} + } + return r.rr.get(<-r.next) +} - if len(r.c) == 0 { - r.load() +func (r *random[E]) Close() { + if r.next != nil { + close(r.next) } - return <-r.c + close(r.c) + r.rr = nil } diff --git a/loadbalance/random_test.go b/loadbalance/random_test.go index 8679d65..c9fa0ef 100644 --- a/loadbalance/random_test.go +++ b/loadbalance/random_test.go @@ -27,7 +27,7 @@ func TestRandom(t *testing.T) { } } - loadbalancer, err = WeightedRandom([]Weighted[string]{{a, 2}, {b, 1}, {c, 1}}...) + loadbalancer, err = WeightedRandom([]*Weighted[string]{{a, 2}, {b, 1}, {c, 1}}...) if err != nil { t.Error(err) } else { diff --git a/loadbalance/roundrobin.go b/loadbalance/roundrobin.go index 3155d40..8bb7422 100644 --- a/loadbalance/roundrobin.go +++ b/loadbalance/roundrobin.go @@ -1,35 +1,68 @@ package loadbalance -import "sync/atomic" +import ( + "sync/atomic" -var _ LoadBalancer[struct{}] = &roundrobin[struct{}]{} + "github.com/sunshineplan/utils/cache" +) + +var _ LoadBalancer[any] = &roundrobin[any]{} type roundrobin[E any] struct { - items []E - next atomic.Int64 + m *cache.Map[[2]int64, *Weighted[E]] + n int64 + next atomic.Int64 } -func RoundRobin[E any](items ...E) (LoadBalancer[E], error) { +func newRoundRobin[E any, Items []E | []*Weighted[E]](items Items) (*roundrobin[E], error) { if len(items) == 0 { return nil, ErrEmptyLoadBalancer } - return &roundrobin[E]{items: items}, nil -} - -func WeightedRoundRobin[E any](items ...Weighted[E]) (LoadBalancer[E], error) { - var pool []E - for _, i := range items { - for n := i.Weight; n > 0; n-- { - pool = append(pool, i.Item) + var s []*Weighted[E] + switch items := any(items).(type) { + case []E: + for _, i := range items { + s = append(s, &Weighted[E]{i, 1}) } + case []*Weighted[E]: + s = items } - if len(pool) == 0 { - return nil, ErrEmptyLoadBalancer + r := new(roundrobin[E]) + r.m = new(cache.Map[[2]int64, *Weighted[E]]) + for _, i := range s { + r.m.Store([2]int64{r.n, r.n + i.Weight}, i) + r.n += i.Weight } - return &roundrobin[E]{items: pool}, nil + return r, nil +} + +func RoundRobin[E any](items ...E) (LoadBalancer[E], error) { + return newRoundRobin[E](items) +} + +func WeightedRoundRobin[E any](items ...*Weighted[E]) (LoadBalancer[E], error) { + return newRoundRobin[E](items) +} + +func (r *roundrobin[E]) get(n int64) (e E) { + if r.m == nil { + panic("load balancer is closed") + } + r.m.Range(func(i [2]int64, w *Weighted[E]) bool { + if n >= i[0] && n < i[1] { + e = w.Item + return false + } + return true + }) + return +} + +func (r *roundrobin[E]) Next() (next E) { + return r.get(r.next.Swap((r.next.Load() + 1) % r.n)) + } -func (r *roundrobin[E]) Next() E { - n := r.next.Add(1) - return r.items[(int(n)-1)%len(r.items)] +func (r *roundrobin[E]) Close() { + r.m = nil } diff --git a/loadbalance/roundrobin_test.go b/loadbalance/roundrobin_test.go index a3be176..5e74307 100644 --- a/loadbalance/roundrobin_test.go +++ b/loadbalance/roundrobin_test.go @@ -25,15 +25,15 @@ func TestRoundRobin(t *testing.T) { } } - loadbalancer, err = WeightedRoundRobin([]Weighted[string]{{a, 2}, {b, 1}, {c, 1}}...) + loadbalancer, err = WeightedRoundRobin([]*Weighted[string]{{a, 2}, {b, 1}, {c, 1}}...) if err != nil { t.Error(err) } else { var res []string - for range 8 { + for range 12 { res = append(res, loadbalancer.Next()) } - if expect := []string{a, a, b, c, a, a, b, c}; !reflect.DeepEqual(res, expect) { + if expect := []string{a, a, b, c, a, a, b, c, a, a, b, c}; !reflect.DeepEqual(res, expect) { t.Errorf("want %v, got %v", expect, res) } }