Skip to content

Commit

Permalink
Add support for inverse index in NodeFilteredGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
s1ck committed Dec 21, 2023
1 parent 15d3ffe commit f87bac6
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 10 deletions.
83 changes: 74 additions & 9 deletions core/src/main/java/org/neo4j/gds/core/huge/NodeFilteredGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.neo4j.gds.core.huge;

import org.apache.commons.lang3.mutable.MutableInt;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.api.CSRGraph;
import org.neo4j.gds.api.CSRGraphAdapter;
Expand All @@ -32,10 +33,10 @@
import org.neo4j.gds.api.RelationshipWithPropertyConsumer;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.api.schema.GraphSchema;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.primitive.PrimitiveLongIterable;
import org.neo4j.gds.config.ConcurrencyConfig;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.utils.CloseableThreadLocal;
Expand All @@ -56,16 +57,30 @@ public class NodeFilteredGraph extends CSRGraphAdapter implements FilteredIdMap
private final FilteredIdMap filteredIdMap;
private long relationshipCount;
private final HugeIntArray degreeCache;
private final HugeIntArray degreeInverseCache;
private final CloseableThreadLocal<Graph> threadLocalGraph;

public NodeFilteredGraph(CSRGraph originalGraph, FilteredIdMap filteredIdMap) {
this(originalGraph, filteredIdMap, emptyDegreeCache(filteredIdMap), -1);
this(
originalGraph,
filteredIdMap,
emptyDegreeCache(filteredIdMap),
originalGraph.characteristics().isInverseIndexed() ? emptyDegreeCache(filteredIdMap) : null,
-1
);
}

