Skip to content

Commit

Permalink
[NOID] changes after cherry-picks
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed Jan 15, 2025
1 parent 7e40c2f commit 3e2c6f3
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 110 deletions.
38 changes: 15 additions & 23 deletions full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,11 @@ public static void setUp() throws Exception {
COLL_ID.set((String) value.get("id"));
});

testCall(db, """
CALL apoc.vectordb.chroma.upsert($host, $collection,
[
{id: '1', vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}, text: 'ajeje'},
{id: '2', vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}, text: 'brazorf'}
])
""",
testCall(db, "CALL apoc.vectordb.chroma.upsert($host, $collection,\n" +
"[\n" +
" {id: '1', vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: \"Berlin\", foo: \"one\"}, text: 'ajeje'},\n" +
" {id: '2', vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: \"London\", foo: \"two\"}, text: 'brazorf'}\n" +
"])",
map("host", HOST, "collection", COLL_ID.get()),
r -> {
assertNull(r.get("value"));
Expand Down Expand Up @@ -163,12 +161,10 @@ public void getVectorsWithoutVectorResult() {

@Test
public void deleteVector() {
testCall(db, """
CALL apoc.vectordb.chroma.upsert($host, $collection,
[
{id: 3, embedding: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}}
])
""",
testCall(db, "CALL apoc.vectordb.chroma.upsert($host, $collection,\n" +
"[\n" +
" {id: 3, embedding: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}}\n" +
"])",
map("host", HOST, "collection", COLL_ID.get()),
r -> {
assertNull(r.get("value"));
Expand Down Expand Up @@ -241,8 +237,7 @@ public void queryVectorsWithYield() {

@Test
public void queryVectorsWithFilter() {
testResult(db, """
CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {city: 'London'}, 5, $conf) YIELD metadata, id""",
testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {city: 'London'}, 5, $conf) YIELD metadata, id",
map("host", HOST, "collection", COLL_ID.get(), "conf", map(ALL_RESULTS_KEY, true)),
r -> {
assertLondonResult(r.next(), FALSE);
Expand All @@ -251,8 +246,7 @@ public void queryVectorsWithFilter() {

@Test
public void queryVectorsWithLimit() {
testResult(db, """
CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 1, $conf) YIELD metadata, id""",
testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 1, $conf) YIELD metadata, id",
map("host", HOST, "collection", COLL_ID.get(), "conf", map(ALL_RESULTS_KEY, true)),
r -> {
assertBerlinResult(r.next(), FALSE);
Expand Down Expand Up @@ -471,12 +465,10 @@ MAPPING_KEY, map(NODE_LABEL, "Rag",
);

testResult(db,
"""
CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector
WITH collect(node) as paths
CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value
RETURN value
"""
"CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector\n" +
"WITH collect(node) as paths\n" +
"CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" +
"RETURN value"
,
map(
"host", HOST,
Expand Down
44 changes: 18 additions & 26 deletions full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,11 @@ public static void setUp() throws Exception {
assertEquals(200L, value.get("code"));
});

testCall(db, """
CALL apoc.vectordb.milvus.upsert($host, 'test_collection',
[
{id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}},
{id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}}
])
""",
testCall(db, "CALL apoc.vectordb.milvus.upsert($host, 'test_collection',\n" +
"[\n" +
" {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: \"Berlin\", foo: \"one\"}},\n" +
" {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: \"London\", foo: \"two\"}}\n" +
"])",
map("host", HOST),
r -> {
Map value = (Map) r.get("value");
Expand Down Expand Up @@ -161,13 +159,11 @@ public void getVectorsWithoutVectorResult() {

@Test
public void deleteVector() {
testCall(db, """
CALL apoc.vectordb.milvus.upsert($host, 'test_collection',
[
{id: 3, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}},
{id: 4, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}}
])
""",
testCall(db, "CALL apoc.vectordb.milvus.upsert($host, 'test_collection',\n" +
"[\n" +
" {id: 3, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}},\n" +
" {id: 4, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}}\n" +
"])",
map("host", HOST),
r -> {
Map value = (Map) r.get("value");
Expand Down Expand Up @@ -234,10 +230,9 @@ public void queryVectorsWithYield() {

@Test
public void queryVectorsWithFilter() {
testResult(db, """
CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7],
'city == "London"',
5, $conf) YIELD metadata, id""",
testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7],\n" +
"'city == \"London\"',\n" +
"5, $conf) YIELD metadata, id",
map("host", HOST,
"conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true)
),
Expand All @@ -248,8 +243,7 @@ public void queryVectorsWithFilter() {

@Test
public void queryVectorsWithLimit() {
testResult(db, """
CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 1, $conf) YIELD metadata, id""",
testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 1, $conf) YIELD metadata, id",
map("host", HOST,
"conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true)
),
Expand Down Expand Up @@ -480,12 +474,10 @@ MAPPING_KEY, map(NODE_LABEL, "Rag",
);

testResult(db,
"""
CALL apoc.vectordb.milvus.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector
WITH collect(node) as paths
CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value
RETURN value
"""
"CALL apoc.vectordb.milvus.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector\n" +
"WITH collect(node) as paths\n" +
"CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" +
"RETURN value"
,
map(
"host", HOST,
Expand Down
52 changes: 22 additions & 30 deletions full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,12 @@ public static void setUp() throws Exception {
assertEquals("ok", value.get("status"));
});

testCall(db, """
CALL apoc.vectordb.qdrant.upsert($host, 'test_collection',
[
{id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}},
{id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}}
],
$conf)
""",
testCall(db, "CALL apoc.vectordb.qdrant.upsert($host, 'test_collection',\n" +
"[\n" +
" {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: \"Berlin\", foo: \"one\"}},\n" +
" {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: \"London\", foo: \"two\"}}\n" +
"],\n" +
"$conf)",
map("host", HOST, "conf", ADMIN_HEADER_CONF),
r -> {
Map value = (Map) r.get("value");
Expand Down Expand Up @@ -196,14 +194,12 @@ public void getVectorsWithoutVectorResult() {

@Test
public void deleteVector() {
testCall(db, """
CALL apoc.vectordb.qdrant.upsert($host, 'test_collection',
[
{id: 3, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}},
{id: 4, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}}
],
$conf)
""",
testCall(db, "CALL apoc.vectordb.qdrant.upsert($host, 'test_collection',\n" +
"[\n" +
" {id: 3, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}},\n" +
" {id: 4, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}}\n" +
"],\n" +
"$conf)",
map("host", HOST, "conf", ADMIN_HEADER_CONF),
r -> {
Map value = (Map) r.get("value");
Expand Down Expand Up @@ -233,12 +229,10 @@ MAPPING_KEY, map(NODE_LABEL, "Rag",
);

testResult(db,
"""
CALL apoc.vectordb.qdrant.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector
WITH collect(node) as paths
CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value
RETURN value
"""
"CALL apoc.vectordb.qdrant.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector\n" +
"WITH collect(node) as paths\n" +
"CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" +
"RETURN value"
,
map(
"host", HOST,
Expand Down Expand Up @@ -306,12 +300,11 @@ public void queryVectorsWithYield() {

@Test
public void queryVectorsWithFilter() {
testResultEventually(db, """
CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7],
{ must:
[ { key: "city", match: { value: "London" } } ]
},
5, $conf) YIELD metadata, id""",
testResultEventually(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7],\n" +
"{ must:\n" +
" [ { key: \"city\", match: { value: \"London\" } } ]\n" +
"},\n" +
"5, $conf) YIELD metadata, id",
map("host", HOST,
"conf", map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)
),
Expand All @@ -323,8 +316,7 @@ public void queryVectorsWithFilter() {

@Test
public void queryVectorsWithLimit() {
testResultEventually(db, """
CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 1, $conf) YIELD metadata, id""",
testResultEventually(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 1, $conf) YIELD metadata, id",
map("host", HOST,
"conf", map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import apoc.util.MapUtil;
import apoc.util.Neo4jContainerExtension;
import apoc.util.TestContainerUtil;
import apoc.util.WeaviateTestUtil;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.ClassRule;
Expand All @@ -15,26 +14,13 @@
import java.util.List;
import java.util.Map;

import static apoc.full.it.vectordb.WeaviateTestUtil.*;
import static apoc.ml.RestAPIConfig.HEADERS_KEY;
import static apoc.util.TestContainerUtil.createEnterpriseDB;
import static apoc.util.TestContainerUtil.testCall;
import static apoc.util.TestContainerUtil.testCallEmpty;
import static apoc.util.TestContainerUtil.testResult;
import static apoc.util.Util.map;
import static apoc.util.WeaviateTestUtil.ADMIN_AUTHORIZATION;
import static apoc.util.WeaviateTestUtil.ADMIN_HEADER_CONF;
import static apoc.util.WeaviateTestUtil.COLLECTION_NAME;
import static apoc.util.WeaviateTestUtil.FIELDS;
import static apoc.util.WeaviateTestUtil.HOST;
import static apoc.util.WeaviateTestUtil.ID_1;
import static apoc.util.WeaviateTestUtil.ID_2;
import static apoc.util.WeaviateTestUtil.WEAVIATE_CONTAINER;
import static apoc.util.WeaviateTestUtil.WEAVIATE_CREATE_COLLECTION_APOC;
import static apoc.util.WeaviateTestUtil.WEAVIATE_DELETE_COLLECTION_APOC;
import static apoc.util.WeaviateTestUtil.WEAVIATE_DELETE_VECTOR_APOC;
import static apoc.util.WeaviateTestUtil.WEAVIATE_PORT;
import static apoc.util.WeaviateTestUtil.WEAVIATE_QUERY_APOC;
import static apoc.util.WeaviateTestUtil.WEAVIATE_UPSERT_QUERY;
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY;
import static org.junit.Assert.assertEquals;
Expand All @@ -53,7 +39,7 @@ public static void setUp() throws Exception {
Network network = Network.newNetwork();

// We build the project, the artifact will be placed into ./build/libs
neo4jContainer = createEnterpriseDB(List.of(TestContainerUtil.ApocPackage.EXTENDED), true)
neo4jContainer = createEnterpriseDB(List.of(TestContainerUtil.ApocPackage.FULL), true)
.withNetwork(network)
.withNetworkAliases("neo4j");
neo4jContainer.start();
Expand Down
5 changes: 3 additions & 2 deletions full/src/main/java/apoc/vectordb/ChromaHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static apoc.util.MapUtil.map;

Expand Down Expand Up @@ -34,7 +35,7 @@ public <T> VectorEmbeddingConfig fromGet(Map<String, Object> config,
List<T> ids,
String collection) {

List<String> fields = procedureCallContext.outputFields().toList();
List<String> fields = procedureCallContext.outputFields().collect(Collectors.toList());

VectorEmbeddingConfig conf = new VectorEmbeddingConfig(config);
Map<String, Object> additionalBodies = map("ids", ids);
Expand All @@ -50,7 +51,7 @@ public VectorEmbeddingConfig fromQuery(Map<String, Object> config,
long limit,
String collection) {

List<String> fields = procedureCallContext.outputFields().toList();
List<String> fields = procedureCallContext.outputFields().collect(Collectors.toList());

VectorEmbeddingConfig conf = new VectorEmbeddingConfig(config);
Map<String, Object> additionalBodies = map("query_embeddings", List.of(vector),
Expand Down
1 change: 0 additions & 1 deletion full/src/main/java/apoc/vectordb/Qdrant.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import apoc.result.MapResult;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Transaction;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
Expand Down
5 changes: 3 additions & 2 deletions full/src/main/java/apoc/vectordb/VectorDb.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ private static <T extends Entity> void setVectorProp(VectorMappingConfig mapping
}

if (embedding == null) {
String embeddingErrMsg = "The embedding value is null. Make sure you execute `YIELD embedding` on the procedure and you configured `%s: true`"
.formatted(ALL_RESULTS_KEY);
String embeddingErrMsg = String.format(
"The embedding value is null. Make sure you execute `YIELD embedding` on the procedure and you configured `%s: true`",
ALL_RESULTS_KEY);
throw new RuntimeException(embeddingErrMsg);
}

Expand Down
4 changes: 1 addition & 3 deletions full/src/main/java/apoc/vectordb/VectorDbUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@


import apoc.SystemPropertyKeys;
import apoc.util.CollectionUtils;
import apoc.util.ExtendedMapUtils;
import apoc.util.Util;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.graphdb.Label;
Expand All @@ -35,7 +34,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static apoc.ml.RestAPIConfig.BASE_URL_KEY;
import static apoc.ml.RestAPIConfig.BODY_KEY;
Expand Down
Loading

0 comments on commit 3e2c6f3

Please sign in to comment.