Skip to content

Commit

Permalink
enhance: support retry search when topk is reduced and result not eno…
Browse files Browse the repository at this point in the history
…ugh (#35645)

issue: #35576 

This pr is to cover those cases when queryHook optimize search params
and make the result size insufficient, add retry search mechanism and
add related metrics for alarming.

---------

Signed-off-by: chasingegg <[email protected]>
  • Loading branch information
chasingegg authored Oct 23, 2024
1 parent 04343d1 commit 1d61b60
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 48 deletions.
2 changes: 2 additions & 0 deletions internal/proto/internal.proto
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ message SearchRequest {
int64 group_by_field_id = 23;
int64 group_size = 24;
int64 field_id = 25;
bool is_topk_reduce = 26;
}

message SubSearchResults {
Expand Down Expand Up @@ -161,6 +162,7 @@ message SearchResults {
repeated SubSearchResults sub_results = 15;
bool is_advanced = 16;
int64 all_search_count = 17;
bool is_topk_reduce = 18;
}

message CostAggregation {
Expand Down
96 changes: 72 additions & 24 deletions internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -2899,9 +2899,30 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
rsp := &milvuspb.SearchResults{
Status: merr.Success(),
}

optimizedSearch := true
resultSizeInsufficient := false
isTopkReduce := false
err2 := retry.Handle(ctx, func() (bool, error) {
rsp, err = node.
search(ctx, request)
rsp, resultSizeInsufficient, isTopkReduce, err = node.search(ctx, request, optimizedSearch)
if merr.Ok(rsp.GetStatus()) && optimizedSearch && resultSizeInsufficient && isTopkReduce && paramtable.Get().AutoIndexConfig.EnableResultLimitCheck.GetAsBool() {
// without optimize search
optimizedSearch = false
rsp, resultSizeInsufficient, isTopkReduce, err = node.search(ctx, request, optimizedSearch)
metrics.ProxyRetrySearchCount.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel,
request.GetCollectionName(),
).Inc()
// result size still insufficient
if resultSizeInsufficient {
metrics.ProxyRetrySearchResultInsufficientCount.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel,
request.GetCollectionName(),
).Inc()
}
}
if errors.Is(merr.Error(rsp.GetStatus()), merr.ErrInconsistentRequery) {
return true, merr.Error(rsp.GetStatus())
}
Expand All @@ -2913,11 +2934,13 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
return rsp, err
}

func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
metrics.GetStats(ctx).
SetNodeID(paramtable.GetNodeID()).
SetInboundLabel(metrics.SearchLabel).
SetCollectionName(request.GetCollectionName())
func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, optimizedSearch bool) (*milvuspb.SearchResults, bool, bool, error) {
receiveSize := proto.Size(request)
metrics.ProxyReceiveBytes.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel,
request.GetCollectionName(),
).Add(float64(receiveSize))

metrics.ProxyReceivedNQ.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
Expand All @@ -2928,7 +2951,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return &milvuspb.SearchResults{
Status: merr.Status(err),
}, nil
}, false, false, nil
}

method := "Search"
Expand All @@ -2949,7 +2972,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
if err != nil {
return &milvuspb.SearchResults{
Status: merr.Status(err),
}, nil
}, false, false, nil
}

request.PlaceholderGroup = placeholderGroupBytes
Expand All @@ -2963,7 +2986,8 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
commonpbutil.WithMsgType(commonpb.MsgType_Search),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
ReqID: paramtable.GetNodeID(),
IsTopkReduce: optimizedSearch,
},
request: request,
tr: timerecord.NewTimeRecorder("search"),
Expand Down Expand Up @@ -3017,7 +3041,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)

return &milvuspb.SearchResults{
Status: merr.Status(err),
}, nil
}, false, false, nil
}
tr.CtxRecord(ctx, "search request enqueue")

Expand All @@ -3043,7 +3067,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)

return &milvuspb.SearchResults{
Status: merr.Status(err),
}, nil
}, false, false, nil
}

