diff --git a/pkg/client/clientpool.go b/pkg/client/clientpool.go index ddffe6668..31c54f79c 100644 --- a/pkg/client/clientpool.go +++ b/pkg/client/clientpool.go @@ -38,8 +38,7 @@ type Pool interface { } type PoolImpl struct { - mu sync.Mutex - pool map[uint]Client + pool sync.Map } var _ Pool = &PoolImpl{} @@ -58,11 +57,7 @@ var _ Pool = &PoolImpl{} // Returns: // - error: An error if any occurred during the process. func (c *PoolImpl) Put(client Client) error { - c.mu.Lock() - defer c.mu.Unlock() - - c.pool[client.ID()] = client - + c.pool.Store(client.ID(), client) return nil } @@ -82,14 +77,10 @@ func (c *PoolImpl) Put(client Client) error { // - ok: A boolean indicating whether the client was successfully removed from the pool. // - error: An error if any occurred during the process, including context cancellation or timeout. func (c *PoolImpl) Pop(id uint) (bool, error) { - c.mu.Lock() - defer c.mu.Unlock() - var err error - cl, ok := c.pool[id] + cl, ok := c.pool.LoadAndDelete(id) if ok { - err = cl.Close() - delete(c.pool, id) + err = cl.(Client).Close() } return ok, err @@ -107,16 +98,17 @@ func (c *PoolImpl) Pop(id uint) (bool, error) { // Returns: // - error: An error if any occurred during the shutdown process. func (c *PoolImpl) Shutdown() error { - c.mu.Lock() - defer c.mu.Unlock() - for _, cl := range c.pool { + c.pool.Range(func(key, value any) bool { + cl := value.(Client) go func(cl Client) { if err := cl.Shutdown(); err != nil { spqrlog.Zero.Error().Err(err).Msg("") } }(cl) - } + + return true + }) return nil } @@ -136,14 +128,16 @@ func (c *PoolImpl) Shutdown() error { // - error: An error if any occurred during the iteration. func (c *PoolImpl) ClientPoolForeach(cb func(client ClientInfo) error) error { - c.mu.Lock() - defer c.mu.Unlock() + c.pool.Range(func(key, value any) bool { + cl := value.(Client) - for _, cl := range c.pool { if err := cb(ClientInfoImpl{Client: cl, rAddr: "local"}); err != nil { spqrlog.Zero.Error().Err(err).Msg("") + return false } - } + + return true + }) return nil } @@ -160,7 +154,6 @@ func (c *PoolImpl) ClientPoolForeach(cb func(client ClientInfo) error) error { // - Pool: A pointer to the newly created PoolImpl instance. func NewClientPool() Pool { return &PoolImpl{ - pool: map[uint]Client{}, - mu: sync.Mutex{}, + pool: sync.Map{}, } }