diff --git a/algo/src/test/java/org/neo4j/gds/pricesteiner/GrowthPhaseTest.java b/algo/src/test/java/org/neo4j/gds/pricesteiner/GrowthPhaseTest.java index acfd5a87f6..7d1690c3ed 100644 --- a/algo/src/test/java/org/neo4j/gds/pricesteiner/GrowthPhaseTest.java +++ b/algo/src/test/java/org/neo4j/gds/pricesteiner/GrowthPhaseTest.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.neo4j.gds.Orientation; +import org.neo4j.gds.collections.ha.HugeLongArray; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.extension.GdlExtension; import org.neo4j.gds.extension.GdlGraph; @@ -30,9 +31,11 @@ import org.neo4j.gds.extension.TestGraph; import org.neo4j.gds.termination.TerminationFlag; +import java.util.List; import java.util.function.LongToDoubleFunction; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.LIST; @GdlExtension class GrowthPhaseTest { @@ -62,14 +65,22 @@ void shouldFindOptimalSolution() { var result = growthPhase.grow(); assertThat(result.activeOriginalNodes().get(graph.toMappedNodeId("a1"))).isFalse(); - assertThat(result.activeOriginalNodes().get(graph.toMappedNodeId("a2"))).isTrue(); - assertThat(result.activeOriginalNodes().get(graph.toMappedNodeId("a3"))).isTrue(); + var a2 = graph.toMappedNodeId("a2"); + var a3= graph.toMappedNodeId("a3"); + assertThat(result.activeOriginalNodes().get(a2)).isTrue(); + assertThat(result.activeOriginalNodes().get(a3)).isTrue(); assertThat(result.activeOriginalNodes().get(graph.toMappedNodeId("a4"))).isFalse(); - } + assertThat(result.numberOfTreeEdges()).isEqualTo(1L); + var treeEdges = result.treeEdges(); + var treeEdgePairs = result.edgeParts(); + var u =- treeEdgePairs.get(2*treeEdges.get(0)); + var v = -treeEdgePairs.get(2*treeEdges.get(0) + 1); + assertThat(List.of(u,v)).asInstanceOf(LIST).containsExactlyInAnyOrder(a2,a3); + } } @@ -143,7 +154,77 @@ void shouldExecuteGrowthPhaseCorrectly() { assertThat(activeNodes.get(2)).isTrue(); assertThat(activeNodes.get(3)).isTrue(); + } + + @Test + void shouldExecuteGrowthPhaseCorrectlyWithUniqueWeights() { + + HugeLongArray prizes = HugeLongArray.newArray(graph.nodeCount()); + prizes.set(graph.toMappedNodeId("a0"),9); + prizes.set(graph.toMappedNodeId("a1"),60); + prizes.set(graph.toMappedNodeId("a2"),30); + prizes.set(graph.toMappedNodeId("a3"),10); + prizes.set(graph.toMappedNodeId("a4"),110); + + var growthPhase = new GrowthPhase( + graph, + prizes::get, + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE + ); + var growthResult = growthPhase.grow(); + var clusterStructure = growthPhase.clusterStructure(); + + + assertThat(clusterStructure.inactiveSince(0)).isEqualTo(5.0); + assertThat(clusterStructure.inactiveSince(1)).isEqualTo(5.0); + + assertThat(clusterStructure.inactiveSince(2)).isEqualTo(7.5); + assertThat(clusterStructure.inactiveSince(3)).isEqualTo(7.5); + + assertThat(clusterStructure.inactiveSince(4)).isEqualTo(27); + assertThat(clusterStructure.inactiveSince(5)).isEqualTo(27); + + assertThat(clusterStructure.inactiveSince(6)).isEqualTo(31); + assertThat(clusterStructure.inactiveSince(7)).isEqualTo(31); + + + assertThat(clusterStructure.active(8)).isTrue(); + + assertThat(clusterStructure.moatAt(0,31)).isEqualTo(5); + assertThat(clusterStructure.moatAt(1,31)).isEqualTo(5); + + assertThat(clusterStructure.moatAt(2,31)).isEqualTo(7.5); + assertThat(clusterStructure.moatAt(3,31)).isEqualTo(7.5); + + assertThat(clusterStructure.moatAt(4,31)).isEqualTo(27); + + assertThat(clusterStructure.moatAt(5,31)).isEqualTo(22); + assertThat(clusterStructure.moatAt(6,31)).isEqualTo(23.5); + assertThat(clusterStructure.moatAt(7,31)).isEqualTo(4); + + + for (long u=0;u< graph.nodeCount();++u){ + assertThat(clusterStructure.sumOnEdgePart(u,31)) + .satisfies( clusterMoatPair->{ + assertThat(clusterMoatPair.totalMoat()).isEqualTo(31); + assertThat(clusterMoatPair.cluster()).isEqualTo(8); + }); + } + + + BitSet activeNodes = growthResult.activeOriginalNodes(); + + assertThat(activeNodes.cardinality()).isEqualTo(5L); + + var treeEdges = growthResult.treeEdges(); + var treeEdgeWeights = growthResult.edgeCosts(); + LongToDoubleFunction costSupplier = e -> treeEdgeWeights.get(treeEdges.get(e)); + assertThat(List.of(costSupplier.applyAsDouble(0), + costSupplier.applyAsDouble(1), + costSupplier.applyAsDouble(2), + costSupplier.applyAsDouble(3))).asInstanceOf(LIST).containsExactlyInAnyOrder(10.0,62.0,54.0,15.0); } } diff --git a/algo/src/test/java/org/neo4j/gds/pricesteiner/PCSTFastTest.java b/algo/src/test/java/org/neo4j/gds/pricesteiner/PCSTFastTest.java index 8a730934ab..09fdb05551 100644 --- a/algo/src/test/java/org/neo4j/gds/pricesteiner/PCSTFastTest.java +++ b/algo/src/test/java/org/neo4j/gds/pricesteiner/PCSTFastTest.java @@ -35,6 +35,7 @@ import org.neo4j.gds.logging.GdsTestLog; import java.util.function.LongToDoubleFunction; +import java.util.stream.LongStream; import static org.assertj.core.api.Assertions.assertThat; import static org.neo4j.gds.assertj.Extractors.removingThreadId; @@ -123,4 +124,57 @@ void shouldLogProgress() { } } + + @Nested + @GdlExtension + class HouseGraph{ + + @GdlGraph(orientation = Orientation.UNDIRECTED) + private static final String DB_CYPHER = + "CREATE " + + " (a0:node)," + + " (a1:node)," + + " (a2:node)," + + " (a3:node)," + + " (a4:node)," + + "(a0)-[:R{w:10}]->(a1)," + + "(a0)-[:R{w:72}]->(a3)," + + "(a1)-[:R{w:74}]->(a2)," + + "(a1)-[:R{w:62}]->(a3)," + + "(a1)-[:R{w:54}]->(a4)," + + "(a2)-[:R{w:15}]->(a3)," + + "(a2)-[:R{w:62}]->(a4)"; + + + @Inject + private TestGraph graph; + + @Test + void shouldFindCorrectAnswer() { + LongToDoubleFunction prizes = (x) -> 20.0; + + var pcst =new PCSTFast(graph,prizes,ProgressTracker.NULL_TRACKER); + var result =pcst.compute(); + + var a0 = graph.toMappedNodeId("a0"); + var a1 = graph.toMappedNodeId("a1"); + + var parents =result.parentArray(); + + boolean case1 = parents.get(a0) == a1 && parents.get(a1) == PrizeSteinerTreeResult.ROOT; + boolean case2 = parents.get(a1) == a0 && parents.get(a0) == PrizeSteinerTreeResult.ROOT; + + assertThat( + LongStream + .range(0, graph.nodeCount()) + .filter(v -> v != a0 && v != a1) + .map(parents::get) + .filter(v -> v != PrizeSteinerTreeResult.PRUNED) + .count()) + .isEqualTo(0l); + + assertThat(case1 ^ case2).isTrue(); + + } + } } diff --git a/algo/src/test/java/org/neo4j/gds/pricesteiner/StrongPruningTest.java b/algo/src/test/java/org/neo4j/gds/pricesteiner/StrongPruningTest.java index 5173065276..92c257e7df 100644 --- a/algo/src/test/java/org/neo4j/gds/pricesteiner/StrongPruningTest.java +++ b/algo/src/test/java/org/neo4j/gds/pricesteiner/StrongPruningTest.java @@ -37,6 +37,7 @@ import java.util.Arrays; import java.util.function.Function; +import java.util.function.LongToDoubleFunction; import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; @@ -263,5 +264,80 @@ void shouldApplyDynamicProgramming() { } + @Nested + @GdlExtension + class HouseGraph{ + + @GdlGraph(orientation = Orientation.UNDIRECTED) + private static final String DB_CYPHER = + "CREATE " + + " (a0:node)," + + " (a1:node)," + + " (a2:node)," + + " (a3:node)," + + " (a4:node)," + + "(a0)-[:R{w:10}]->(a1)," + + "(a2)-[:R{w:62}]->(a4)," + + "(a1)-[:R{w:54}]->(a4)," + + "(a2)-[:R{w:15}]->(a3),"; + + @Inject + private TestGraph graph; + + + @Test + void shouldFindOptimal(){ + HugeLongArray prizes = HugeLongArray.newArray(graph.nodeCount()); + var a0 = graph.toMappedNodeId("a0"); + var a1 = graph.toMappedNodeId("a1"); + var a2 = graph.toMappedNodeId("a2"); + var a3 = graph.toMappedNodeId("a3"); + var a4 = graph.toMappedNodeId("a4"); + + prizes.set(a0,9); + prizes.set(a1,60); + prizes.set(a2,30); + prizes.set(a3,10); + prizes.set(a4,110); + + var bitSet = new BitSet(graph.nodeCount()); + for (int i = 0; i < graph.nodeCount(); ++i) { + bitSet.set(i); + } + + HugeLongArray degrees = HugeLongArray.newArray(graph.nodeCount()); + degrees.setAll(v -> graph.degree(v)); + var strongPruning = new StrongPruning( + new TreeStructure(graph, degrees, graph.nodeCount()), + bitSet, + prizes::get, + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE + ); + strongPruning.performPruning(); + + var sp=strongPruning.resultTree(); + var parents = sp.parentArray(); + var costs = sp.relationshipToParentCost(); + LongToDoubleFunction costSupplier = e -> { + if (parents.get(e) == PrizeSteinerTreeResult.ROOT ) + return 0; + if (parents.get(e) == PrizeSteinerTreeResult.PRUNED) + return Long.MIN_VALUE; + return costs.get(e); + }; + + assertThat(parents.get(a0)).isEqualTo(PrizeSteinerTreeResult.PRUNED); + assertThat(parents.get(a2)).isEqualTo(PrizeSteinerTreeResult.PRUNED); + assertThat(parents.get(a3)).isEqualTo(PrizeSteinerTreeResult.PRUNED); + + double sum=costSupplier.applyAsDouble(a1) +costSupplier.applyAsDouble(a4); + + assertThat(sum).isEqualTo(54); + + + } + } + }