Skip to content

Commit

Permalink
Merge pull request #9725 from IoannisPanagiotas/some-more-tests
Browse files Browse the repository at this point in the history
Improve testing a bit
  • Loading branch information
IoannisPanagiotas authored Oct 10, 2024
2 parents 8dffe68 + 159c5d8 commit 89e4a4b
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 3 deletions.
87 changes: 84 additions & 3 deletions algo/src/test/java/org/neo4j/gds/pricesteiner/GrowthPhaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,19 @@
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.neo4j.gds.Orientation;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.extension.GdlExtension;
import org.neo4j.gds.extension.GdlGraph;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.TestGraph;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;
import java.util.function.LongToDoubleFunction;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.InstanceOfAssertFactories.LIST;

@GdlExtension
class GrowthPhaseTest {
Expand Down Expand Up @@ -62,14 +65,22 @@ void shouldFindOptimalSolution() {
var result = growthPhase.grow();

assertThat(result.activeOriginalNodes().get(graph.toMappedNodeId("a1"))).isFalse();
assertThat(result.activeOriginalNodes().get(graph.toMappedNodeId("a2"))).isTrue();
assertThat(result.activeOriginalNodes().get(graph.toMappedNodeId("a3"))).isTrue();
var a2 = graph.toMappedNodeId("a2");
var a3= graph.toMappedNodeId("a3");
assertThat(result.activeOriginalNodes().get(a2)).isTrue();
assertThat(result.activeOriginalNodes().get(a3)).isTrue();
assertThat(result.activeOriginalNodes().get(graph.toMappedNodeId("a4"))).isFalse();

}
assertThat(result.numberOfTreeEdges()).isEqualTo(1L);

var treeEdges = result.treeEdges();
var treeEdgePairs = result.edgeParts();

var u =- treeEdgePairs.get(2*treeEdges.get(0));
var v = -treeEdgePairs.get(2*treeEdges.get(0) + 1);
assertThat(List.of(u,v)).asInstanceOf(LIST).containsExactlyInAnyOrder(a2,a3);

}

}

Expand Down Expand Up @@ -143,7 +154,77 @@ void shouldExecuteGrowthPhaseCorrectly() {
assertThat(activeNodes.get(2)).isTrue();
assertThat(activeNodes.get(3)).isTrue();

}

@Test
void shouldExecuteGrowthPhaseCorrectlyWithUniqueWeights() {

HugeLongArray prizes = HugeLongArray.newArray(graph.nodeCount());
prizes.set(graph.toMappedNodeId("a0"),9);
prizes.set(graph.toMappedNodeId("a1"),60);
prizes.set(graph.toMappedNodeId("a2"),30);
prizes.set(graph.toMappedNodeId("a3"),10);
prizes.set(graph.toMappedNodeId("a4"),110);

var growthPhase = new GrowthPhase(
graph,
prizes::get,
ProgressTracker.NULL_TRACKER,
TerminationFlag.RUNNING_TRUE
);
var growthResult = growthPhase.grow();
var clusterStructure = growthPhase.clusterStructure();


assertThat(clusterStructure.inactiveSince(0)).isEqualTo(5.0);
assertThat(clusterStructure.inactiveSince(1)).isEqualTo(5.0);

assertThat(clusterStructure.inactiveSince(2)).isEqualTo(7.5);
assertThat(clusterStructure.inactiveSince(3)).isEqualTo(7.5);

assertThat(clusterStructure.inactiveSince(4)).isEqualTo(27);
assertThat(clusterStructure.inactiveSince(5)).isEqualTo(27);

assertThat(clusterStructure.inactiveSince(6)).isEqualTo(31);
assertThat(clusterStructure.inactiveSince(7)).isEqualTo(31);


assertThat(clusterStructure.active(8)).isTrue();

assertThat(clusterStructure.moatAt(0,31)).isEqualTo(5);
assertThat(clusterStructure.moatAt(1,31)).isEqualTo(5);

assertThat(clusterStructure.moatAt(2,31)).isEqualTo(7.5);
assertThat(clusterStructure.moatAt(3,31)).isEqualTo(7.5);

assertThat(clusterStructure.moatAt(4,31)).isEqualTo(27);

assertThat(clusterStructure.moatAt(5,31)).isEqualTo(22);
assertThat(clusterStructure.moatAt(6,31)).isEqualTo(23.5);
assertThat(clusterStructure.moatAt(7,31)).isEqualTo(4);


for (long u=0;u< graph.nodeCount();++u){
assertThat(clusterStructure.sumOnEdgePart(u,31))
.satisfies( clusterMoatPair->{
assertThat(clusterMoatPair.totalMoat()).isEqualTo(31);
assertThat(clusterMoatPair.cluster()).isEqualTo(8);
});
}


BitSet activeNodes = growthResult.activeOriginalNodes();

assertThat(activeNodes.cardinality()).isEqualTo(5L);

var treeEdges = growthResult.treeEdges();
var treeEdgeWeights = growthResult.edgeCosts();

LongToDoubleFunction costSupplier = e -> treeEdgeWeights.get(treeEdges.get(e));
assertThat(List.of(costSupplier.applyAsDouble(0),
costSupplier.applyAsDouble(1),
costSupplier.applyAsDouble(2),
costSupplier.applyAsDouble(3))).asInstanceOf(LIST).containsExactlyInAnyOrder(10.0,62.0,54.0,15.0);

}
}
Expand Down
54 changes: 54 additions & 0 deletions algo/src/test/java/org/neo4j/gds/pricesteiner/PCSTFastTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.neo4j.gds.logging.GdsTestLog;

