diff --git a/proc/community/src/test/java/org/neo4j/gds/leiden/LeidenMutateProcTest.java b/proc/community/src/test/java/org/neo4j/gds/leiden/LeidenMutateProcTest.java index 9e7aa8912c..7614c9ddc7 100644 --- a/proc/community/src/test/java/org/neo4j/gds/leiden/LeidenMutateProcTest.java +++ b/proc/community/src/test/java/org/neo4j/gds/leiden/LeidenMutateProcTest.java @@ -31,9 +31,13 @@ import org.neo4j.gds.catalog.GraphProjectProc; import org.neo4j.gds.catalog.GraphStreamNodePropertiesProc; import org.neo4j.gds.core.loading.GraphStoreCatalog; +import org.neo4j.gds.extension.IdFunction; +import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.Neo4jGraph; +import java.util.HashMap; import java.util.HashSet; +import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNoException; @@ -69,6 +73,9 @@ class LeidenMutateProcTest extends BaseProcTest { " (a5)-[:R {weight: 1.0}]->(a7)," + " (a6)-[:R {weight: 1.0}]->(a7)"; + @Inject + IdFunction idFunction; + @BeforeEach void setUp() throws Exception { registerProcedures( @@ -84,16 +91,36 @@ void setUp() throws Exception { @ParameterizedTest @ValueSource(strings = {"gds.leiden","gds.beta.leiden"}) void mutate(String procedureName) { + var query = "CALL " + procedureName + ".mutate('leiden', {mutateProperty: 'communityId', concurrency: 1})"; assertLeidenMutateQuery(query); + Graph mutatedGraph = GraphStoreCatalog.get(getUsername(), DatabaseId.of(db.databaseName()), "leiden").graphStore().getUnion(); + var communities = mutatedGraph.nodeProperties("communityId"); - var communitySet = new HashSet(); + HashMap communitiesSet = new HashMap<>(); + mutatedGraph.forEachNode(nodeId -> { - communitySet.add(communities.longValue(nodeId)); + var community = communities.longValue(nodeId); + var neo4jId = mutatedGraph.toOriginalNodeId(nodeId); + communitiesSet.put(neo4jId, community); return true; }); - assertThat(communitySet).containsExactly(3L, 6L); + + Function map = node -> communitiesSet.get(idFunction.of(node)); + + //community 1 + assertThat(map.apply("a0")) + .isEqualTo(map.apply("a2")) + .isEqualTo(map.apply("a3")) + .isEqualTo(map.apply("a4")) + .isNotEqualTo(map.apply("a1")); + + //community 2 + assertThat(map.apply("a1")) + .isEqualTo(map.apply("a5")) + .isEqualTo(map.apply("a6")) + .isEqualTo(map.apply("a7")); } diff --git a/proc/community/src/test/java/org/neo4j/gds/leiden/LeidenWriteProcTest.java b/proc/community/src/test/java/org/neo4j/gds/leiden/LeidenWriteProcTest.java index 0422ff8fbb..c880be27fa 100644 --- a/proc/community/src/test/java/org/neo4j/gds/leiden/LeidenWriteProcTest.java +++ b/proc/community/src/test/java/org/neo4j/gds/leiden/LeidenWriteProcTest.java @@ -33,10 +33,14 @@ import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.catalog.GraphProjectProc; import org.neo4j.gds.core.loading.GraphStoreCatalog; +import org.neo4j.gds.extension.IdFunction; +import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.Neo4jGraph; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.function.Function; import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; @@ -73,6 +77,10 @@ class LeidenWriteProcTest extends BaseProcTest { " (a5)-[:R {weight: 1.0}]->(a7)," + " (a6)-[:R {weight: 1.0}]->(a7)"; + @Inject + IdFunction idFunction; + + @BeforeEach void setUp() throws Exception { registerProcedures( @@ -100,16 +108,36 @@ void write(String procedureName) { var writeGraph = GraphStoreCatalog.get(getUsername(), DatabaseId.of(db.databaseName()), "writeGraph").graphStore().getUnion(); + var communities = writeGraph.nodeProperties("communityId"); - var communitySet = new HashSet(); + + HashMap communitiesSet = new HashMap<>(); + writeGraph.forEachNode(nodeId -> { - communitySet.add(communities.longValue(nodeId)); + var community = communities.longValue(nodeId); + var neo4jId = writeGraph.toOriginalNodeId(nodeId); + communitiesSet.put(neo4jId, community); return true; }); - assertThat(communitySet).containsExactly(3L, 6L); + + Function map = node -> communitiesSet.get(idFunction.of(node)); + + //community 1 + assertThat(map.apply("a0")) + .isEqualTo(map.apply("a2")) + .isEqualTo(map.apply("a3")) + .isEqualTo(map.apply("a4")) + .isNotEqualTo(map.apply("a1")); + + //community 2 + assertThat(map.apply("a1")) + .isEqualTo(map.apply("a5")) + .isEqualTo(map.apply("a6")) + .isEqualTo(map.apply("a7")); } + @Test void shouldWriteWithConsecutiveIds() { var query = "CALL gds.leiden.write('leiden', { writeProperty: 'communityId', consecutiveIds: true })";