span := tr.CtxRecord(ctx, "wait search result")
Expand Down Expand Up @@ -3100,16 +3124,37 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest)
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeSearch, dbName, username).Add(float64(v))
}
}
return qt.result, nil
return qt.result, qt.resultSizeInsufficient, qt.isTopkReduce, nil
}

func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
var err error
rsp := &milvuspb.SearchResults{
Status: merr.Success(),
}
optimizedSearch := true
resultSizeInsufficient := false
isTopkReduce := false
err2 := retry.Handle(ctx, func() (bool, error) {
rsp, err = node.hybridSearch(ctx, request)
rsp, resultSizeInsufficient, isTopkReduce, err = node.hybridSearch(ctx, request, optimizedSearch)
if merr.Ok(rsp.GetStatus()) && optimizedSearch && resultSizeInsufficient && isTopkReduce && paramtable.Get().AutoIndexConfig.EnableResultLimitCheck.GetAsBool() {
// without optimize search
optimizedSearch = false
rsp, resultSizeInsufficient, isTopkReduce, err = node.hybridSearch(ctx, request, optimizedSearch)
metrics.ProxyRetrySearchCount.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.HybridSearchLabel,
request.GetCollectionName(),
).Inc()
// result size still insufficient
if resultSizeInsufficient {
metrics.ProxyRetrySearchResultInsufficientCount.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.HybridSearchLabel,
request.GetCollectionName(),
).Inc()
}
}
if errors.Is(merr.Error(rsp.GetStatus()), merr.ErrInconsistentRequery) {
return true, merr.Error(rsp.GetStatus())
}
Expand All @@ -3121,16 +3166,18 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea
return rsp, err
}

func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
metrics.GetStats(ctx).
SetNodeID(paramtable.GetNodeID()).
SetInboundLabel(metrics.HybridSearchLabel).
SetCollectionName(request.GetCollectionName())
func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest, optimizedSearch bool) (*milvuspb.SearchResults, bool, bool, error) {
receiveSize := proto.Size(request)
metrics.ProxyReceiveBytes.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.HybridSearchLabel,
request.GetCollectionName(),
).Add(float64(receiveSize))

if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return &milvuspb.SearchResults{
Status: merr.Status(err),
}, nil
}, false, false, nil
}

method := "HybridSearch"
Expand All @@ -3154,7 +3201,8 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
commonpbutil.WithMsgType(commonpb.MsgType_Search),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
ReqID: paramtable.GetNodeID(),
IsTopkReduce: optimizedSearch,
},
request: newSearchReq,
tr: timerecord.NewTimeRecorder(method),
Expand Down Expand Up @@ -3203,7 +3251,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea

return &milvuspb.SearchResults{
Status: merr.Status(err),
}, nil
}, false, false, nil
}
tr.CtxRecord(ctx, "hybrid search request enqueue")

Expand All @@ -3228,7 +3276,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea

return &milvuspb.SearchResults{
Status: merr.Status(err),
}, nil
}, false, false, nil
}

span := tr.CtxRecord(ctx, "wait hybrid search result")
Expand Down Expand Up @@ -3285,7 +3333,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeHybridSearch, dbName, username).Add(float64(v))
}
}
return qt.result, nil
return qt.result, qt.resultSizeInsufficient, qt.isTopkReduce, nil
}