import java.util.function.LongToDoubleFunction;
import java.util.stream.LongStream;

import static org.assertj.core.api.Assertions.assertThat;
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
Expand Down Expand Up @@ -123,4 +124,57 @@ void shouldLogProgress() {
}

}

@Nested
@GdlExtension
class HouseGraph{

@GdlGraph(orientation = Orientation.UNDIRECTED)
private static final String DB_CYPHER =
"CREATE " +
" (a0:node)," +
" (a1:node)," +
" (a2:node)," +
" (a3:node)," +
" (a4:node)," +
"(a0)-[:R{w:10}]->(a1)," +
"(a0)-[:R{w:72}]->(a3)," +
"(a1)-[:R{w:74}]->(a2)," +
"(a1)-[:R{w:62}]->(a3)," +
"(a1)-[:R{w:54}]->(a4)," +
"(a2)-[:R{w:15}]->(a3)," +
"(a2)-[:R{w:62}]->(a4)";


@Inject
private TestGraph graph;

@Test
void shouldFindCorrectAnswer() {
LongToDoubleFunction prizes = (x) -> 20.0;

var pcst =new PCSTFast(graph,prizes,ProgressTracker.NULL_TRACKER);
var result =pcst.compute();

var a0 = graph.toMappedNodeId("a0");
var a1 = graph.toMappedNodeId("a1");

var parents =result.parentArray();

boolean case1 = parents.get(a0) == a1 && parents.get(a1) == PrizeSteinerTreeResult.ROOT;
boolean case2 = parents.get(a1) == a0 && parents.get(a0) == PrizeSteinerTreeResult.ROOT;

assertThat(
LongStream
.range(0, graph.nodeCount())
.filter(v -> v != a0 && v != a1)
.map(parents::get)
.filter(v -> v != PrizeSteinerTreeResult.PRUNED)
.count())
.isEqualTo(0l);

assertThat(case1 ^ case2).isTrue();

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

import java.util.Arrays;
import java.util.function.Function;
import java.util.function.LongToDoubleFunction;
import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -263,5 +264,80 @@ void shouldApplyDynamicProgramming() {

}

@Nested
@GdlExtension
class HouseGraph{

@GdlGraph(orientation = Orientation.UNDIRECTED)
private static final String DB_CYPHER =
"CREATE " +
" (a0:node)," +
" (a1:node)," +
" (a2:node)," +
" (a3:node)," +
" (a4:node)," +
"(a0)-[:R{w:10}]->(a1)," +
"(a2)-[:R{w:62}]->(a4)," +
"(a1)-[:R{w:54}]->(a4)," +
"(a2)-[:R{w:15}]->(a3),";

@Inject
private TestGraph graph;


@Test
void shouldFindOptimal(){
HugeLongArray prizes = HugeLongArray.newArray(graph.nodeCount());
var a0 = graph.toMappedNodeId("a0");
var a1 = graph.toMappedNodeId("a1");
var a2 = graph.toMappedNodeId("a2");
var a3 = graph.toMappedNodeId("a3");
var a4 = graph.toMappedNodeId("a4");

prizes.set(a0,9);
prizes.set(a1,60);
prizes.set(a2,30);
prizes.set(a3,10);
prizes.set(a4,110);

var bitSet = new BitSet(graph.nodeCount());
for (int i = 0; i < graph.nodeCount(); ++i) {
bitSet.set(i);
}

HugeLongArray degrees = HugeLongArray.newArray(graph.nodeCount());
degrees.setAll(v -> graph.degree(v));
var strongPruning = new StrongPruning(
new TreeStructure(graph, degrees, graph.nodeCount()),
bitSet,
prizes::get,
ProgressTracker.NULL_TRACKER,
TerminationFlag.RUNNING_TRUE
);
strongPruning.performPruning();

var sp=strongPruning.resultTree();
var parents = sp.parentArray();
var costs = sp.relationshipToParentCost();
LongToDoubleFunction costSupplier = e -> {
if (parents.get(e) == PrizeSteinerTreeResult.ROOT )
return 0;
if (parents.get(e) == PrizeSteinerTreeResult.PRUNED)
return Long.MIN_VALUE;
return costs.get(e);
};

assertThat(parents.get(a0)).isEqualTo(PrizeSteinerTreeResult.PRUNED);
assertThat(parents.get(a2)).isEqualTo(PrizeSteinerTreeResult.PRUNED);
assertThat(parents.get(a3)).isEqualTo(PrizeSteinerTreeResult.PRUNED);

double sum=costSupplier.applyAsDouble(a1) +costSupplier.applyAsDouble(a4);

assertThat(sum).isEqualTo(54);


}
}


}

0 comments on commit 89e4a4b

Please sign in to comment.