diff --git a/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java b/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java index 82a367c275..cc55c73f82 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/ChromaDbTest.java @@ -88,13 +88,11 @@ public static void setUp() throws Exception { COLL_ID.set((String) value.get("id")); }); - testCall(db, """ - CALL apoc.vectordb.chroma.upsert($host, $collection, - [ - {id: '1', vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}, text: 'ajeje'}, - {id: '2', vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}, text: 'brazorf'} - ]) - """, + testCall(db, "CALL apoc.vectordb.chroma.upsert($host, $collection,\n" + + "[\n" + + " {id: '1', vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: \"Berlin\", foo: \"one\"}, text: 'ajeje'},\n" + + " {id: '2', vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: \"London\", foo: \"two\"}, text: 'brazorf'}\n" + + "])", map("host", HOST, "collection", COLL_ID.get()), r -> { assertNull(r.get("value")); @@ -163,12 +161,10 @@ public void getVectorsWithoutVectorResult() { @Test public void deleteVector() { - testCall(db, """ - CALL apoc.vectordb.chroma.upsert($host, $collection, - [ - {id: 3, embedding: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}} - ]) - """, + testCall(db, "CALL apoc.vectordb.chroma.upsert($host, $collection,\n" + + "[\n" + + " {id: 3, embedding: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}}\n" + + "])", map("host", HOST, "collection", COLL_ID.get()), r -> { assertNull(r.get("value")); @@ -241,8 +237,7 @@ public void queryVectorsWithYield() { @Test public void queryVectorsWithFilter() { - testResult(db, """ - CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {city: 'London'}, 5, $conf) YIELD metadata, id""", + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {city: 'London'}, 5, $conf) YIELD metadata, id", map("host", HOST, "collection", COLL_ID.get(), "conf", map(ALL_RESULTS_KEY, true)), r -> { assertLondonResult(r.next(), FALSE); @@ -251,8 +246,7 @@ public void queryVectorsWithFilter() { @Test public void queryVectorsWithLimit() { - testResult(db, """ - CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 1, $conf) YIELD metadata, id""", + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 1, $conf) YIELD metadata, id", map("host", HOST, "collection", COLL_ID.get(), "conf", map(ALL_RESULTS_KEY, true)), r -> { assertBerlinResult(r.next(), FALSE); @@ -471,12 +465,10 @@ MAPPING_KEY, map(NODE_LABEL, "Rag", ); testResult(db, - """ - CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector - WITH collect(node) as paths - CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value - RETURN value - """ + "CALL apoc.vectordb.chroma.getAndUpdate($host, $collection, ['1', '2'], $conf) YIELD node, metadata, id, vector\n" + + "WITH collect(node) as paths\n" + + "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + + "RETURN value" , map( "host", HOST, diff --git a/full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java b/full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java index 7649d8e243..7007a4c5dc 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/MilvusTest.java @@ -93,13 +93,11 @@ public static void setUp() throws Exception { assertEquals(200L, value.get("code")); }); - testCall(db, """ - CALL apoc.vectordb.milvus.upsert($host, 'test_collection', - [ - {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}}, - {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}} - ]) - """, + testCall(db, "CALL apoc.vectordb.milvus.upsert($host, 'test_collection',\n" + + "[\n" + + " {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: \"Berlin\", foo: \"one\"}},\n" + + " {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: \"London\", foo: \"two\"}}\n" + + "])", map("host", HOST), r -> { Map value = (Map) r.get("value"); @@ -161,13 +159,11 @@ public void getVectorsWithoutVectorResult() { @Test public void deleteVector() { - testCall(db, """ - CALL apoc.vectordb.milvus.upsert($host, 'test_collection', - [ - {id: 3, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}}, - {id: 4, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}} - ]) - """, + testCall(db, "CALL apoc.vectordb.milvus.upsert($host, 'test_collection',\n" + + "[\n" + + " {id: 3, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}},\n" + + " {id: 4, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}}\n" + + "])", map("host", HOST), r -> { Map value = (Map) r.get("value"); @@ -234,10 +230,9 @@ public void queryVectorsWithYield() { @Test public void queryVectorsWithFilter() { - testResult(db, """ - CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], - 'city == "London"', - 5, $conf) YIELD metadata, id""", + testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7],\n" + + "'city == \"London\"',\n" + + "5, $conf) YIELD metadata, id", map("host", HOST, "conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true) ), @@ -248,8 +243,7 @@ public void queryVectorsWithFilter() { @Test public void queryVectorsWithLimit() { - testResult(db, """ - CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 1, $conf) YIELD metadata, id""", + testResult(db, "CALL apoc.vectordb.milvus.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], null, 1, $conf) YIELD metadata, id", map("host", HOST, "conf", map(FIELDS_KEY, FIELDS, ALL_RESULTS_KEY, true) ), @@ -480,12 +474,10 @@ MAPPING_KEY, map(NODE_LABEL, "Rag", ); testResult(db, - """ - CALL apoc.vectordb.milvus.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector - WITH collect(node) as paths - CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value - RETURN value - """ + "CALL apoc.vectordb.milvus.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector\n" + + "WITH collect(node) as paths\n" + + "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + + "RETURN value" , map( "host", HOST, diff --git a/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java b/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java index cca52b6e89..93762ddd5c 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/QdrantTest.java @@ -99,14 +99,12 @@ public static void setUp() throws Exception { assertEquals("ok", value.get("status")); }); - testCall(db, """ - CALL apoc.vectordb.qdrant.upsert($host, 'test_collection', - [ - {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: "Berlin", foo: "one"}}, - {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: "London", foo: "two"}} - ], - $conf) - """, + testCall(db, "CALL apoc.vectordb.qdrant.upsert($host, 'test_collection',\n" + + "[\n" + + " {id: 1, vector: [0.05, 0.61, 0.76, 0.74], metadata: {city: \"Berlin\", foo: \"one\"}},\n" + + " {id: 2, vector: [0.19, 0.81, 0.75, 0.11], metadata: {city: \"London\", foo: \"two\"}}\n" + + "],\n" + + "$conf)", map("host", HOST, "conf", ADMIN_HEADER_CONF), r -> { Map value = (Map) r.get("value"); @@ -196,14 +194,12 @@ public void getVectorsWithoutVectorResult() { @Test public void deleteVector() { - testCall(db, """ - CALL apoc.vectordb.qdrant.upsert($host, 'test_collection', - [ - {id: 3, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}}, - {id: 4, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: "baz"}} - ], - $conf) - """, + testCall(db, "CALL apoc.vectordb.qdrant.upsert($host, 'test_collection',\n" + + "[\n" + + " {id: 3, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}},\n" + + " {id: 4, vector: [0.19, 0.81, 0.75, 0.11], metadata: {foo: \"baz\"}}\n" + + "],\n" + + "$conf)", map("host", HOST, "conf", ADMIN_HEADER_CONF), r -> { Map value = (Map) r.get("value"); @@ -233,12 +229,10 @@ MAPPING_KEY, map(NODE_LABEL, "Rag", ); testResult(db, - """ - CALL apoc.vectordb.qdrant.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector - WITH collect(node) as paths - CALL apoc.ml.rag(paths, $attributes, "Which city has foo equals to one?", $confPrompt) YIELD value - RETURN value - """ + "CALL apoc.vectordb.qdrant.getAndUpdate($host, 'test_collection', [1, 2], $conf) YIELD node, metadata, id, vector\n" + + "WITH collect(node) as paths\n" + + "CALL apoc.ml.rag(paths, $attributes, \"Which city has foo equals to one?\", $confPrompt) YIELD value\n" + + "RETURN value" , map( "host", HOST, @@ -306,12 +300,11 @@ public void queryVectorsWithYield() { @Test public void queryVectorsWithFilter() { - testResultEventually(db, """ - CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], - { must: - [ { key: "city", match: { value: "London" } } ] - }, - 5, $conf) YIELD metadata, id""", + testResultEventually(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7],\n" + + "{ must:\n" + + " [ { key: \"city\", match: { value: \"London\" } } ]\n" + + "},\n" + + "5, $conf) YIELD metadata, id", map("host", HOST, "conf", map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION) ), @@ -323,8 +316,7 @@ public void queryVectorsWithFilter() { @Test public void queryVectorsWithLimit() { - testResultEventually(db, """ - CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 1, $conf) YIELD metadata, id""", + testResultEventually(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 1, $conf) YIELD metadata, id", map("host", HOST, "conf", map(ALL_RESULTS_KEY, true, HEADERS_KEY, ADMIN_AUTHORIZATION) ), diff --git a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateEnterpriseTest.java b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateEnterpriseTest.java index d38cd3bbbf..4eb2321730 100644 --- a/full-it/src/test/java/apoc/full/it/vectordb/WeaviateEnterpriseTest.java +++ b/full-it/src/test/java/apoc/full/it/vectordb/WeaviateEnterpriseTest.java @@ -3,7 +3,6 @@ import apoc.util.MapUtil; import apoc.util.Neo4jContainerExtension; import apoc.util.TestContainerUtil; -import apoc.util.WeaviateTestUtil; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -15,26 +14,13 @@ import java.util.List; import java.util.Map; +import static apoc.full.it.vectordb.WeaviateTestUtil.*; import static apoc.ml.RestAPIConfig.HEADERS_KEY; import static apoc.util.TestContainerUtil.createEnterpriseDB; import static apoc.util.TestContainerUtil.testCall; import static apoc.util.TestContainerUtil.testCallEmpty; import static apoc.util.TestContainerUtil.testResult; import static apoc.util.Util.map; -import static apoc.util.WeaviateTestUtil.ADMIN_AUTHORIZATION; -import static apoc.util.WeaviateTestUtil.ADMIN_HEADER_CONF; -import static apoc.util.WeaviateTestUtil.COLLECTION_NAME; -import static apoc.util.WeaviateTestUtil.FIELDS; -import static apoc.util.WeaviateTestUtil.HOST; -import static apoc.util.WeaviateTestUtil.ID_1; -import static apoc.util.WeaviateTestUtil.ID_2; -import static apoc.util.WeaviateTestUtil.WEAVIATE_CONTAINER; -import static apoc.util.WeaviateTestUtil.WEAVIATE_CREATE_COLLECTION_APOC; -import static apoc.util.WeaviateTestUtil.WEAVIATE_DELETE_COLLECTION_APOC; -import static apoc.util.WeaviateTestUtil.WEAVIATE_DELETE_VECTOR_APOC; -import static apoc.util.WeaviateTestUtil.WEAVIATE_PORT; -import static apoc.util.WeaviateTestUtil.WEAVIATE_QUERY_APOC; -import static apoc.util.WeaviateTestUtil.WEAVIATE_UPSERT_QUERY; import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY; import static apoc.vectordb.VectorEmbeddingConfig.FIELDS_KEY; import static org.junit.Assert.assertEquals; @@ -53,7 +39,7 @@ public static void setUp() throws Exception { Network network = Network.newNetwork(); // We build the project, the artifact will be placed into ./build/libs - neo4jContainer = createEnterpriseDB(List.of(TestContainerUtil.ApocPackage.EXTENDED), true) + neo4jContainer = createEnterpriseDB(List.of(TestContainerUtil.ApocPackage.FULL), true) .withNetwork(network) .withNetworkAliases("neo4j"); neo4jContainer.start(); diff --git a/full/src/main/java/apoc/vectordb/ChromaHandler.java b/full/src/main/java/apoc/vectordb/ChromaHandler.java index d99b4e657a..76e486879c 100644 --- a/full/src/main/java/apoc/vectordb/ChromaHandler.java +++ b/full/src/main/java/apoc/vectordb/ChromaHandler.java @@ -6,6 +6,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static apoc.util.MapUtil.map; @@ -34,7 +35,7 @@ public VectorEmbeddingConfig fromGet(Map config, List ids, String collection) { - List fields = procedureCallContext.outputFields().toList(); + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); VectorEmbeddingConfig conf = new VectorEmbeddingConfig(config); Map additionalBodies = map("ids", ids); @@ -50,7 +51,7 @@ public VectorEmbeddingConfig fromQuery(Map config, long limit, String collection) { - List fields = procedureCallContext.outputFields().toList(); + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); VectorEmbeddingConfig conf = new VectorEmbeddingConfig(config); Map additionalBodies = map("query_embeddings", List.of(vector), diff --git a/full/src/main/java/apoc/vectordb/Qdrant.java b/full/src/main/java/apoc/vectordb/Qdrant.java index 779524013d..1c26deb008 100644 --- a/full/src/main/java/apoc/vectordb/Qdrant.java +++ b/full/src/main/java/apoc/vectordb/Qdrant.java @@ -5,7 +5,6 @@ import apoc.result.MapResult; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.Transaction; -import org.neo4j.graphdb.security.URLAccessChecker; import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; diff --git a/full/src/main/java/apoc/vectordb/VectorDb.java b/full/src/main/java/apoc/vectordb/VectorDb.java index 638e29b052..848b17dc6c 100644 --- a/full/src/main/java/apoc/vectordb/VectorDb.java +++ b/full/src/main/java/apoc/vectordb/VectorDb.java @@ -201,8 +201,9 @@ private static void setVectorProp(VectorMappingConfig mapping } if (embedding == null) { - String embeddingErrMsg = "The embedding value is null. Make sure you execute `YIELD embedding` on the procedure and you configured `%s: true`" - .formatted(ALL_RESULTS_KEY); + String embeddingErrMsg = String.format( + "The embedding value is null. Make sure you execute `YIELD embedding` on the procedure and you configured `%s: true`", + ALL_RESULTS_KEY); throw new RuntimeException(embeddingErrMsg); } diff --git a/full/src/main/java/apoc/vectordb/VectorDbUtil.java b/full/src/main/java/apoc/vectordb/VectorDbUtil.java index deccb1f0db..640bc03212 100644 --- a/full/src/main/java/apoc/vectordb/VectorDbUtil.java +++ b/full/src/main/java/apoc/vectordb/VectorDbUtil.java @@ -20,9 +20,8 @@ import apoc.SystemPropertyKeys; -import apoc.util.CollectionUtils; -import apoc.util.ExtendedMapUtils; import apoc.util.Util; +import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.MapUtils; import org.apache.commons.lang3.StringUtils; import org.neo4j.graphdb.Label; @@ -35,7 +34,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import static apoc.ml.RestAPIConfig.BASE_URL_KEY; import static apoc.ml.RestAPIConfig.BODY_KEY; diff --git a/full/src/main/java/apoc/vectordb/Weaviate.java b/full/src/main/java/apoc/vectordb/Weaviate.java index 7bacafeff7..7456b0e279 100644 --- a/full/src/main/java/apoc/vectordb/Weaviate.java +++ b/full/src/main/java/apoc/vectordb/Weaviate.java @@ -7,7 +7,6 @@ import apoc.util.UrlResolver; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.Transaction; -import org.neo4j.graphdb.security.URLAccessChecker; import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; @@ -18,6 +17,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import java.util.stream.Stream; import static apoc.ml.RestAPIConfig.METHOD_KEY; @@ -141,7 +141,7 @@ public Stream delete(@Name("hostOrKey") String hostOrKey, List objects = ids.stream() .peek(id -> { - String endpoint = "%s/objects/%s/%s".formatted(restAPIConfig.getBaseUrl(), collection, id); + String endpoint = String.format("%s/objects/%s/%s", restAPIConfig.getBaseUrl(), collection, id); restAPIConfig.setEndpoint(endpoint); try { executeRequest(restAPIConfig); @@ -149,7 +149,7 @@ public Stream delete(@Name("hostOrKey") String hostOrKey, throw new RuntimeException(e); } }) - .toList(); + .collect(Collectors.toList()); return Stream.of(new ListResult(objects)); } @@ -183,7 +183,7 @@ private Stream getCommon(String hostOrKey, String collection, L */ config.putIfAbsent(METHOD_KEY, null); - List fields = procedureCallContext.outputFields().toList(); + List fields = procedureCallContext.outputFields().collect(Collectors.toList()); VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, procedureCallContext, ids, collection); boolean hasEmbedding = fields.contains("vector") && conf.isAllResults(); @@ -194,7 +194,7 @@ private Stream getCommon(String hostOrKey, String collection, L return ids.stream() .flatMap(id -> { - String endpoint = "%s/objects/%s/%s".formatted(conf.getApiConfig().getBaseUrl(), collection, id) + suffix; + String endpoint = String.format("%s/objects/%s/%s", conf.getApiConfig().getBaseUrl(), collection, id) + suffix; conf.getApiConfig().setEndpoint(endpoint); try { return executeRequest(conf.getApiConfig()) @@ -235,7 +235,7 @@ private Stream queryCommon(String hostOrKey, String collection, VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection); - return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, + return getEmbeddingResultStream(conf, procedureCallContext, tx, v -> { Object getValue = ((Map) v).get("data").get("Get"); Object collectionValue = ((Map) getValue).get(collection); diff --git a/full/src/test/java/apoc/util/ExtendedTestUtil.java b/full/src/test/java/apoc/util/ExtendedTestUtil.java index 91f712677d..0671b275ca 100644 --- a/full/src/test/java/apoc/util/ExtendedTestUtil.java +++ b/full/src/test/java/apoc/util/ExtendedTestUtil.java @@ -58,10 +58,15 @@ public static void testRetryCallEventually( */ public static void testResultEventually( GraphDatabaseService db, String call, Consumer resultConsumer, long timeout) { + testResultEventually(db, call, Collections.emptyMap(), resultConsumer, timeout); + } + + public static void testResultEventually( + GraphDatabaseService db, String call, Map params, Consumer resultConsumer, long timeout) { assertEventually( () -> { try { - return db.executeTransactionally(call, Map.of(), r -> { + return db.executeTransactionally(call, params, r -> { resultConsumer.accept(r); return true; });