func (node *Proxy) getVectorPlaceholderGroupForSearchByPks(ctx context.Context, request *milvuspb.SearchRequest) ([]byte, error) {
Expand Down
18 changes: 18 additions & 0 deletions internal/proxy/task_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ type searchTask struct {
partitionKeyMode bool
enableMaterializedView bool
mustUsePartitionKey bool
resultSizeInsufficient bool
isTopkReduce bool

userOutputFields []string
userDynamicFields []string
Expand Down Expand Up @@ -644,7 +646,11 @@ func (t *searchTask) PostExecute(ctx context.Context) error {

t.queryChannelsTs = make(map[string]uint64)
t.relatedDataSize = 0
isTopkReduce := false
for _, r := range toReduceResults {
if r.GetIsTopkReduce() {
isTopkReduce = true
}
t.relatedDataSize += r.GetCostAggregation().GetTotalRelatedDataSize()
for ch, ts := range r.GetChannelsMvcc() {
t.queryChannelsTs[ch] = ts
Expand All @@ -657,6 +663,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
return err
}

// reduce
if t.SearchRequest.GetIsAdvanced() {
multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs()))
for _, searchResult := range toReduceResults {
Expand Down Expand Up @@ -713,6 +720,17 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}
}

// reduce done, get final result
limit := t.SearchRequest.GetTopk() - t.SearchRequest.GetOffset()
resultSizeInsufficient := false
for _, topk := range t.result.Results.Topks {
if topk < limit {
resultSizeInsufficient = true
break
}
}
t.resultSizeInsufficient = resultSizeInsufficient
t.isTopkReduce = isTopkReduce
t.result.CollectionName = t.collectionName
t.fillInFieldInfo()

Expand Down
6 changes: 5 additions & 1 deletion internal/proxy/task_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ func TestSearchTask_PostExecute(t *testing.T) {
task := &searchTask{
ctx: ctx,
collectionName: collName,
SearchRequest: &internalpb.SearchRequest{},
SearchRequest: &internalpb.SearchRequest{
IsTopkReduce: true,
},
request: &milvuspb.SearchRequest{
CollectionName: collName,
Nq: 1,
Expand All @@ -98,6 +100,8 @@ func TestSearchTask_PostExecute(t *testing.T) {
err := qt.PostExecute(context.TODO())
assert.NoError(t, err)
assert.Equal(t, qt.result.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
assert.Equal(t, qt.resultSizeInsufficient, true)
assert.Equal(t, qt.isTopkReduce, false)
})
}

Expand Down
1 change: 1 addition & 0 deletions internal/querynodev2/delegator/delegator.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
GroupByFieldId: subReq.GetGroupByFieldId(),
GroupSize: subReq.GetGroupSize(),
FieldId: subReq.GetFieldId(),
IsTopkReduce: req.GetReq().GetIsTopkReduce(),
}
future := conc.Go(func() (*internalpb.SearchResults, error) {
searchReq := &querypb.SearchRequest{
Expand Down
8 changes: 6 additions & 2 deletions internal/querynodev2/optimizers/query_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type QueryHook interface {
func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, queryHook QueryHook, numSegments int) (*querypb.SearchRequest, error) {
// no hook applied or disabled, just return
if queryHook == nil || !paramtable.Get().AutoIndexConfig.Enable.GetAsBool() {
req.Req.IsTopkReduce = false
return req, nil
}

Expand Down Expand Up @@ -67,7 +68,7 @@ func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, query
common.SegmentNumKey: estSegmentNum,
common.WithFilterKey: withFilter,
common.DataTypeKey: int32(plan.GetVectorAnns().GetVectorType()),
common.WithOptimizeKey: paramtable.Get().AutoIndexConfig.EnableOptimize.GetAsBool(),
common.WithOptimizeKey: paramtable.Get().AutoIndexConfig.EnableOptimize.GetAsBool() && req.GetReq().GetIsTopkReduce(),
common.CollectionKey: req.GetReq().GetCollectionID(),
}
if withFilter && channelNum > 1 {
Expand All @@ -78,14 +79,17 @@ func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, query
log.Warn("failed to execute queryHook", zap.Error(err))
return nil, merr.WrapErrServiceUnavailable(err.Error(), "queryHook execution failed")
}
queryInfo.Topk = params[common.TopKKey].(int64)
finalTopk := params[common.TopKKey].(int64)
isTopkReduce := req.GetReq().GetIsTopkReduce() && (finalTopk < queryInfo.GetTopk())
queryInfo.Topk = finalTopk
queryInfo.SearchParams = params[common.SearchParamKey].(string)
serializedExprPlan, err := proto.Marshal(&plan)
if err != nil {
log.Warn("failed to marshal optimized plan", zap.Error(err))
return nil, merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
}
req.Req.SerializedExprPlan = serializedExprPlan
req.Req.IsTopkReduce = isTopkReduce
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
default:
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
Expand Down
Loading

0 comments on commit 1d61b60

Please sign in to comment.