private NodeFilteredGraph(CSRGraph originalGraph, FilteredIdMap filteredIdMap, HugeIntArray degreeCache, long relationshipCount) {
private NodeFilteredGraph(
CSRGraph originalGraph,
FilteredIdMap filteredIdMap,
HugeIntArray degreeCache,
@Nullable HugeIntArray degreeInverseCache,
long relationshipCount
) {
super(originalGraph);

this.degreeCache = degreeCache;
this.degreeInverseCache = degreeInverseCache;
this.filteredIdMap = filteredIdMap;
this.relationshipCount = relationshipCount;
this.threadLocalGraph = CloseableThreadLocal.withInitial(this::concurrentCopy);
Expand Down Expand Up @@ -104,18 +119,18 @@ public void forEachNode(LongPredicate consumer) {

@Override
public int degree(long nodeId) {
int cachedDegree = degreeCache.get(nodeId);
int cachedDegree = this.degreeCache.get(nodeId);
if (cachedDegree != NO_DEGREE) {
return cachedDegree;
}

var degree = new MutableInt();

threadLocalGraph.get().forEachRelationship(nodeId, (s, t) -> {
this.threadLocalGraph.get().forEachRelationship(nodeId, (s, t) -> {
degree.increment();
return true;
});
degreeCache.set(nodeId, degree.intValue());
this.degreeCache.set(nodeId, degree.intValue());

return degree.intValue();
}
Expand All @@ -130,6 +145,24 @@ public int degreeWithoutParallelRelationships(long nodeId) {
return degreeCounter.degree;
}

@Override
public int degreeInverse(long nodeId) {
int cachedDegree = this.degreeInverseCache.get(nodeId);
if (cachedDegree != NO_DEGREE) {
return cachedDegree;
}

var degree = new MutableInt();

this.threadLocalGraph.get().forEachInverseRelationship(nodeId, (s, t) -> {
degree.increment();
return true;
});
this.degreeInverseCache.set(nodeId, degree.intValue());

return degree.intValue();
}

@Override
public long nodeCount() {
return filteredIdMap.nodeCount();
Expand Down Expand Up @@ -218,9 +251,30 @@ public void forEachRelationship(long nodeId, double fallbackValue, RelationshipW
);
}

@Override
public void forEachInverseRelationship(long nodeId, RelationshipConsumer consumer) {
super.forEachInverseRelationship(
filteredIdMap.toRootNodeId(nodeId),
(s, t) -> filterAndConsume(s, t, consumer)
);
}

@Override
public void forEachInverseRelationship(
long nodeId,
double fallbackValue,
RelationshipWithPropertyConsumer consumer
) {
super.forEachInverseRelationship(
filteredIdMap.toRootNodeId(nodeId),
fallbackValue,
(s, t, p) -> filterAndConsume(s, t, p, consumer)
);
}

@Override
public Stream<RelationshipCursor> streamRelationships(long nodeId, double fallbackValue) {
if (! filteredIdMap.containsRootNodeId(filteredIdMap.toRootNodeId(nodeId))) {
if (!filteredIdMap.containsRootNodeId(filteredIdMap.toRootNodeId(nodeId))) {
return Stream.empty();
}

Expand Down Expand Up @@ -266,7 +320,13 @@ public double relationshipProperty(long sourceNodeId, long targetNodeId) {

@Override
public CSRGraph concurrentCopy() {
return new NodeFilteredGraph(csrGraph.concurrentCopy(), filteredIdMap, degreeCache, relationshipCount);
return new NodeFilteredGraph(
csrGraph.concurrentCopy(),
filteredIdMap,
degreeCache,
degreeInverseCache,
relationshipCount
);
}

@Override
Expand Down Expand Up @@ -312,7 +372,12 @@ private boolean filterAndConsume(long source, long target, RelationshipConsumer
return true;
}

private boolean filterAndConsume(long source, long target, double propertyValue, RelationshipWithPropertyConsumer consumer) {
private boolean filterAndConsume(
long source,
long target,
double propertyValue,
RelationshipWithPropertyConsumer consumer
) {
if (filteredIdMap.containsRootNodeId(source) && filteredIdMap.containsRootNodeId(target)) {
long internalSourceId = filteredIdMap.toFilteredNodeId(source);
long internalTargetId = filteredIdMap.toFilteredNodeId(target);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,18 @@
@GdlExtension
class NodeFilteredGraphTest {

@GdlGraph(idOffset = 1337)
@GdlGraph(idOffset = 1337, indexInverse = true)
static String GDL = " (x:Ignore)," +
" (a:Person)," +
" (b:Ignore:Person)," +
" (c:Ignore:Person)," +
" (d:Person)," +
" (e:Ignore)," +
" (x)-->(d)," + // ignored for index inverse
" (a)-->(b)," +
" (a)-->(e)," +
" (b)-->(c)," +
" (x)-->(c)," + // ignored for index inverse
" (b)-->(d)," +
" (c)-->(e)";

Expand Down Expand Up @@ -122,6 +124,32 @@ void filterDegreeWithoutParallelRelationships() {
assertThat(graph.degreeWithoutParallelRelationships(filteredIdFunction(graph).apply("a"))).isEqualTo(1L);
}

@Test
void filterDegreeInverse() {
var graph = graphStore.getGraph(
NodeLabel.of("Person"),
RelationshipType.ALL_RELATIONSHIPS,
Optional.empty()
);

assertThat(graph.degreeInverse(filteredIdFunction(graph).apply("d"))).isEqualTo(1L);
}

@Test
void foreachInverseRelationship() {
var graph = graphStore.getGraph(
NodeLabel.of("Person"),
RelationshipType.ALL_RELATIONSHIPS,
Optional.empty()
);

graph.forEachInverseRelationship(filteredIdFunction(graph).apply("d"), (source, target) -> {
assertThat(source).isEqualTo(filteredIdFunction(graph).apply("d"));
assertThat(target).isEqualTo(filteredIdFunction(graph).apply("b"));
return true;
});
}

@Test
void filterStreamRelationships() {
var graph = graphStore.getGraph(
Expand Down

0 comments on commit f87bac6

Please sign in to comment.