Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Get shard client failed by client is closed #37729

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions internal/proxy/shard_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,11 @@
}
}

func (n *shardClient) DecRef() bool {
if n.refCnt.Dec() == 0 {
n.Close()
return true
func (n *shardClient) DecRef() {
ret := n.refCnt.Dec()
if ret <= 0 {
log.Warn("unexpected client ref count zero, please check the release call", zap.Int64("refCount", ret), zap.Stack("caller"))
}
return false
}

func (n *shardClient) IncRef() {
Expand All @@ -94,6 +93,9 @@
n.clients = nil
}

// Notice: close client should only be called by shard client manager. and after close, the client must be removed from the manager.
// 1. the client hasn't been used for a long time
// 2. shard client manager has been closed.
func (n *shardClient) Close() {
n.Lock()
defer n.Unlock()
Expand Down Expand Up @@ -137,8 +139,8 @@

// roundRobinSelectClient selects a client in a round-robin manner
func (n *shardClient) roundRobinSelectClient() (types.QueryNodeClient, error) {
n.Lock()
defer n.Unlock()
n.RLock()
defer n.RUnlock()
if n.isClosed {
return nil, errClosed
}
Expand All @@ -159,6 +161,11 @@
SetClientCreatorFunc(creator queryNodeCreatorFunc)
}

const (
defaultPurgeInterval = 600 * time.Second
defaultPurgeExpiredAge = 3
)

type shardClientMgrImpl struct {
clients struct {
sync.RWMutex
Expand All @@ -167,6 +174,9 @@
clientCreator queryNodeCreatorFunc

closeCh chan struct{}

purgeInterval time.Duration
purgeExpiredAge int
}

// SessionOpt provides a way to set params in SessionManager
Expand All @@ -187,8 +197,10 @@
sync.RWMutex
data map[UniqueID]*shardClient
}{data: make(map[UniqueID]*shardClient)},
clientCreator: defaultQueryNodeClientCreator,
closeCh: make(chan struct{}),
clientCreator: defaultQueryNodeClientCreator,
closeCh: make(chan struct{}),
purgeInterval: defaultPurgeInterval,
purgeExpiredAge: defaultPurgeExpiredAge,
}
for _, opt := range options {
opt(s)
Expand Down Expand Up @@ -227,9 +239,14 @@
return client.getClient(ctx)
}

// PurgeClient purges client if it is not used for a long time
func (c *shardClientMgrImpl) PurgeClient() {
ticker := time.NewTicker(600 * time.Second)
ticker := time.NewTicker(c.purgeInterval)
defer ticker.Stop()

// record node's age, if node reach 3 consecutive failures, try to purge it
nodeAges := make(map[int64]int, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is leaked.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to count time?

we can record last active time as a mark


for {
select {
case <-c.closeCh:
Expand All @@ -239,8 +256,16 @@
c.clients.Lock()
for nodeID, client := range c.clients.data {
if _, ok := shardLocations[nodeID]; !ok {
client.DecRef()
delete(c.clients.data, nodeID)
nodeAges[nodeID] += 1
if nodeAges[nodeID] > c.purgeExpiredAge {
if client.refCnt.Load() <= 1 {
client.Close()
delete(c.clients.data, nodeID)
log.Info("remove client due to not used for long time", zap.Int64("nodeID", nodeID))
}
}
} else {
nodeAges[nodeID] = 0

Check warning on line 268 in internal/proxy/shard_client.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/shard_client.go#L267-L268

Added lines #L267 - L268 were not covered by tests
}
}
c.clients.Unlock()
Expand Down
54 changes: 51 additions & 3 deletions internal/proxy/shard_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package proxy

import (
"context"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
Expand Down Expand Up @@ -40,7 +44,7 @@ func TestShardClient(t *testing.T) {
}

qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil)
qn.EXPECT().Close().Return(nil).Maybe()
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
Expand All @@ -60,7 +64,51 @@ func TestShardClient(t *testing.T) {
shardClient.DecRef()
assert.Equal(t, int64(1), shardClient.refCnt.Load())

// only shard client manager can close shard client
shardClient.DecRef()
shardClient.DecRef()
shardClient.DecRef()
assert.Equal(t, int64(0), shardClient.refCnt.Load())
assert.Equal(t, true, shardClient.isClosed)
shardClient.DecRef()
assert.Equal(t, false, shardClient.isClosed)
}

func TestPurgeClient(t *testing.T) {
nodeInfo := nodeInfo{
nodeID: 1,
}

qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil).Maybe()
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}

s := &shardClientMgrImpl{
clients: struct {
sync.RWMutex
data map[UniqueID]*shardClient
}{data: make(map[UniqueID]*shardClient)},
clientCreator: creator,
closeCh: make(chan struct{}),
purgeInterval: 1 * time.Second,
purgeExpiredAge: 3,
}
mockQC := mocks.NewMockQueryCoordClient(t)
mockRC := mocks.NewMockRootCoordClient(t)
mockRC.EXPECT().ListPolicy(mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{}, nil)
InitMetaCache(context.TODO(), mockRC, mockQC, s)

go s.PurgeClient()
defer s.Close()
// test client has been used
_, err := s.GetClient(context.Background(), nodeInfo)
assert.Nil(t, err)
time.Sleep(5 * time.Second)
// expected client has not been purged
assert.Equal(t, len(s.clients.data), 1)

s.ReleaseClientRef(1)
time.Sleep(5 * time.Second)
// expected client has been purged
assert.Equal(t, len(s.clients.data), 0)
}
Loading