diff --git a/core/src/main/java/org/neo4j/gds/core/huge/NodeFilteredGraph.java b/core/src/main/java/org/neo4j/gds/core/huge/NodeFilteredGraph.java index fea21b13c7..bdf953c3c6 100644 --- a/core/src/main/java/org/neo4j/gds/core/huge/NodeFilteredGraph.java +++ b/core/src/main/java/org/neo4j/gds/core/huge/NodeFilteredGraph.java @@ -20,6 +20,7 @@ package org.neo4j.gds.core.huge; import org.apache.commons.lang3.mutable.MutableInt; +import org.jetbrains.annotations.Nullable; import org.neo4j.gds.NodeLabel; import org.neo4j.gds.api.CSRGraph; import org.neo4j.gds.api.CSRGraphAdapter; @@ -32,10 +33,10 @@ import org.neo4j.gds.api.RelationshipWithPropertyConsumer; import org.neo4j.gds.api.properties.nodes.NodePropertyValues; import org.neo4j.gds.api.schema.GraphSchema; +import org.neo4j.gds.collections.ha.HugeIntArray; import org.neo4j.gds.collections.primitive.PrimitiveLongIterable; import org.neo4j.gds.config.ConcurrencyConfig; import org.neo4j.gds.core.concurrency.RunWithConcurrency; -import org.neo4j.gds.collections.ha.HugeIntArray; import org.neo4j.gds.core.utils.partition.Partition; import org.neo4j.gds.core.utils.partition.PartitionUtils; import org.neo4j.gds.utils.CloseableThreadLocal; @@ -56,16 +57,30 @@ public class NodeFilteredGraph extends CSRGraphAdapter implements FilteredIdMap private final FilteredIdMap filteredIdMap; private long relationshipCount; private final HugeIntArray degreeCache; + private final HugeIntArray degreeInverseCache; private final CloseableThreadLocal threadLocalGraph; public NodeFilteredGraph(CSRGraph originalGraph, FilteredIdMap filteredIdMap) { - this(originalGraph, filteredIdMap, emptyDegreeCache(filteredIdMap), -1); + this( + originalGraph, + filteredIdMap, + emptyDegreeCache(filteredIdMap), + originalGraph.characteristics().isInverseIndexed() ? emptyDegreeCache(filteredIdMap) : null, + -1 + ); } - private NodeFilteredGraph(CSRGraph originalGraph, FilteredIdMap filteredIdMap, HugeIntArray degreeCache, long relationshipCount) { + private NodeFilteredGraph( + CSRGraph originalGraph, + FilteredIdMap filteredIdMap, + HugeIntArray degreeCache, + @Nullable HugeIntArray degreeInverseCache, + long relationshipCount + ) { super(originalGraph); this.degreeCache = degreeCache; + this.degreeInverseCache = degreeInverseCache; this.filteredIdMap = filteredIdMap; this.relationshipCount = relationshipCount; this.threadLocalGraph = CloseableThreadLocal.withInitial(this::concurrentCopy); @@ -104,18 +119,18 @@ public void forEachNode(LongPredicate consumer) { @Override public int degree(long nodeId) { - int cachedDegree = degreeCache.get(nodeId); + int cachedDegree = this.degreeCache.get(nodeId); if (cachedDegree != NO_DEGREE) { return cachedDegree; } var degree = new MutableInt(); - threadLocalGraph.get().forEachRelationship(nodeId, (s, t) -> { + this.threadLocalGraph.get().forEachRelationship(nodeId, (s, t) -> { degree.increment(); return true; }); - degreeCache.set(nodeId, degree.intValue()); + this.degreeCache.set(nodeId, degree.intValue()); return degree.intValue(); } @@ -130,6 +145,24 @@ public int degreeWithoutParallelRelationships(long nodeId) { return degreeCounter.degree; } + @Override + public int degreeInverse(long nodeId) { + int cachedDegree = this.degreeInverseCache.get(nodeId); + if (cachedDegree != NO_DEGREE) { + return cachedDegree; + } + + var degree = new MutableInt(); + + this.threadLocalGraph.get().forEachInverseRelationship(nodeId, (s, t) -> { + degree.increment(); + return true; + }); + this.degreeInverseCache.set(nodeId, degree.intValue()); + + return degree.intValue(); + } + @Override public long nodeCount() { return filteredIdMap.nodeCount(); @@ -218,9 +251,30 @@ public void forEachRelationship(long nodeId, double fallbackValue, RelationshipW ); } + @Override + public void forEachInverseRelationship(long nodeId, RelationshipConsumer consumer) { + super.forEachInverseRelationship( + filteredIdMap.toRootNodeId(nodeId), + (s, t) -> filterAndConsume(s, t, consumer) + ); + } + + @Override + public void forEachInverseRelationship( + long nodeId, + double fallbackValue, + RelationshipWithPropertyConsumer consumer + ) { + super.forEachInverseRelationship( + filteredIdMap.toRootNodeId(nodeId), + fallbackValue, + (s, t, p) -> filterAndConsume(s, t, p, consumer) + ); + } + @Override public Stream streamRelationships(long nodeId, double fallbackValue) { - if (! filteredIdMap.containsRootNodeId(filteredIdMap.toRootNodeId(nodeId))) { + if (!filteredIdMap.containsRootNodeId(filteredIdMap.toRootNodeId(nodeId))) { return Stream.empty(); } @@ -266,7 +320,13 @@ public double relationshipProperty(long sourceNodeId, long targetNodeId) { @Override public CSRGraph concurrentCopy() { - return new NodeFilteredGraph(csrGraph.concurrentCopy(), filteredIdMap, degreeCache, relationshipCount); + return new NodeFilteredGraph( + csrGraph.concurrentCopy(), + filteredIdMap, + degreeCache, + degreeInverseCache, + relationshipCount + ); } @Override @@ -312,7 +372,12 @@ private boolean filterAndConsume(long source, long target, RelationshipConsumer return true; } - private boolean filterAndConsume(long source, long target, double propertyValue, RelationshipWithPropertyConsumer consumer) { + private boolean filterAndConsume( + long source, + long target, + double propertyValue, + RelationshipWithPropertyConsumer consumer + ) { if (filteredIdMap.containsRootNodeId(source) && filteredIdMap.containsRootNodeId(target)) { long internalSourceId = filteredIdMap.toFilteredNodeId(source); long internalTargetId = filteredIdMap.toFilteredNodeId(target); diff --git a/core/src/test/java/org/neo4j/gds/core/huge/NodeFilteredGraphTest.java b/core/src/test/java/org/neo4j/gds/core/huge/NodeFilteredGraphTest.java index 2aaabe63fc..94548b5ddd 100644 --- a/core/src/test/java/org/neo4j/gds/core/huge/NodeFilteredGraphTest.java +++ b/core/src/test/java/org/neo4j/gds/core/huge/NodeFilteredGraphTest.java @@ -44,16 +44,18 @@ @GdlExtension class NodeFilteredGraphTest { - @GdlGraph(idOffset = 1337) + @GdlGraph(idOffset = 1337, indexInverse = true) static String GDL = " (x:Ignore)," + " (a:Person)," + " (b:Ignore:Person)," + " (c:Ignore:Person)," + " (d:Person)," + " (e:Ignore)," + + " (x)-->(d)," + // ignored for index inverse " (a)-->(b)," + " (a)-->(e)," + " (b)-->(c)," + + " (x)-->(c)," + // ignored for index inverse " (b)-->(d)," + " (c)-->(e)"; @@ -122,6 +124,32 @@ void filterDegreeWithoutParallelRelationships() { assertThat(graph.degreeWithoutParallelRelationships(filteredIdFunction(graph).apply("a"))).isEqualTo(1L); } + @Test + void filterDegreeInverse() { + var graph = graphStore.getGraph( + NodeLabel.of("Person"), + RelationshipType.ALL_RELATIONSHIPS, + Optional.empty() + ); + + assertThat(graph.degreeInverse(filteredIdFunction(graph).apply("d"))).isEqualTo(1L); + } + + @Test + void foreachInverseRelationship() { + var graph = graphStore.getGraph( + NodeLabel.of("Person"), + RelationshipType.ALL_RELATIONSHIPS, + Optional.empty() + ); + + graph.forEachInverseRelationship(filteredIdFunction(graph).apply("d"), (source, target) -> { + assertThat(source).isEqualTo(filteredIdFunction(graph).apply("d")); + assertThat(target).isEqualTo(filteredIdFunction(graph).apply("b")); + return true; + }); + } + @Test void filterStreamRelationships() { var graph = graphStore.getGraph(