Skip to content

Commit

Permalink
Supports filling elements through templates for expression
Browse files Browse the repository at this point in the history
Signed-off-by: cai.zhang <[email protected]>
  • Loading branch information
xiaocai2333 committed Oct 27, 2024
1 parent b3eec5d commit 4e9080b
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 5 deletions.
5 changes: 5 additions & 0 deletions src/main/java/io/milvus/v2/service/vector/VectorService.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -235,6 +236,10 @@ public DeleteResp delete(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStu
builder.putExprTemplateValues(key, vectorUtils.deduceTemplateValue(value));
});
}
Map<String, TemplateValue> templateValues = new HashMap<>();
convertUtils.processFilterTemplateValues(builder.getExprTemplateValuesMap(), request.getFilterTemplateValues());
builder.putAllExprTemplateValues(templateValues);

MutationResult response = blockingStub.delete(builder.build());
rpcUtils.handleResponse(title, response.getStatus());
GTsDict.getInstance().updateCollectionTs(request.getCollectionName(), response.getTimestamp());
Expand Down
47 changes: 47 additions & 0 deletions src/main/java/io/milvus/v2/utils/ConvertUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,51 @@ public DescribeCollectionResp convertDescCollectionResp(DescribeCollectionRespon
.build();
return describeCollectionResp;
}

public TemplateValue convertExpressionValue(Object value) {
TemplateValue.Builder targetBuilder = TemplateValue.newBuilder();

if (value instanceof List<?>) {
TemplateArrayValue.Builder arrayBuilder = TemplateArrayValue.newBuilder();
boolean sameType = true;
DataType elementType = DataType.None;
int i = 0;

for (Object element : (List<?>) value) {
TemplateValue targetValue = convertExpressionValue(element);

if (elementType.equals(DataType.None)) {
elementType = targetValue.getType();
} else if (!elementType.equals(targetValue.getType())) {
sameType = false;
elementType = DataType.JSON;
}
arrayBuilder.setArray(i, targetValue);
}

arrayBuilder.setSameType(sameType);
arrayBuilder.setElementType(elementType);
targetBuilder.setType(DataType.Array).setArrayVal(arrayBuilder.build());
} else if (value instanceof Integer) {
targetBuilder.setType(DataType.Int64).setInt64Val(((Integer) value).longValue());
} else if (value instanceof Float) {
targetBuilder.setType(DataType.Float).setFloatVal(((Float) value).doubleValue());
} else if (value instanceof Double) {
targetBuilder.setType(DataType.Float).setFloatVal((Double) value);
} else if (value instanceof String) {
targetBuilder.setType(DataType.String).setStringVal((String) value);
} else {
throw new IllegalArgumentException("Unsupported type: " + value.getClass().getName());
}

return targetBuilder.build();
}

public void processFilterTemplateValues(Map<String, TemplateValue> templateValues, Map<String, Object> filterTemplateValues) {
if (filterTemplateValues != null && !filterTemplateValues.isEmpty()) {
filterTemplateValues.forEach((key, value) -> {
templateValues.put(key, convertExpressionValue(value));
});
}
}
}
9 changes: 7 additions & 2 deletions src/main/java/io/milvus/v2/utils/VectorUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import java.util.*;

public class VectorUtils {

public ConvertUtils convertUtils = new ConvertUtils();
public QueryRequest ConvertToGrpcQueryRequest(QueryReq request){
QueryRequest.Builder builder = QueryRequest.newBuilder()
.setCollectionName(request.getCollectionName())
Expand Down Expand Up @@ -82,7 +82,9 @@ public QueryRequest ConvertToGrpcQueryRequest(QueryReq request){
// .setKey(Constant.IGNORE_GROWING)
// .setValue(String.valueOf(request.isIgnoreGrowing()))
// .build());

Map<String, TemplateValue> templateValues = new HashMap<>();
convertUtils.processFilterTemplateValues(builder.getExprTemplateValuesMap(), request.getFilterTemplateValues());
builder.putAllExprTemplateValues(templateValues);
return builder.build();

}
Expand Down Expand Up @@ -203,6 +205,9 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) {
builder.setConsistencyLevelValue(request.getConsistencyLevel().getCode());
}

Map<String, TemplateValue> templateValues = new HashMap<>();
convertUtils.processFilterTemplateValues(builder.getExprTemplateValuesMap(), request.getFilterTemplateValues());
builder.putAllExprTemplateValues(templateValues);
return builder.build();
}

Expand Down
34 changes: 31 additions & 3 deletions src/test/java/io/milvus/v2/service/vector/VectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.*;

class VectorTest extends BaseTest {

Expand Down Expand Up @@ -105,6 +103,36 @@ void testSearch() {
logger.info(statusR.toString());
}

@Test
void testSearchWithTemplateExpression() {
List<Float> vectorList = new ArrayList<>();
vectorList.add(1.0f);
vectorList.add(2.0f);

Map<String, Map<String, Object>> expressionTemplateValues = new HashMap<>();
Map<String, Object> params = new HashMap<>();
params.put("max", 10);
expressionTemplateValues.put("id < {max}", params);

List<Object> list = Arrays.asList(1, 2, 3);
Map<String, Object> params2 = new HashMap<>();
params2.put("list", list);
expressionTemplateValues.put("id in {list}", params2);

expressionTemplateValues.forEach((key, value) -> {
SearchReq request = SearchReq.builder()
.collectionName("test")
.data(Collections.singletonList(new FloatVec(vectorList)))
.topK(10)
.offset(0L)
.filter(key)
.filterTemplateValues(value)
.build();
SearchResp statusR = client_v2.search(request);
logger.info(statusR.toString());
});
}

@Test
void testDelete() {
DeleteReq request = DeleteReq.builder()
Expand Down

0 comments on commit 4e9080b

Please sign in to comment.