diff --git a/algo/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityFactory.java b/algo/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityFactory.java
index 95c0b5bd4f..710de17fe9 100644
--- a/algo/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityFactory.java
+++ b/algo/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityFactory.java
@@ -24,6 +24,7 @@
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.ha.HugeObjectArray;
+import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
@@ -74,7 +75,6 @@ public MemoryEstimation memoryEstimation(CONFIG config) {
int topK = Math.abs(config.normalizedK());
MemoryEstimations.Builder builder = MemoryEstimations.builder(NodeSimilarity.class.getSimpleName())
- .perNode("components", HugeLongArray::memoryEstimation)
.perNode("node filter", nodeCount -> sizeOfLongArray(BitSet.bits2words(nodeCount)))
.add(
"vectors",
@@ -97,6 +97,13 @@ public MemoryEstimation memoryEstimation(CONFIG config) {
.rangePerNode("array", nodeCount -> MemoryRange.of(0, nodeCount * averageVectorSize))
.build();
}));
+ if (config.considerComponents()) {
+ builder.perNode("nodes sorted by component", HugeLongArray::memoryEstimation);
+ builder.perNode("upper bound per component", HugeAtomicLongArray::memoryEstimation);
+ }
+ if (config.considerComponents() && config.componentProperty() != null) {
+ builder.perNode("component mapping", HugeLongArray::memoryEstimation);
+ }
if (config.computeToGraph() && !config.hasTopK()) {
builder.add(
"similarity graph",
diff --git a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/ComponentNodes.java b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/ComponentNodes.java
new file mode 100644
index 0000000000..8c8683464a
--- /dev/null
+++ b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/ComponentNodes.java
@@ -0,0 +1,170 @@
+/*
+ * Copyright (c) "Neo4j"
+ * Neo4j Sweden AB [http://neo4j.com]
+ *
+ * This file is part of Neo4j.
+ *
+ * Neo4j is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+package org.neo4j.gds.similarity.nodesim;
+
+import org.neo4j.gds.collections.ha.HugeLongArray;
+import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
+import org.neo4j.gds.core.concurrency.ParallelUtil;
+import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
+import org.neo4j.gds.termination.TerminationFlag;
+
+import java.util.NoSuchElementException;
+import java.util.PrimitiveIterator;
+import java.util.Spliterator;
+import java.util.Spliterators;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.LongUnaryOperator;
+
+/**
+ * Manages nodes sorted by component. Produces an iterator over all nodes in a given component.
+ */
+public final class ComponentNodes {
+ private final LongUnaryOperator components;
+ private final HugeAtomicLongArray upperBoundPerComponent;
+ private final HugeLongArray nodesSorted;
+
+ private ComponentNodes(LongUnaryOperator components, HugeAtomicLongArray upperBoundPerComponent,
+ HugeLongArray nodesSorted) {
+
+ this.components = components;
+ this.upperBoundPerComponent = upperBoundPerComponent;
+ this.nodesSorted = nodesSorted;
+ }
+
+ public static ComponentNodes create(LongUnaryOperator components, long nodeCount, int concurrency) {
+ var upperBoundPerComponent = computeIndexUpperBoundPerComponent(components, nodeCount, concurrency);
+ var nodesSorted = computeNodesSortedByComponent(components, upperBoundPerComponent, concurrency);
+ return new ComponentNodes(
+ components,
+ upperBoundPerComponent,
+ nodesSorted
+ );
+ }
+
+ public PrimitiveIterator.OfLong iterator(long componentId, long offset) {
+ return new Iterator(componentId, offset);
+ }
+
+ public Spliterator.OfLong spliterator(long componentId, long offset) {
+ return Spliterators.spliteratorUnknownSize(
+ iterator(componentId, offset),
+ Spliterator.ORDERED | Spliterator.SORTED | Spliterator.IMMUTABLE | Spliterator.NONNULL | Spliterator.DISTINCT
+ );
+ }
+
+ LongUnaryOperator getComponents() {
+ return components;
+ }
+
+ HugeAtomicLongArray getUpperBoundPerComponent() {
+ return upperBoundPerComponent;
+ }
+
+ HugeLongArray getNodesSorted() {
+ return nodesSorted;
+ }
+
+ static HugeAtomicLongArray computeIndexUpperBoundPerComponent(LongUnaryOperator components, long nodeCount,
+ int concurrency) {
+
+ var upperBoundPerComponent = HugeAtomicLongArray.of(nodeCount, ParalleLongPageCreator.passThrough(concurrency));
+
+ // init coordinate array to contain the nr of nodes per component
+ // i.e. comp1 containing 3 nodes, comp2 containing 20 nodes: {(comp1, 3), (comp2, 20)}
+ ParallelUtil.parallelForEachNode(nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, nodeId -> {
+ {
+ long componentId = components.applyAsLong(nodeId);
+ upperBoundPerComponent.getAndAdd(componentId, 1);
+ }
+ });
+ AtomicLong atomicNodeSum = new AtomicLong();
+ // modify coordinate array to contain the upper bound of the global index for each component
+ // i.e. comp1 containing 3 nodes, comp2 containing 20 nodes, comp1 randomly accessed prior to comp2:
+ // {(comp1, 2), (comp2, 22)}
+ ParallelUtil.parallelForEachNode(nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, componentId ->
+ {
+ if (upperBoundPerComponent.get(componentId) > 0) {
+ var nodeSum = atomicNodeSum.addAndGet(upperBoundPerComponent.get(componentId));
+ upperBoundPerComponent.set(componentId, nodeSum - 1);
+ }
+ });
+
+ return upperBoundPerComponent;
+ }
+
+ static HugeLongArray computeNodesSortedByComponent(LongUnaryOperator components,
+ HugeAtomicLongArray idxUpperBoundPerComponent, int concurrency) {
+
+ // initialized to its max possible size of 1 node <=> 1 component in a disconnected graph
+ long nodeCount = idxUpperBoundPerComponent.size();
+ var nodesSortedByComponent = HugeLongArray.newArray(nodeCount);
+ var nodeIdxProviderArray = HugeAtomicLongArray.of(nodeCount, ParalleLongPageCreator.passThrough(concurrency));
+ idxUpperBoundPerComponent.copyTo(nodeIdxProviderArray, nodeCount);
+
+ // fill nodesSortedByComponent with nodeId per component-sorted, unique index
+ // i.e. comp1 containing 3 nodes, comp2 containing 20 nodes, named in order of processing:
+ // {(0, n3), (1, n2), (2, n1), (3, n23), .., (22, n4)}
+ ParallelUtil.parallelForEachNode(nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, indexId ->
+ {
+ long nodeId = nodeCount - indexId - 1;
+ long componentId = components.applyAsLong(nodeId);
+ long nodeIdx = nodeIdxProviderArray.getAndAdd(componentId, -1);
+ nodesSortedByComponent.set(nodeIdx, nodeId);
+ });
+
+ return nodesSortedByComponent;
+ }
+
+ private final class Iterator implements PrimitiveIterator.OfLong {
+ private final long offset;
+ long runningIdx;
+ final long componentId;
+
+ Iterator(long componentId, long offset) {
+ this.componentId = componentId;
+ this.runningIdx = getUpperBoundPerComponent().get(componentId);
+ this.offset = offset;
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (offset < 1L) {
+ return runningIdx > -1 && getComponents().applyAsLong(getNodesSorted().get(runningIdx)) == componentId;
+ } else {
+ while (runningIdx > -1 && getComponents().applyAsLong(getNodesSorted().get(runningIdx)) == componentId) {
+ if (getNodesSorted().get(runningIdx) < offset) {
+ runningIdx--;
+ } else {
+ return true;
+ }
+ }
+ return false;
+ }
+ }
+ @Override
+ public long nextLong() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+ return getNodesSorted().get(runningIdx--);
+ }
+
+ }
+}
diff --git a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarity.java b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarity.java
index a787ff5f91..4c1a0b43f3 100644
--- a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarity.java
+++ b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarity.java
@@ -28,6 +28,7 @@
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.SetBitsIterable;
+import org.neo4j.gds.core.utils.paged.HugeLongLongMap;
import org.neo4j.gds.core.utils.paged.dss.DisjointSetStruct;
import org.neo4j.gds.core.utils.progress.BatchingProgressLogger;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
@@ -43,14 +44,20 @@
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+import java.util.function.LongUnaryOperator;
import java.util.stream.LongStream;
import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
public class NodeSimilarity extends Algorithm {
private final Graph graph;
- private final boolean sortVectors;
private final NodeSimilarityBaseConfig config;
+ private final boolean sortVectors;
+ private final boolean weighted;
private final BitSet sourceNodes;
private final BitSet targetNodes;
@@ -60,12 +67,12 @@ public class NodeSimilarity extends Algorithm {
private final ExecutorService executorService;
private final int concurrency;
private final MetricSimilarityComputer similarityComputer;
+
private HugeObjectArray neighbors;
private HugeObjectArray weights;
- private HugeLongArray components;
- private SimilarityPairTriConsumer similarityConsumer;
-
- private final boolean weighted;
+ private LongUnaryOperator components;
+ private Function sourceNodesStream;
+ private BiFunction targetNodesStream;
public static NodeSimilarity create(
Graph graph,
@@ -197,47 +204,16 @@ private SimilarityGraphResult computeToGraph() {
private void prepare() {
progressTracker.beginSubTask();
- initComponents();
+ components = initComponents();
+ sourceNodesStream = initSourceNodesStream();
+ targetNodesStream = initTargetNodesStream();
if (config.runWCC()) {
progressTracker.beginSubTask();
}
- neighbors = HugeObjectArray.newArray(long[].class, graph.nodeCount());
- if (weighted) {
- weights = HugeObjectArray.newArray(double[].class, graph.nodeCount());
- }
- DegreeComputer degreeComputer = new DegreeComputer();
- VectorComputer vectorComputer = VectorComputer.of(graph, weighted);
- DegreeFilter degreeFilter = new DegreeFilter(config.degreeCutoff(), config.upperDegreeCutoff());
- neighbors.setAll(node -> {
- graph.forEachRelationship(node, degreeComputer);
- int degree = degreeComputer.degree;
- degreeComputer.reset();
- vectorComputer.reset(degree);
+ initNodeSpecificFields();
- progressTracker.logProgress(graph.degree(node));
- if (degreeFilter.apply(degree)) {
- if (sourceNodeFilter.test(node)) {
- sourceNodes.set(node);
- }
- if (targetNodeFilter.test(node)) {
- targetNodes.set(node);
- }
-
- // TODO: we don't need to do the rest of the prepare for a node that isn't going to be used in the computation
- vectorComputer.forEachRelationship(node);
-
- if (sortVectors) {
- vectorComputer.sortTargetIds();
- }
- if (weighted) {
- weights.set(node, vectorComputer.getWeights());
- }
- return vectorComputer.targetIds.buffer;
- }
- return null;
- });
if (config.runWCC()) {
progressTracker.endSubTask();
}
@@ -262,24 +238,16 @@ private Stream computeParallel() {
}
}
- private void initComponents() {
- components = HugeLongArray.newArray(graph.nodeCount());
- if (!config.isEnableComponentOptimization()) {
+ private LongUnaryOperator initComponents() {
+ if (!config.considerComponents()) {
// considering everything as within the same component
- similarityConsumer = this::computeSimilarityForSingleComponent;
- return;
+ return n -> 0;
}
- similarityConsumer = this::computeSimilarityForComponents;
-
if (config.componentProperty() != null) {
- NodePropertyValues nodeProperties = graph.nodeProperties(config.componentProperty());
// extract component info from property
- graph.forEachNode(n -> {
- components.set(n, nodeProperties.longValue(n));
- return true;
- });
- return;
+ NodePropertyValues nodeProperties = graph.nodeProperties(config.componentProperty());
+ return initComponentIdMapping(graph, nodeProperties::longValue);
}
// run WCC to determine components
@@ -293,17 +261,53 @@ private void initComponents() {
Wcc wcc = new WccAlgorithmFactory<>().build(graph, wccConfig, ProgressTracker.NULL_TRACKER);
DisjointSetStruct disjointSets = wcc.compute();
- graph.forEachNode(n -> {
- components.set(n, disjointSets.setIdOf(n));
- return true;
- });
progressTracker.endSubTask();
+ return disjointSets::setIdOf;
+ }
+
+ private void initNodeSpecificFields() {
+ neighbors = HugeObjectArray.newArray(long[].class, graph.nodeCount());
+ if (weighted) {
+ weights = HugeObjectArray.newArray(double[].class, graph.nodeCount());
+ }
+
+ DegreeComputer degreeComputer = new DegreeComputer();
+ VectorComputer vectorComputer = VectorComputer.of(graph, weighted);
+ DegreeFilter degreeFilter = new DegreeFilter(config.degreeCutoff(), config.upperDegreeCutoff());
+ neighbors.setAll(node -> {
+ graph.forEachRelationship(node, degreeComputer);
+ int degree = degreeComputer.degree;
+ degreeComputer.reset();
+ vectorComputer.reset(degree);
+
+ progressTracker.logProgress(graph.degree(node));
+ if (degreeFilter.apply(degree)) {
+ if (sourceNodeFilter.test(node)) {
+ sourceNodes.set(node);
+ }
+ if (targetNodeFilter.test(node)) {
+ targetNodes.set(node);
+ }
+
+ // TODO: we don't need to do the rest of the prepare for a node that isn't going to be used in the computation
+ vectorComputer.forEachRelationship(node);
+
+ if (sortVectors) {
+ vectorComputer.sortTargetIds();
+ }
+ if (weighted) {
+ weights.set(node, vectorComputer.getWeights());
+ }
+ return vectorComputer.targetIds.buffer;
+ }
+ return null;
+ });
}
private Stream computeAll() {
progressTracker.beginSubTask(calculateWorkload());
- var similarityResultStream = loggableAndTerminatableSourceNodeStream()
+ var similarityResultStream = loggableAndTerminableSourceNodeStream()
.boxed()
.flatMap(this::computeSimilaritiesForNode);
progressTracker.endSubTask();
@@ -312,7 +316,7 @@ private Stream computeAll() {
private Stream computeAllParallel() {
return ParallelUtil.parallelStream(
- loggableAndTerminatableSourceNodeStream(), concurrency, stream -> stream
+ loggableAndTerminableSourceNodeStream(), concurrency, stream -> stream
.boxed()
.flatMap(this::computeSimilaritiesForNode)
);
@@ -326,20 +330,20 @@ private TopKMap computeTopKMap() {
: SimilarityResult.ASCENDING;
var topKMap = new TopKMap(neighbors.size(), sourceNodes, Math.abs(config.normalizedK()), comparator);
- loggableAndTerminatableSourceNodeStream()
+ loggableAndTerminableSourceNodeStream()
.forEach(sourceNodeId -> {
if (sourceNodeFilter.equals(NodeFilter.noOp)) {
- targetNodesStream(sourceNodeId + 1)
- .forEach(targetNodeId -> similarityConsumer.accept(sourceNodeId, targetNodeId,
+ targetNodesStream.apply(components.applyAsLong(sourceNodeId), sourceNodeId + 1)
+ .forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId,
(source, target, similarity) -> {
topKMap.put(source, target, similarity);
topKMap.put(target, source, similarity);
}
));
} else {
- targetNodesStream()
+ targetNodesStream.apply(components.applyAsLong(sourceNodeId), 0L)
.filter(targetNodeId -> sourceNodeId != targetNodeId)
- .forEach(targetNodeId -> similarityConsumer.accept(sourceNodeId, targetNodeId, topKMap::put));
+ .forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, topKMap::put));
}
});
progressTracker.endSubTask();
@@ -355,7 +359,7 @@ private TopKMap computeTopKMapParallel() {
var topKMap = new TopKMap(neighbors.size(), sourceNodes, Math.abs(config.normalizedK()), comparator);
ParallelUtil.parallelStreamConsume(
- loggableAndTerminatableSourceNodeStream(),
+ loggableAndTerminableSourceNodeStream(),
concurrency,
terminationFlag,
stream -> stream
@@ -366,9 +370,9 @@ private TopKMap computeTopKMapParallel() {
// into these queues is not considered to be thread-safe.
// Hence, we need to ensure that down the stream, exactly one queue
// within the TopKMap processes all pairs for a single node.
- targetNodesStream()
+ targetNodesStream.apply(components.applyAsLong(sourceNodeId), 0L)
.filter(targetNodeId -> sourceNodeId != targetNodeId)
- .forEach(targetNodeId -> similarityConsumer.accept(sourceNodeId, targetNodeId, topKMap::put))
+ .forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, topKMap::put))
)
);
@@ -380,15 +384,15 @@ private Stream computeTopN() {
progressTracker.beginSubTask(calculateWorkload());
var topNList = new TopNList(config.normalizedN());
- loggableAndTerminatableSourceNodeStream()
+ loggableAndTerminableSourceNodeStream()
.forEach(sourceNodeId -> {
if (sourceNodeFilter.equals(NodeFilter.noOp)) {
- targetNodesStream(sourceNodeId + 1)
- .forEach(targetNodeId -> similarityConsumer.accept(sourceNodeId, targetNodeId, topNList::add));
+ targetNodesStream.apply(components.applyAsLong(sourceNodeId), sourceNodeId + 1)
+ .forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, topNList::add));
} else {
- targetNodesStream()
+ targetNodesStream.apply(components.applyAsLong(sourceNodeId), 0L)
.filter(targetNodeId -> sourceNodeId != targetNodeId)
- .forEach(targetNodeId -> similarityConsumer.accept(sourceNodeId, targetNodeId, topNList::add));
+ .forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, topNList::add));
}
});
@@ -402,31 +406,30 @@ private Stream computeTopN(TopKMap topKMap) {
return topNList.stream();
}
- private LongStream sourceNodesStream(long offset) {
- return new SetBitsIterable(sourceNodes, offset).stream();
- }
-
- private LongStream sourceNodesStream() {
- return sourceNodesStream(0);
+ private Function initSourceNodesStream() {
+ return offset -> new SetBitsIterable(sourceNodes, offset).stream();
}
- private LongStream loggableAndTerminatableSourceNodeStream() {
- return checkProgress(sourceNodesStream());
- }
+ private BiFunction initTargetNodesStream() {
+ if (!config.considerComponents()) {
+ return (componentId, offset) -> new SetBitsIterable(targetNodes, offset).stream();
+ }
- private LongStream targetNodesStream(long offset) {
- return new SetBitsIterable(targetNodes, offset).stream();
+ var componentNodes = ComponentNodes.create(components, graph.nodeCount(), concurrency);
+ return (componentId, offset) -> StreamSupport
+ .longStream(componentNodes.spliterator(componentId, offset), true)
+ .filter(targetNodes::get);
}
- private LongStream targetNodesStream() {
- return targetNodesStream(0);
+ private LongStream loggableAndTerminableSourceNodeStream() {
+ return checkProgress(sourceNodesStream.apply(0L));
}
private Stream computeSimilaritiesForNode(long sourceNodeId) {
- return targetNodesStream(sourceNodeId + 1)
+ return targetNodesStream.apply(components.applyAsLong(sourceNodeId), sourceNodeId + 1)
.mapToObj(targetNodeId -> {
var resultHolder = new SimilarityResult[]{null};
- similarityConsumer.accept(
+ computeSimilarityFor(
sourceNodeId,
targetNodeId,
(source, target, similarity) -> resultHolder[0] = new SimilarityResult(source, target, similarity)
@@ -436,11 +439,30 @@ private Stream computeSimilaritiesForNode(long sourceNodeId) {
.filter(Objects::nonNull);
}
+ private static LongUnaryOperator initComponentIdMapping(Graph graph, LongUnaryOperator originComponentIdMapper) {
+ var componentIdMappings = new HugeLongLongMap();
+ var mappedComponentId = new AtomicLong(0L);
+ var mappedComponentIdPerNode = HugeLongArray.newArray(graph.nodeCount());
+ graph.forEachNode(n -> {
+ long originComponentIdForNode = originComponentIdMapper.applyAsLong(n);
+ long mappedComponentIdForNode = componentIdMappings.getOrDefault(originComponentIdMapper.applyAsLong(n),
+ mappedComponentId.getAndIncrement());
+
+ if (!componentIdMappings.containsKey(originComponentIdForNode)) {
+ componentIdMappings.put(originComponentIdForNode, mappedComponentIdForNode);
+ }
+ mappedComponentIdPerNode.set(n, mappedComponentIdForNode);
+ return true;
+ });
+
+ return mappedComponentIdPerNode::get;
+ }
+
interface SimilarityConsumer {
void accept(long sourceNodeId, long targetNodeId, double similarity);
}
- private void computeSimilarityForSingleComponent(long sourceNodeId, long targetNodeId, SimilarityConsumer consumer) {
+ private void computeSimilarityFor(long sourceNodeId, long targetNodeId, SimilarityConsumer consumer) {
double similarity;
var sourceNodeNeighbors = neighbors.get(sourceNodeId);
var targetNodeNeighbors = neighbors.get(targetNodeId);
@@ -456,15 +478,6 @@ private void computeSimilarityForSingleComponent(long sourceNodeId, long targetN
}
}
- private void computeSimilarityForComponents(long sourceNodeId, long targetNodeId, SimilarityConsumer consumer) {
- if (components.get(sourceNodeId) != components.get(targetNodeId)) {
- consumer.accept(sourceNodeId, targetNodeId, 0);
- return;
- }
-
- computeSimilarityForSingleComponent(sourceNodeId, targetNodeId, consumer);
- }
-
private double computeWeightedSimilarity(
long[] sourceNodeNeighbors,
long[] targetNodeNeighbors,
diff --git a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityBaseConfig.java b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityBaseConfig.java
index dd18f23245..74d3df8383 100644
--- a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityBaseConfig.java
+++ b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityBaseConfig.java
@@ -52,8 +52,8 @@ public interface NodeSimilarityBaseConfig extends AlgoBaseConfig, RelationshipWe
String COMPONENT_PROPERTY_KEY = "componentProperty";
- String ENABLE_COMPONENT_OPTIMIZATION_KEY = "enableComponentOptimization";
- boolean ENABLE_COMPONENT_OPTIMIZATION = false;
+ String CONSIDER_COMPONENTS_KEY = "considerComponents";
+ boolean CONSIDER_COMPONENTS = false;
@Value.Default
@Configuration.DoubleRange(min = 0, max = 1)
@@ -114,8 +114,8 @@ default int bottomN() {
default @Nullable String componentProperty() { return null; }
@Value.Default
- @Configuration.Key(ENABLE_COMPONENT_OPTIMIZATION_KEY)
- default boolean isEnableComponentOptimization() { return ENABLE_COMPONENT_OPTIMIZATION; }
+ @Configuration.Key(CONSIDER_COMPONENTS_KEY)
+ default boolean considerComponents() { return CONSIDER_COMPONENTS; }
@Configuration.Ignore
@Value.Derived
@@ -203,10 +203,7 @@ default void validateComponentProperty(
@Value.Derived
default boolean runWCC() {
- if (isEnableComponentOptimization() && componentProperty() == null) {
- return true;
- }
- return false;
+ return considerComponents() && componentProperty() == null;
}
}
diff --git a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityFactory.java b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityFactory.java
index ca51f5bbff..eb5ff2aaae 100644
--- a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityFactory.java
+++ b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityFactory.java
@@ -24,6 +24,7 @@
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.ha.HugeObjectArray;
+import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
@@ -69,7 +70,6 @@ public MemoryEstimation memoryEstimation(CONFIG config) {
int topK = Math.abs(config.normalizedK());
MemoryEstimations.Builder builder = MemoryEstimations.builder(NodeSimilarity.class.getSimpleName())
- .perNode("components", HugeLongArray::memoryEstimation)
.perNode("node filter", nodeCount -> sizeOfLongArray(BitSet.bits2words(nodeCount)))
.add(
"vectors",
@@ -92,6 +92,13 @@ public MemoryEstimation memoryEstimation(CONFIG config) {
.rangePerNode("array", nodeCount -> MemoryRange.of(0, nodeCount * averageVectorSize))
.build();
}));
+ if (config.considerComponents()) {
+ builder.perNode("nodes sorted by component", HugeLongArray::memoryEstimation);
+ builder.perNode("upper bound per component", HugeAtomicLongArray::memoryEstimation);
+ }
+ if (config.considerComponents() && config.componentProperty() != null) {
+ builder.perNode("component mapping", HugeLongArray::memoryEstimation);
+ }
if (config.computeToGraph() && !config.hasTopK()) {
builder.add(
"similarity graph",
diff --git a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/SimilarityPairTriConsumer.java b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/SimilarityPairTriConsumer.java
deleted file mode 100644
index 52a7a1bbfb..0000000000
--- a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/SimilarityPairTriConsumer.java
+++ /dev/null
@@ -1,27 +0,0 @@
-/*
- * Copyright (c) "Neo4j"
- * Neo4j Sweden AB [http://neo4j.com]
- *
- * This file is part of Neo4j.
- *
- * Neo4j is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see .
- */
-package org.neo4j.gds.similarity.nodesim;
-
-@FunctionalInterface
-public interface SimilarityPairTriConsumer {
-
- void accept(long k, long v, NodeSimilarity.SimilarityConsumer similarityConsumer);
-
-}
diff --git a/algo/src/test/java/org/neo4j/gds/similarity/nodesim/ComponentNodesTest.java b/algo/src/test/java/org/neo4j/gds/similarity/nodesim/ComponentNodesTest.java
new file mode 100644
index 0000000000..54f0a32c76
--- /dev/null
+++ b/algo/src/test/java/org/neo4j/gds/similarity/nodesim/ComponentNodesTest.java
@@ -0,0 +1,268 @@
+/*
+ * Copyright (c) "Neo4j"
+ * Neo4j Sweden AB [http://neo4j.com]
+ *
+ * This file is part of Neo4j.
+ *
+ * Neo4j is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+package org.neo4j.gds.similarity.nodesim;
+
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+import org.neo4j.gds.collections.ha.HugeLongArray;
+import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
+import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
+import org.neo4j.gds.core.utils.shuffle.ShuffleUtil;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.SplittableRandom;
+import java.util.function.LongUnaryOperator;
+import java.util.stream.Collectors;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+class ComponentNodesTest {
+
+ private LongUnaryOperator prepare7DistinctSizeComponents() {
+ return nodeId -> {
+ if (nodeId < 3) {
+ return 0L; // size 3
+ } else if (nodeId < 8) {
+ return 1L; // size 5
+ } else if (nodeId < 12) {
+ return 2L; // size 4
+ } else if (nodeId < 18) {
+ return 3L; // size 6
+ } else if (nodeId < 19) {
+ return 4L; // size 1
+ } else if (nodeId < 21) {
+ return 5L; // size 2
+ } else {
+ return 6L; // size 7
+ }
+ };
+ }
+
+ @Test
+ void shouldDetermineIndexUpperBound() {
+ // nodeId -> componentId
+ var components = prepare7DistinctSizeComponents();
+ // componentId, upperBound
+ var idxUpperBoundPerComponent = ComponentNodes.computeIndexUpperBoundPerComponent(components, 28, 4);
+ // we cannot infer which component follows another, but the range must match in size for the component
+ Map componentPerIdxUpperBound = new HashMap<>(7);
+ for (int i = 0; i < 7; i++) {
+ componentPerIdxUpperBound.put(idxUpperBoundPerComponent.get(i), (long) i);
+ }
+ assertThat(idxUpperBoundPerComponent.get(7)).isEqualTo(0L);
+ int previousUpperBound = -1;
+ for (long key : componentPerIdxUpperBound.keySet().stream().sorted().collect(Collectors.toList())) {
+ switch ((int) (key - previousUpperBound)) {
+ case 1: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(4);break;
+ case 2: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(5);break;
+ case 3: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(0);break;
+ case 4: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(2);break;
+ case 5: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(1);break;
+ case 6: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(3);break;
+ case 7: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(6);break;
+ }
+ previousUpperBound = (int) (key);
+ }
+ }
+
+ @Test
+ void shouldComputeNodesSortedByComponent() {
+ // nodeId -> componentId
+ var components = prepare7DistinctSizeComponents();
+ // componentId, upperIdx of component
+ var upperBoundPerComponent = HugeAtomicLongArray.of(28, ParalleLongPageCreator.passThrough(4));
+ upperBoundPerComponent.set(0, 2);
+ upperBoundPerComponent.set(1, 7);
+ upperBoundPerComponent.set(2, 11);
+ upperBoundPerComponent.set(3, 17);
+ upperBoundPerComponent.set(4, 18);
+ upperBoundPerComponent.set(5, 20);
+ upperBoundPerComponent.set(6, 27);
+
+ var nodesSortedByComponent = ComponentNodes.computeNodesSortedByComponent(components,
+ upperBoundPerComponent, 4);
+
+ // nodes may occur in arbitrary order within components, but with the given assignment, nodeIds must be within
+ // component index bounds
+ assertEquals(28, nodesSortedByComponent.size());
+ for (int i = 0; i < 28; i++) {
+ var currentComp = components.applyAsLong(nodesSortedByComponent.get(i));
+
+ assertThat(nodesSortedByComponent.get(i)).isGreaterThan(currentComp == 0 ?
+ -1 : upperBoundPerComponent.get(currentComp - 1));
+ assertThat(nodesSortedByComponent.get(i)).isLessThanOrEqualTo(upperBoundPerComponent.get(currentComp));
+ }
+ }
+
+ @Test
+ void shouldComputeNodesSortedByComponentsNotConsecutive() {
+ // nodeId -> componentId
+ LongUnaryOperator components = nodeId -> {
+ if (nodeId < 4) {
+ return 3; // size 4
+ } else if (nodeId < 6) {
+ return 5; // size 2
+ } else {
+ return 1; // size 5
+ }
+ };
+ // componentId, upperIdx of component
+ var upperBoundPerComponent = HugeAtomicLongArray.of(11, ParalleLongPageCreator.passThrough(4));
+ upperBoundPerComponent.set(3, 3);
+ upperBoundPerComponent.set(5, 10);
+ upperBoundPerComponent.set(1, 8);
+
+ var nodesSortedByComponent = ComponentNodes.computeNodesSortedByComponent(components,
+ upperBoundPerComponent, 4);
+
+ // nodes may occur in arbitrary order within components, but with the given assignment, nodeIds must be within
+ // component index bounds
+ assertEquals(11, nodesSortedByComponent.size());
+ Collection values = new ArrayList<>();
+ int end = 0;
+ while (end < 11) {
+ int start = end;
+ if (nodesSortedByComponent.get(start) < 4) {
+ // next 4 nodes must be of component 3
+ end += 4;
+ values.addAll(List.of(0L, 1L, 2L, 3L));
+ } else if (nodesSortedByComponent.get(start) < 6) {
+ // next 2 nodes must be of component 5
+ end += 2;
+ values.addAll(List.of(4L, 5L));
+ } else {
+ // next 5 nodes must be of component 1
+ end += 5;
+ values.addAll(List.of(6L, 7L, 8L, 9L, 10L));
+ }
+ for (int i = start; i < end; i++) {
+ long nodeId = nodesSortedByComponent.get(i);
+ assertTrue(values.remove(nodeId));
+ }
+ }
+
+ ComponentNodes componentNodesMock = Mockito.mock(ComponentNodes.class);
+ Mockito.doReturn(components).when(componentNodesMock).getComponents();
+ Mockito.doReturn(upperBoundPerComponent).when(componentNodesMock).getUpperBoundPerComponent();
+ Mockito.doReturn(nodesSortedByComponent).when(componentNodesMock).getNodesSorted();
+
+ // no component with id 0
+ Mockito.doCallRealMethod().when(componentNodesMock).iterator(0L,0L);
+ Iterator iterator = componentNodesMock.iterator(0L, 0L);
+ assertFalse(iterator.hasNext());
+
+ // 5 nodes for component with id 1
+ Mockito.doCallRealMethod().when(componentNodesMock).iterator(1L,0L);
+ iterator = componentNodesMock.iterator(1L, 0L);
+ values.addAll(List.of(6L, 7L, 8L, 9L, 10L));
+ for (int i = 0; i < 5; i++) {
+ assertTrue(iterator.hasNext());
+ long nodeId = iterator.next();
+ assertTrue(values.remove(nodeId));
+ }
+ assertFalse(iterator.hasNext());
+ }
+
+ @Test
+ void shouldReturnNodesForComponent() {
+ // nodeId -> componentId
+ var components = new LongUnaryOperator() {
+ @Override
+ public long applyAsLong(long nodeId) {
+ if (nodeId < 3) {
+ return 0L; // size 3
+ } else {
+ return 1L; // size 5
+ }
+ }
+ };
+
+ var nodesSorted = HugeLongArray.newArray(8);
+ // uniqueIdx, nodeId
+ nodesSorted.set(0, 2);
+ nodesSorted.set(1, 1);
+ nodesSorted.set(2, 0);
+ nodesSorted.set(3, 7);
+ nodesSorted.set(4, 6);
+ nodesSorted.set(5, 5);
+ nodesSorted.set(6, 4);
+ nodesSorted.set(7, 3);
+
+ var upperBoundPerComponent = HugeAtomicLongArray.of(8, ParalleLongPageCreator.passThrough(4));
+ // componentId, upperBound
+ upperBoundPerComponent.set(0, 2);
+ upperBoundPerComponent.set(1, 7);
+
+ ComponentNodes componentNodesMock = Mockito.mock(ComponentNodes.class);
+ Mockito.doReturn(components).when(componentNodesMock).getComponents();
+ Mockito.doReturn(upperBoundPerComponent).when(componentNodesMock).getUpperBoundPerComponent();
+ Mockito.doReturn(nodesSorted).when(componentNodesMock).getNodesSorted();
+
+ // first component
+ Mockito.doCallRealMethod().when(componentNodesMock).iterator(0L,0L);
+ Iterator iterator = componentNodesMock.iterator(0L, 0L);
+ for (int nodeId = 0; nodeId < 3; nodeId++) {
+ assertTrue(iterator.hasNext());
+ assertThat(iterator.next()).isEqualTo(nodeId);
+ }
+ assertFalse(iterator.hasNext());
+ // second component
+ Mockito.doCallRealMethod().when(componentNodesMock).iterator(1L,0L);
+ iterator = componentNodesMock.iterator(1L, 0L);
+ for (int nodeId = 3; nodeId < 8; nodeId++) {
+ assertTrue(iterator.hasNext());
+ assertThat(iterator.next()).isEqualTo(nodeId);
+ }
+ assertFalse(iterator.hasNext());
+ }
+
+ @Test
+ void shouldRespectOffset() {
+ LongUnaryOperator components = nodeId -> 0L;
+
+ var nodesSorted = HugeLongArray.newArray(20);
+ nodesSorted.setAll(x -> x);
+ ShuffleUtil.shuffleArray(nodesSorted, new SplittableRandom(92));
+
+ var upperBoundPerComponent = HugeAtomicLongArray.of(1, ParalleLongPageCreator.passThrough(4));
+ upperBoundPerComponent.set(0, 19);
+
+ ComponentNodes componentNodesMock = Mockito.mock(ComponentNodes.class);
+ Mockito.doReturn(components).when(componentNodesMock).getComponents();
+ Mockito.doReturn(upperBoundPerComponent).when(componentNodesMock).getUpperBoundPerComponent();
+ Mockito.doReturn(nodesSorted).when(componentNodesMock).getNodesSorted();
+ Mockito.doCallRealMethod().when(componentNodesMock).iterator(0L,11L);
+
+ Set resultingNodes = new HashSet<>();
+ Iterator iterator = componentNodesMock.iterator(0L, 11L);
+ iterator.forEachRemaining(resultingNodes::add);
+ assertThat(resultingNodes).containsExactlyInAnyOrder(11L, 12L, 13L, 14L, 15L, 16L, 17L, 18L, 19L);
+ }
+}
diff --git a/algo/src/test/java/org/neo4j/gds/similarity/nodesim/ComponentPropertyNodeSimilarityTest.java b/algo/src/test/java/org/neo4j/gds/similarity/nodesim/ComponentPropertyNodeSimilarityTest.java
new file mode 100644
index 0000000000..ff71e618dc
--- /dev/null
+++ b/algo/src/test/java/org/neo4j/gds/similarity/nodesim/ComponentPropertyNodeSimilarityTest.java
@@ -0,0 +1,216 @@
+/*
+ * Copyright (c) "Neo4j"
+ * Neo4j Sweden AB [http://neo4j.com]
+ *
+ * This file is part of Neo4j.
+ *
+ * Neo4j is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+package org.neo4j.gds.similarity.nodesim;
+
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.junit.jupiter.params.provider.ValueSource;
+import org.neo4j.gds.Orientation;
+import org.neo4j.gds.api.Graph;
+import org.neo4j.gds.core.GraphDimensions;
+import org.neo4j.gds.core.ImmutableGraphDimensions;
+import org.neo4j.gds.core.concurrency.DefaultPool;
+import org.neo4j.gds.core.utils.mem.MemoryEstimations;
+import org.neo4j.gds.core.utils.mem.MemoryRange;
+import org.neo4j.gds.core.utils.mem.MemoryTree;
+import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
+import org.neo4j.gds.extension.GdlExtension;
+import org.neo4j.gds.extension.GdlGraph;
+import org.neo4j.gds.extension.Inject;
+import org.neo4j.gds.extension.TestGraph;
+import org.neo4j.gds.similarity.SimilarityResult;
+
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.params.provider.Arguments.arguments;
+import static org.neo4j.gds.Orientation.NATURAL;
+import static org.neo4j.gds.Orientation.REVERSE;
+import static org.neo4j.gds.TestSupport.crossArguments;
+import static org.neo4j.gds.TestSupport.toArguments;
+import static org.neo4j.gds.similarity.nodesim.NodeSimilarityBaseConfig.TOP_K_DEFAULT;
+import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
+
+@GdlExtension
+public class ComponentPropertyNodeSimilarityTest {
+
+ @GdlGraph(graphNamePrefix = "natural", orientation = NATURAL, idOffset = 0)
+ @GdlGraph(graphNamePrefix = "reverse", orientation = REVERSE, idOffset = 0)
+ private static final String DB_CYPHER =
+ "CREATE" +
+ " (a:Person {compid: 0})" +
+ ", (b:Person {compid: 0})" +
+ ", (c:Person {compid: 0})" +
+ ", (d:Person {compid: 0})" +
+ ", (e:Person {compid: 1})" +
+ ", (i1:Item {compid: 0})" +
+ ", (i2:Item {compid: 0})" +
+ ", (i3:Item {compid: 0})" +
+ ", (i4:Item {compid: 1})" +
+ ", (i5:Item {compid: 1})" +
+ ", (a)-[:LIKES {prop: 1.0}]->(i1)" +
+ ", (a)-[:LIKES {prop: 1.0}]->(i2)" +
+ ", (a)-[:LIKES {prop: 2.0}]->(i3)" +
+ ", (b)-[:LIKES {prop: 1.0}]->(i1)" +
+ ", (b)-[:LIKES {prop: 1.0}]->(i2)" +
+ ", (c)-[:LIKES {prop: 1.0}]->(i3)" +
+ ", (d)-[:LIKES {prop: 0.5}]->(i1)" +
+ ", (d)-[:LIKES {prop: 1.0}]->(i2)" +
+ ", (d)-[:LIKES {prop: 1.0}]->(i3)" +
+ ", (e)-[:LIKES {prop: 1.0}]->(i4)" +
+ ", (e)-[:LIKES {prop: 1.0}]->(i5)";
+ private static final Collection EXPECTED_OUTGOING_COMP_OPT = new HashSet<>();
+ private static final Collection EXPECTED_INCOMING_COMP_OPT = new HashSet<>();
+
+ static {
+ EXPECTED_OUTGOING_COMP_OPT.add(resultString(0, 1, 2 / 3.0));
+ EXPECTED_OUTGOING_COMP_OPT.add(resultString(0, 2, 1 / 3.0));
+ EXPECTED_OUTGOING_COMP_OPT.add(resultString(0, 3, 1.0));
+ EXPECTED_OUTGOING_COMP_OPT.add(resultString(1, 2, 0.0));
+ EXPECTED_OUTGOING_COMP_OPT.add(resultString(1, 3, 2 / 3.0));
+ EXPECTED_OUTGOING_COMP_OPT.add(resultString(2, 3, 1 / 3.0));
+ // Add results in reverse direction because topK
+ EXPECTED_OUTGOING_COMP_OPT.add(resultString(1, 0, 2 / 3.0));
+ EXPECTED_OUTGOING_COMP_OPT.add(resultString(2, 0, 1 / 3.0));
+ EXPECTED_OUTGOING_COMP_OPT.add(resultString(3, 0, 1.0));
+ EXPECTED_OUTGOING_COMP_OPT.add(resultString(2, 1, 0.0));
+ EXPECTED_OUTGOING_COMP_OPT.add(resultString(3, 1, 2 / 3.0));
+ EXPECTED_OUTGOING_COMP_OPT.add(resultString(3, 2, 1 / 3.0));
+
+ EXPECTED_INCOMING_COMP_OPT.add(resultString(9, 8, 1.0));
+ EXPECTED_INCOMING_COMP_OPT.add(resultString(5, 6, 1.0));
+ EXPECTED_INCOMING_COMP_OPT.add(resultString(5, 7, 1 / 2.0));
+ EXPECTED_INCOMING_COMP_OPT.add(resultString(6, 7, 1 / 2.0));
+ // Add results in reverse direction because topK
+ EXPECTED_INCOMING_COMP_OPT.add(resultString(8, 9, 1.0));
+ EXPECTED_INCOMING_COMP_OPT.add(resultString(6, 5, 1.0));
+ EXPECTED_INCOMING_COMP_OPT.add(resultString(7, 5, 1 / 2.0));
+ EXPECTED_INCOMING_COMP_OPT.add(resultString(7, 6, 1 / 2.0));
+ }
+
+ @Inject
+ private TestGraph naturalGraph;
+ @Inject
+ private TestGraph reverseGraph;
+
+ private static String resultString(long node1, long node2, double similarity) {
+ return formatWithLocale("%d,%d %f", node1, node2, similarity);
+ }
+
+ private static String resultString(SimilarityResult result) {
+ return resultString(result.node1, result.node2, result.similarity);
+ }
+
+ private static Stream concurrencies() {
+ return Stream.of(1, 4);
+ }
+
+ static Stream supportedLoadAndComputeDirections() {
+ Stream directions = Stream.of(
+ arguments(NATURAL),
+ arguments(REVERSE)
+ );
+ return crossArguments(() -> directions, toArguments(ComponentPropertyNodeSimilarityTest::concurrencies));
+ }
+
+ @ParameterizedTest(name = "componentProperty = {0}")
+ @ValueSource(booleans = {true, false})
+ void shouldComputeMemrecWithOrWithoutComponentMapping(boolean componentPropertySet) {
+ GraphDimensions dimensions = ImmutableGraphDimensions.builder()
+ .nodeCount(1_000_000)
+ .relCountUpperBound(5_000_000)
+ .build();
+
+ NodeSimilarityWriteConfig config = ImmutableNodeSimilarityWriteConfig
+ .builder()
+ .similarityCutoff(0.0)
+ .topK(TOP_K_DEFAULT)
+ .writeProperty("writeProperty")
+ .writeRelationshipType("writeRelationshipType")
+ .considerComponents(true)
+ .componentProperty(componentPropertySet ? "compid" : null)
+ .build();
+
+ MemoryTree actual = new NodeSimilarityFactory<>().memoryEstimation(config).estimate(dimensions, 1);
+
+ long nodeFilterRangeMin = 125_016L;
+ long nodeFilterRangeMax = 125_016L;
+ MemoryRange nodeFilterRange = MemoryRange.of(nodeFilterRangeMin, nodeFilterRangeMax);
+
+ long vectorsRangeMin = 56_000_016L;
+ long vectorsRangeMax = 56_000_016L;
+ MemoryRange vectorsRange = MemoryRange.of(vectorsRangeMin, vectorsRangeMax);
+
+ long weightsRangeMin = 16L;
+ long weightsRangeMax = 56_000_016L;
+ MemoryRange weightsRange = MemoryRange.of(weightsRangeMin, weightsRangeMax);
+
+ MemoryEstimations.Builder builder = MemoryEstimations.builder()
+ .fixed("upper bound per component", 8000040)
+ .fixed("nodes sorted by component", 8000040)
+ .fixed("node filter", nodeFilterRange)
+ .fixed("vectors", vectorsRange)
+ .fixed("weights", weightsRange)
+ .fixed("similarityComputer", 8);
+ if (componentPropertySet) {
+ builder.fixed("component mapping", 8000040);
+ }
+
+ long topKMapRangeMin = 248_000_016L;
+ long topKMapRangeMax = 248_000_016L;
+ builder.fixed("topK map", MemoryRange.of(topKMapRangeMin, topKMapRangeMax));
+
+ MemoryTree expected = builder.build().estimate(dimensions, 1);
+
+ assertEquals(expected.memoryUsage(), actual.memoryUsage());
+ }
+
+ @ParameterizedTest(name = "orientation: {0}, concurrency: {1}")
+ @MethodSource("supportedLoadAndComputeDirections")
+ void shouldOptimizeForDistinctComponentsProperty(Orientation orientation, int concurrency) {
+ Graph graph = orientation == NATURAL ? naturalGraph : reverseGraph;
+ var config = ImmutableNodeSimilarityStreamConfig.builder()
+ .similarityCutoff(0.0)
+ .considerComponents(true)
+ .componentProperty("compid")
+ .concurrency(concurrency)
+ .build();
+
+ var nodeSimilarity = NodeSimilarity.create(
+ graph,
+ config,
+ DefaultPool.INSTANCE,
+ ProgressTracker.NULL_TRACKER
+ );
+
+ Set result = nodeSimilarity
+ .compute()
+ .streamResult()
+ .map(ComponentPropertyNodeSimilarityTest::resultString)
+ .collect(Collectors.toSet());
+
+ assertEquals(orientation == REVERSE ? EXPECTED_INCOMING_COMP_OPT : EXPECTED_OUTGOING_COMP_OPT, result);
+ }
+}
diff --git a/algo/src/test/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityTest.java b/algo/src/test/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityTest.java
index 21f630ccfc..030c7aaacd 100644
--- a/algo/src/test/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityTest.java
+++ b/algo/src/test/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityTest.java
@@ -74,7 +74,7 @@
@GdlExtension
final class NodeSimilarityTest {
- // fixing idOffset to 0 as the expecatations hard-code ids
+ // fixing idOffset to 0 as the expectations hard-code ids
@GdlGraph(graphNamePrefix = "natural", orientation = NATURAL, idOffset = 0)
@GdlGraph(graphNamePrefix = "reverse", orientation = REVERSE, idOffset = 0)
@GdlGraph(graphNamePrefix = "undirected", orientation = UNDIRECTED, idOffset = 0)
@@ -685,7 +685,6 @@ void shouldComputeMemrec(int topK) {
MemoryRange weightsRange = MemoryRange.of(weightsRangeMin, weightsRangeMax);
MemoryEstimations.Builder builder = MemoryEstimations.builder()
- .fixed("components", 8000040)
.fixed("node filter", nodeFilterRange)
.fixed("vectors", vectorsRange)
.fixed("weights", weightsRange)
@@ -743,7 +742,6 @@ void shouldComputeMemrecWithTop(int topK) {
MemoryRange topNListRange = MemoryRange.of(topNListMin, topNListMax);
MemoryEstimations.Builder builder = MemoryEstimations.builder()
- .fixed("components", 8000040)
.fixed("node filter", nodeFilterRange)
.fixed("vectors", vectorsRange)
.fixed("weights", weightsRange)
@@ -784,8 +782,8 @@ void shouldComputeMemrecWithTopKAndTopNGreaterThanNodeCount() {
MemoryTree actual = new NodeSimilarityFactory<>().memoryEstimation(config).estimate(dimensions, 1);
- assertEquals(571432, actual.memoryUsage().min);
- assertEquals(733032, actual.memoryUsage().max);
+ assertEquals(570592, actual.memoryUsage().min);
+ assertEquals(732192, actual.memoryUsage().max);
}
@@ -881,6 +879,46 @@ void shouldLogProgress(int concurrency) {
);
}
+ @Test
+ void shouldLogProgressForWccOptimization() {
+ var graph = naturalGraph;
+ var config = ImmutableNodeSimilarityStreamConfig.builder()
+ .considerComponents(true)
+ .concurrency(4)
+ .build();
+ var progressTask = new NodeSimilarityFactory<>().progressTask(graph, config);
+ TestLog log = Neo4jProxy.testLog();
+ var progressTracker = new TestProgressTracker(
+ progressTask,
+ log,
+ 4,
+ EmptyTaskRegistryFactory.INSTANCE
+ );
+
+ NodeSimilarity.create(
+ graph,
+ config,
+ DefaultPool.INSTANCE,
+ progressTracker
+ ).compute().streamResult().count();
+
+ List progresses = progressTracker.getProgresses();
+
+ // Should log progress for prepare and actual comparisons
+ assertThat(progresses).hasSize(6);
+
+ assertThat(log.getMessages(INFO))
+ .extracting(removingThreadId())
+ .contains(
+ "NodeSimilarity :: prepare :: WCC :: Start",
+ "NodeSimilarity :: prepare :: WCC :: Finished",
+ "NodeSimilarity :: prepare :: Start",
+ "NodeSimilarity :: prepare :: Finished",
+ "NodeSimilarity :: compare node pairs :: Start",
+ "NodeSimilarity :: compare node pairs :: Finished"
+ );
+ }
+
@Test
void shouldGiveCorrectResultsWithOverlap() {
var gdl =
@@ -981,42 +1019,4 @@ void shouldThrowIfUpperIsSmaller() {
assertThatThrownBy(streamConfigBuilder().upperDegreeCutoff(3).degreeCutoff(4)::build)
.hasMessageContaining("upperDegreeCutoff cannot be smaller than degreeCutoff");
}
-
-
- @Test
- void shouldOptimizeForDistinctComponents() {
- var graph = naturalGraph;
- var config = ImmutableNodeSimilarityStreamConfig.builder().isEnableComponentOptimization(true).degreeCutoff(0).concurrency(4).build();
- var progressTask = new NodeSimilarityFactory<>().progressTask(graph, config);
- TestLog log = Neo4jProxy.testLog();
- var progressTracker = new TestProgressTracker(
- progressTask,
- log,
- 4,
- EmptyTaskRegistryFactory.INSTANCE
- );
-
- NodeSimilarity.create(
- graph,
- config,
- DefaultPool.INSTANCE,
- progressTracker
- ).compute().streamResult().count();
-
- List progresses = progressTracker.getProgresses();
-
- // Should log progress for prepare and actual comparisons
- assertThat(progresses).hasSize(6);
-
- assertThat(log.getMessages(INFO))
- .extracting(removingThreadId())
- .contains(
- "NodeSimilarity :: prepare :: WCC :: Start",
- "NodeSimilarity :: prepare :: WCC :: Finished",
- "NodeSimilarity :: prepare :: Start",
- "NodeSimilarity :: prepare :: Finished",
- "NodeSimilarity :: compare node pairs :: Start",
- "NodeSimilarity :: compare node pairs :: Finished"
- );
- }
}
diff --git a/doc/modules/ROOT/pages/algorithms/filtered-node-similarity.adoc b/doc/modules/ROOT/pages/algorithms/filtered-node-similarity.adoc
index 0b39d0794e..078e55e454 100644
--- a/doc/modules/ROOT/pages/algorithms/filtered-node-similarity.adoc
+++ b/doc/modules/ROOT/pages/algorithms/filtered-node-similarity.adoc
@@ -327,7 +327,7 @@ YIELD nodeCount, relationshipCount, bytesMin, bytesMax, requiredMemory
[opts="header",cols="1,1,1,1,1"]
|===
| nodeCount | relationshipCount | bytesMin | bytesMax | requiredMemory
-| 9 | 9 | 2496 | 2712 | "[2496 Bytes \... 2712 Bytes]"
+| 9 | 9 | 2384 | 2600 | "[2384 Bytes \... 2600 Bytes]"
|===
--
[[algorithms-filtered-node-similarity-examples-stream]]
diff --git a/doc/modules/ROOT/pages/algorithms/node-similarity.adoc b/doc/modules/ROOT/pages/algorithms/node-similarity.adoc
index 72cc1f5bd7..a3e10434e7 100644
--- a/doc/modules/ROOT/pages/algorithms/node-similarity.adoc
+++ b/doc/modules/ROOT/pages/algorithms/node-similarity.adoc
@@ -333,7 +333,7 @@ YIELD nodeCount, relationshipCount, bytesMin, bytesMax, requiredMemory
[opts="header",cols="1,1,1,1,1"]
|===
| nodeCount | relationshipCount | bytesMin | bytesMax | requiredMemory
-| 9 | 9 | 2496 | 2712 | "[2496 Bytes \... 2712 Bytes]"
+| 9 | 9 | 2384 | 2600 | "[2384 Bytes \... 2600 Bytes]"
|===
--
diff --git a/doc/modules/ROOT/partials/algorithms/node-similarity/specific-configuration.adoc b/doc/modules/ROOT/partials/algorithms/node-similarity/specific-configuration.adoc
index 99ee46363a..b4e0ff5475 100644
--- a/doc/modules/ROOT/partials/algorithms/node-similarity/specific-configuration.adoc
+++ b/doc/modules/ROOT/partials/algorithms/node-similarity/specific-configuration.adoc
@@ -21,3 +21,5 @@ If unspecified, the algorithm runs unweighted.
| similarityMetric
| String | JACCARD | yes | The metric used to compute similarity.
Can be either `JACCARD`, `OVERLAP` or `COSINE`.
+| [[consider-components]] considerComponents | Boolean | false | yes | If enabled applies an optimization which can increase performance for multi-component graphs. Makes use of the fact that nodes of distinct components always have a similarity of 0. If not already provided through xref:#component-property [componentProperty], internally runs xref:algorithms/wcc.adoc[WCC].
+| [[component-property]] componentProperty | String | null | yes | Name of the pre-computed node property to use for enabled xref:#consider-components [component optimization] in case pre-computed values are available.
diff --git a/proc/catalog/src/test/java/org/neo4j/gds/beta/generator/GraphGenerateProcTest.java b/proc/catalog/src/test/java/org/neo4j/gds/beta/generator/GraphGenerateProcTest.java
index 233ce61920..f232bed834 100644
--- a/proc/catalog/src/test/java/org/neo4j/gds/beta/generator/GraphGenerateProcTest.java
+++ b/proc/catalog/src/test/java/org/neo4j/gds/beta/generator/GraphGenerateProcTest.java
@@ -188,9 +188,9 @@ void shouldGenerateGraphWithRelationshipProperty() {
private static Stream estimations() {
return Stream.of(
- Arguments.of(100, 2, MemoryRange.of(28_928, 32_128)),
- Arguments.of(100, 4, MemoryRange.of(30_528, 35_328)),
- Arguments.of(200, 4, MemoryRange.of(60_944, 70_544))
+ Arguments.of(100, 2, MemoryRange.of(28_088, 31_288)),
+ Arguments.of(100, 4, MemoryRange.of(29_688, 34_488)),
+ Arguments.of(200, 4, MemoryRange.of(59_304, 68_904))
);
}
}