diff --git a/algo/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityFactory.java b/algo/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityFactory.java index 95c0b5bd4f..710de17fe9 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityFactory.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityFactory.java @@ -24,6 +24,7 @@ import org.neo4j.gds.api.Graph; import org.neo4j.gds.collections.ha.HugeLongArray; import org.neo4j.gds.collections.ha.HugeObjectArray; +import org.neo4j.gds.collections.haa.HugeAtomicLongArray; import org.neo4j.gds.core.concurrency.DefaultPool; import org.neo4j.gds.core.utils.mem.MemoryEstimation; import org.neo4j.gds.core.utils.mem.MemoryEstimations; @@ -74,7 +75,6 @@ public MemoryEstimation memoryEstimation(CONFIG config) { int topK = Math.abs(config.normalizedK()); MemoryEstimations.Builder builder = MemoryEstimations.builder(NodeSimilarity.class.getSimpleName()) - .perNode("components", HugeLongArray::memoryEstimation) .perNode("node filter", nodeCount -> sizeOfLongArray(BitSet.bits2words(nodeCount))) .add( "vectors", @@ -97,6 +97,13 @@ public MemoryEstimation memoryEstimation(CONFIG config) { .rangePerNode("array", nodeCount -> MemoryRange.of(0, nodeCount * averageVectorSize)) .build(); })); + if (config.considerComponents()) { + builder.perNode("nodes sorted by component", HugeLongArray::memoryEstimation); + builder.perNode("upper bound per component", HugeAtomicLongArray::memoryEstimation); + } + if (config.considerComponents() && config.componentProperty() != null) { + builder.perNode("component mapping", HugeLongArray::memoryEstimation); + } if (config.computeToGraph() && !config.hasTopK()) { builder.add( "similarity graph", diff --git a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/ComponentNodes.java b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/ComponentNodes.java new file mode 100644 index 0000000000..8c8683464a --- /dev/null +++ b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/ComponentNodes.java @@ -0,0 +1,170 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.similarity.nodesim; + +import org.neo4j.gds.collections.ha.HugeLongArray; +import org.neo4j.gds.collections.haa.HugeAtomicLongArray; +import org.neo4j.gds.core.concurrency.ParallelUtil; +import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator; +import org.neo4j.gds.termination.TerminationFlag; + +import java.util.NoSuchElementException; +import java.util.PrimitiveIterator; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.LongUnaryOperator; + +/** + * Manages nodes sorted by component. Produces an iterator over all nodes in a given component. + */ +public final class ComponentNodes { + private final LongUnaryOperator components; + private final HugeAtomicLongArray upperBoundPerComponent; + private final HugeLongArray nodesSorted; + + private ComponentNodes(LongUnaryOperator components, HugeAtomicLongArray upperBoundPerComponent, + HugeLongArray nodesSorted) { + + this.components = components; + this.upperBoundPerComponent = upperBoundPerComponent; + this.nodesSorted = nodesSorted; + } + + public static ComponentNodes create(LongUnaryOperator components, long nodeCount, int concurrency) { + var upperBoundPerComponent = computeIndexUpperBoundPerComponent(components, nodeCount, concurrency); + var nodesSorted = computeNodesSortedByComponent(components, upperBoundPerComponent, concurrency); + return new ComponentNodes( + components, + upperBoundPerComponent, + nodesSorted + ); + } + + public PrimitiveIterator.OfLong iterator(long componentId, long offset) { + return new Iterator(componentId, offset); + } + + public Spliterator.OfLong spliterator(long componentId, long offset) { + return Spliterators.spliteratorUnknownSize( + iterator(componentId, offset), + Spliterator.ORDERED | Spliterator.SORTED | Spliterator.IMMUTABLE | Spliterator.NONNULL | Spliterator.DISTINCT + ); + } + + LongUnaryOperator getComponents() { + return components; + } + + HugeAtomicLongArray getUpperBoundPerComponent() { + return upperBoundPerComponent; + } + + HugeLongArray getNodesSorted() { + return nodesSorted; + } + + static HugeAtomicLongArray computeIndexUpperBoundPerComponent(LongUnaryOperator components, long nodeCount, + int concurrency) { + + var upperBoundPerComponent = HugeAtomicLongArray.of(nodeCount, ParalleLongPageCreator.passThrough(concurrency)); + + // init coordinate array to contain the nr of nodes per component + // i.e. comp1 containing 3 nodes, comp2 containing 20 nodes: {(comp1, 3), (comp2, 20)} + ParallelUtil.parallelForEachNode(nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, nodeId -> { + { + long componentId = components.applyAsLong(nodeId); + upperBoundPerComponent.getAndAdd(componentId, 1); + } + }); + AtomicLong atomicNodeSum = new AtomicLong(); + // modify coordinate array to contain the upper bound of the global index for each component + // i.e. comp1 containing 3 nodes, comp2 containing 20 nodes, comp1 randomly accessed prior to comp2: + // {(comp1, 2), (comp2, 22)} + ParallelUtil.parallelForEachNode(nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, componentId -> + { + if (upperBoundPerComponent.get(componentId) > 0) { + var nodeSum = atomicNodeSum.addAndGet(upperBoundPerComponent.get(componentId)); + upperBoundPerComponent.set(componentId, nodeSum - 1); + } + }); + + return upperBoundPerComponent; + } + + static HugeLongArray computeNodesSortedByComponent(LongUnaryOperator components, + HugeAtomicLongArray idxUpperBoundPerComponent, int concurrency) { + + // initialized to its max possible size of 1 node <=> 1 component in a disconnected graph + long nodeCount = idxUpperBoundPerComponent.size(); + var nodesSortedByComponent = HugeLongArray.newArray(nodeCount); + var nodeIdxProviderArray = HugeAtomicLongArray.of(nodeCount, ParalleLongPageCreator.passThrough(concurrency)); + idxUpperBoundPerComponent.copyTo(nodeIdxProviderArray, nodeCount); + + // fill nodesSortedByComponent with nodeId per component-sorted, unique index + // i.e. comp1 containing 3 nodes, comp2 containing 20 nodes, named in order of processing: + // {(0, n3), (1, n2), (2, n1), (3, n23), .., (22, n4)} + ParallelUtil.parallelForEachNode(nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, indexId -> + { + long nodeId = nodeCount - indexId - 1; + long componentId = components.applyAsLong(nodeId); + long nodeIdx = nodeIdxProviderArray.getAndAdd(componentId, -1); + nodesSortedByComponent.set(nodeIdx, nodeId); + }); + + return nodesSortedByComponent; + } + + private final class Iterator implements PrimitiveIterator.OfLong { + private final long offset; + long runningIdx; + final long componentId; + + Iterator(long componentId, long offset) { + this.componentId = componentId; + this.runningIdx = getUpperBoundPerComponent().get(componentId); + this.offset = offset; + } + + @Override + public boolean hasNext() { + if (offset < 1L) { + return runningIdx > -1 && getComponents().applyAsLong(getNodesSorted().get(runningIdx)) == componentId; + } else { + while (runningIdx > -1 && getComponents().applyAsLong(getNodesSorted().get(runningIdx)) == componentId) { + if (getNodesSorted().get(runningIdx) < offset) { + runningIdx--; + } else { + return true; + } + } + return false; + } + } + @Override + public long nextLong() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return getNodesSorted().get(runningIdx--); + } + + } +} diff --git a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarity.java b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarity.java index a787ff5f91..4c1a0b43f3 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarity.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarity.java @@ -28,6 +28,7 @@ import org.neo4j.gds.collections.ha.HugeObjectArray; import org.neo4j.gds.core.concurrency.ParallelUtil; import org.neo4j.gds.core.utils.SetBitsIterable; +import org.neo4j.gds.core.utils.paged.HugeLongLongMap; import org.neo4j.gds.core.utils.paged.dss.DisjointSetStruct; import org.neo4j.gds.core.utils.progress.BatchingProgressLogger; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; @@ -43,14 +44,20 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.LongUnaryOperator; import java.util.stream.LongStream; import java.util.stream.Stream; +import java.util.stream.StreamSupport; public class NodeSimilarity extends Algorithm { private final Graph graph; - private final boolean sortVectors; private final NodeSimilarityBaseConfig config; + private final boolean sortVectors; + private final boolean weighted; private final BitSet sourceNodes; private final BitSet targetNodes; @@ -60,12 +67,12 @@ public class NodeSimilarity extends Algorithm { private final ExecutorService executorService; private final int concurrency; private final MetricSimilarityComputer similarityComputer; + private HugeObjectArray neighbors; private HugeObjectArray weights; - private HugeLongArray components; - private SimilarityPairTriConsumer similarityConsumer; - - private final boolean weighted; + private LongUnaryOperator components; + private Function sourceNodesStream; + private BiFunction targetNodesStream; public static NodeSimilarity create( Graph graph, @@ -197,47 +204,16 @@ private SimilarityGraphResult computeToGraph() { private void prepare() { progressTracker.beginSubTask(); - initComponents(); + components = initComponents(); + sourceNodesStream = initSourceNodesStream(); + targetNodesStream = initTargetNodesStream(); if (config.runWCC()) { progressTracker.beginSubTask(); } - neighbors = HugeObjectArray.newArray(long[].class, graph.nodeCount()); - if (weighted) { - weights = HugeObjectArray.newArray(double[].class, graph.nodeCount()); - } - DegreeComputer degreeComputer = new DegreeComputer(); - VectorComputer vectorComputer = VectorComputer.of(graph, weighted); - DegreeFilter degreeFilter = new DegreeFilter(config.degreeCutoff(), config.upperDegreeCutoff()); - neighbors.setAll(node -> { - graph.forEachRelationship(node, degreeComputer); - int degree = degreeComputer.degree; - degreeComputer.reset(); - vectorComputer.reset(degree); + initNodeSpecificFields(); - progressTracker.logProgress(graph.degree(node)); - if (degreeFilter.apply(degree)) { - if (sourceNodeFilter.test(node)) { - sourceNodes.set(node); - } - if (targetNodeFilter.test(node)) { - targetNodes.set(node); - } - - // TODO: we don't need to do the rest of the prepare for a node that isn't going to be used in the computation - vectorComputer.forEachRelationship(node); - - if (sortVectors) { - vectorComputer.sortTargetIds(); - } - if (weighted) { - weights.set(node, vectorComputer.getWeights()); - } - return vectorComputer.targetIds.buffer; - } - return null; - }); if (config.runWCC()) { progressTracker.endSubTask(); } @@ -262,24 +238,16 @@ private Stream computeParallel() { } } - private void initComponents() { - components = HugeLongArray.newArray(graph.nodeCount()); - if (!config.isEnableComponentOptimization()) { + private LongUnaryOperator initComponents() { + if (!config.considerComponents()) { // considering everything as within the same component - similarityConsumer = this::computeSimilarityForSingleComponent; - return; + return n -> 0; } - similarityConsumer = this::computeSimilarityForComponents; - if (config.componentProperty() != null) { - NodePropertyValues nodeProperties = graph.nodeProperties(config.componentProperty()); // extract component info from property - graph.forEachNode(n -> { - components.set(n, nodeProperties.longValue(n)); - return true; - }); - return; + NodePropertyValues nodeProperties = graph.nodeProperties(config.componentProperty()); + return initComponentIdMapping(graph, nodeProperties::longValue); } // run WCC to determine components @@ -293,17 +261,53 @@ private void initComponents() { Wcc wcc = new WccAlgorithmFactory<>().build(graph, wccConfig, ProgressTracker.NULL_TRACKER); DisjointSetStruct disjointSets = wcc.compute(); - graph.forEachNode(n -> { - components.set(n, disjointSets.setIdOf(n)); - return true; - }); progressTracker.endSubTask(); + return disjointSets::setIdOf; + } + + private void initNodeSpecificFields() { + neighbors = HugeObjectArray.newArray(long[].class, graph.nodeCount()); + if (weighted) { + weights = HugeObjectArray.newArray(double[].class, graph.nodeCount()); + } + + DegreeComputer degreeComputer = new DegreeComputer(); + VectorComputer vectorComputer = VectorComputer.of(graph, weighted); + DegreeFilter degreeFilter = new DegreeFilter(config.degreeCutoff(), config.upperDegreeCutoff()); + neighbors.setAll(node -> { + graph.forEachRelationship(node, degreeComputer); + int degree = degreeComputer.degree; + degreeComputer.reset(); + vectorComputer.reset(degree); + + progressTracker.logProgress(graph.degree(node)); + if (degreeFilter.apply(degree)) { + if (sourceNodeFilter.test(node)) { + sourceNodes.set(node); + } + if (targetNodeFilter.test(node)) { + targetNodes.set(node); + } + + // TODO: we don't need to do the rest of the prepare for a node that isn't going to be used in the computation + vectorComputer.forEachRelationship(node); + + if (sortVectors) { + vectorComputer.sortTargetIds(); + } + if (weighted) { + weights.set(node, vectorComputer.getWeights()); + } + return vectorComputer.targetIds.buffer; + } + return null; + }); } private Stream computeAll() { progressTracker.beginSubTask(calculateWorkload()); - var similarityResultStream = loggableAndTerminatableSourceNodeStream() + var similarityResultStream = loggableAndTerminableSourceNodeStream() .boxed() .flatMap(this::computeSimilaritiesForNode); progressTracker.endSubTask(); @@ -312,7 +316,7 @@ private Stream computeAll() { private Stream computeAllParallel() { return ParallelUtil.parallelStream( - loggableAndTerminatableSourceNodeStream(), concurrency, stream -> stream + loggableAndTerminableSourceNodeStream(), concurrency, stream -> stream .boxed() .flatMap(this::computeSimilaritiesForNode) ); @@ -326,20 +330,20 @@ private TopKMap computeTopKMap() { : SimilarityResult.ASCENDING; var topKMap = new TopKMap(neighbors.size(), sourceNodes, Math.abs(config.normalizedK()), comparator); - loggableAndTerminatableSourceNodeStream() + loggableAndTerminableSourceNodeStream() .forEach(sourceNodeId -> { if (sourceNodeFilter.equals(NodeFilter.noOp)) { - targetNodesStream(sourceNodeId + 1) - .forEach(targetNodeId -> similarityConsumer.accept(sourceNodeId, targetNodeId, + targetNodesStream.apply(components.applyAsLong(sourceNodeId), sourceNodeId + 1) + .forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, (source, target, similarity) -> { topKMap.put(source, target, similarity); topKMap.put(target, source, similarity); } )); } else { - targetNodesStream() + targetNodesStream.apply(components.applyAsLong(sourceNodeId), 0L) .filter(targetNodeId -> sourceNodeId != targetNodeId) - .forEach(targetNodeId -> similarityConsumer.accept(sourceNodeId, targetNodeId, topKMap::put)); + .forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, topKMap::put)); } }); progressTracker.endSubTask(); @@ -355,7 +359,7 @@ private TopKMap computeTopKMapParallel() { var topKMap = new TopKMap(neighbors.size(), sourceNodes, Math.abs(config.normalizedK()), comparator); ParallelUtil.parallelStreamConsume( - loggableAndTerminatableSourceNodeStream(), + loggableAndTerminableSourceNodeStream(), concurrency, terminationFlag, stream -> stream @@ -366,9 +370,9 @@ private TopKMap computeTopKMapParallel() { // into these queues is not considered to be thread-safe. // Hence, we need to ensure that down the stream, exactly one queue // within the TopKMap processes all pairs for a single node. - targetNodesStream() + targetNodesStream.apply(components.applyAsLong(sourceNodeId), 0L) .filter(targetNodeId -> sourceNodeId != targetNodeId) - .forEach(targetNodeId -> similarityConsumer.accept(sourceNodeId, targetNodeId, topKMap::put)) + .forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, topKMap::put)) ) ); @@ -380,15 +384,15 @@ private Stream computeTopN() { progressTracker.beginSubTask(calculateWorkload()); var topNList = new TopNList(config.normalizedN()); - loggableAndTerminatableSourceNodeStream() + loggableAndTerminableSourceNodeStream() .forEach(sourceNodeId -> { if (sourceNodeFilter.equals(NodeFilter.noOp)) { - targetNodesStream(sourceNodeId + 1) - .forEach(targetNodeId -> similarityConsumer.accept(sourceNodeId, targetNodeId, topNList::add)); + targetNodesStream.apply(components.applyAsLong(sourceNodeId), sourceNodeId + 1) + .forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, topNList::add)); } else { - targetNodesStream() + targetNodesStream.apply(components.applyAsLong(sourceNodeId), 0L) .filter(targetNodeId -> sourceNodeId != targetNodeId) - .forEach(targetNodeId -> similarityConsumer.accept(sourceNodeId, targetNodeId, topNList::add)); + .forEach(targetNodeId -> computeSimilarityFor(sourceNodeId, targetNodeId, topNList::add)); } }); @@ -402,31 +406,30 @@ private Stream computeTopN(TopKMap topKMap) { return topNList.stream(); } - private LongStream sourceNodesStream(long offset) { - return new SetBitsIterable(sourceNodes, offset).stream(); - } - - private LongStream sourceNodesStream() { - return sourceNodesStream(0); + private Function initSourceNodesStream() { + return offset -> new SetBitsIterable(sourceNodes, offset).stream(); } - private LongStream loggableAndTerminatableSourceNodeStream() { - return checkProgress(sourceNodesStream()); - } + private BiFunction initTargetNodesStream() { + if (!config.considerComponents()) { + return (componentId, offset) -> new SetBitsIterable(targetNodes, offset).stream(); + } - private LongStream targetNodesStream(long offset) { - return new SetBitsIterable(targetNodes, offset).stream(); + var componentNodes = ComponentNodes.create(components, graph.nodeCount(), concurrency); + return (componentId, offset) -> StreamSupport + .longStream(componentNodes.spliterator(componentId, offset), true) + .filter(targetNodes::get); } - private LongStream targetNodesStream() { - return targetNodesStream(0); + private LongStream loggableAndTerminableSourceNodeStream() { + return checkProgress(sourceNodesStream.apply(0L)); } private Stream computeSimilaritiesForNode(long sourceNodeId) { - return targetNodesStream(sourceNodeId + 1) + return targetNodesStream.apply(components.applyAsLong(sourceNodeId), sourceNodeId + 1) .mapToObj(targetNodeId -> { var resultHolder = new SimilarityResult[]{null}; - similarityConsumer.accept( + computeSimilarityFor( sourceNodeId, targetNodeId, (source, target, similarity) -> resultHolder[0] = new SimilarityResult(source, target, similarity) @@ -436,11 +439,30 @@ private Stream computeSimilaritiesForNode(long sourceNodeId) { .filter(Objects::nonNull); } + private static LongUnaryOperator initComponentIdMapping(Graph graph, LongUnaryOperator originComponentIdMapper) { + var componentIdMappings = new HugeLongLongMap(); + var mappedComponentId = new AtomicLong(0L); + var mappedComponentIdPerNode = HugeLongArray.newArray(graph.nodeCount()); + graph.forEachNode(n -> { + long originComponentIdForNode = originComponentIdMapper.applyAsLong(n); + long mappedComponentIdForNode = componentIdMappings.getOrDefault(originComponentIdMapper.applyAsLong(n), + mappedComponentId.getAndIncrement()); + + if (!componentIdMappings.containsKey(originComponentIdForNode)) { + componentIdMappings.put(originComponentIdForNode, mappedComponentIdForNode); + } + mappedComponentIdPerNode.set(n, mappedComponentIdForNode); + return true; + }); + + return mappedComponentIdPerNode::get; + } + interface SimilarityConsumer { void accept(long sourceNodeId, long targetNodeId, double similarity); } - private void computeSimilarityForSingleComponent(long sourceNodeId, long targetNodeId, SimilarityConsumer consumer) { + private void computeSimilarityFor(long sourceNodeId, long targetNodeId, SimilarityConsumer consumer) { double similarity; var sourceNodeNeighbors = neighbors.get(sourceNodeId); var targetNodeNeighbors = neighbors.get(targetNodeId); @@ -456,15 +478,6 @@ private void computeSimilarityForSingleComponent(long sourceNodeId, long targetN } } - private void computeSimilarityForComponents(long sourceNodeId, long targetNodeId, SimilarityConsumer consumer) { - if (components.get(sourceNodeId) != components.get(targetNodeId)) { - consumer.accept(sourceNodeId, targetNodeId, 0); - return; - } - - computeSimilarityForSingleComponent(sourceNodeId, targetNodeId, consumer); - } - private double computeWeightedSimilarity( long[] sourceNodeNeighbors, long[] targetNodeNeighbors, diff --git a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityBaseConfig.java b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityBaseConfig.java index dd18f23245..74d3df8383 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityBaseConfig.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityBaseConfig.java @@ -52,8 +52,8 @@ public interface NodeSimilarityBaseConfig extends AlgoBaseConfig, RelationshipWe String COMPONENT_PROPERTY_KEY = "componentProperty"; - String ENABLE_COMPONENT_OPTIMIZATION_KEY = "enableComponentOptimization"; - boolean ENABLE_COMPONENT_OPTIMIZATION = false; + String CONSIDER_COMPONENTS_KEY = "considerComponents"; + boolean CONSIDER_COMPONENTS = false; @Value.Default @Configuration.DoubleRange(min = 0, max = 1) @@ -114,8 +114,8 @@ default int bottomN() { default @Nullable String componentProperty() { return null; } @Value.Default - @Configuration.Key(ENABLE_COMPONENT_OPTIMIZATION_KEY) - default boolean isEnableComponentOptimization() { return ENABLE_COMPONENT_OPTIMIZATION; } + @Configuration.Key(CONSIDER_COMPONENTS_KEY) + default boolean considerComponents() { return CONSIDER_COMPONENTS; } @Configuration.Ignore @Value.Derived @@ -203,10 +203,7 @@ default void validateComponentProperty( @Value.Derived default boolean runWCC() { - if (isEnableComponentOptimization() && componentProperty() == null) { - return true; - } - return false; + return considerComponents() && componentProperty() == null; } } diff --git a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityFactory.java b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityFactory.java index ca51f5bbff..eb5ff2aaae 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityFactory.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityFactory.java @@ -24,6 +24,7 @@ import org.neo4j.gds.api.Graph; import org.neo4j.gds.collections.ha.HugeLongArray; import org.neo4j.gds.collections.ha.HugeObjectArray; +import org.neo4j.gds.collections.haa.HugeAtomicLongArray; import org.neo4j.gds.core.concurrency.DefaultPool; import org.neo4j.gds.core.utils.mem.MemoryEstimation; import org.neo4j.gds.core.utils.mem.MemoryEstimations; @@ -69,7 +70,6 @@ public MemoryEstimation memoryEstimation(CONFIG config) { int topK = Math.abs(config.normalizedK()); MemoryEstimations.Builder builder = MemoryEstimations.builder(NodeSimilarity.class.getSimpleName()) - .perNode("components", HugeLongArray::memoryEstimation) .perNode("node filter", nodeCount -> sizeOfLongArray(BitSet.bits2words(nodeCount))) .add( "vectors", @@ -92,6 +92,13 @@ public MemoryEstimation memoryEstimation(CONFIG config) { .rangePerNode("array", nodeCount -> MemoryRange.of(0, nodeCount * averageVectorSize)) .build(); })); + if (config.considerComponents()) { + builder.perNode("nodes sorted by component", HugeLongArray::memoryEstimation); + builder.perNode("upper bound per component", HugeAtomicLongArray::memoryEstimation); + } + if (config.considerComponents() && config.componentProperty() != null) { + builder.perNode("component mapping", HugeLongArray::memoryEstimation); + } if (config.computeToGraph() && !config.hasTopK()) { builder.add( "similarity graph", diff --git a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/SimilarityPairTriConsumer.java b/algo/src/main/java/org/neo4j/gds/similarity/nodesim/SimilarityPairTriConsumer.java deleted file mode 100644 index 52a7a1bbfb..0000000000 --- a/algo/src/main/java/org/neo4j/gds/similarity/nodesim/SimilarityPairTriConsumer.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.similarity.nodesim; - -@FunctionalInterface -public interface SimilarityPairTriConsumer { - - void accept(long k, long v, NodeSimilarity.SimilarityConsumer similarityConsumer); - -} diff --git a/algo/src/test/java/org/neo4j/gds/similarity/nodesim/ComponentNodesTest.java b/algo/src/test/java/org/neo4j/gds/similarity/nodesim/ComponentNodesTest.java new file mode 100644 index 0000000000..54f0a32c76 --- /dev/null +++ b/algo/src/test/java/org/neo4j/gds/similarity/nodesim/ComponentNodesTest.java @@ -0,0 +1,268 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.similarity.nodesim; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.neo4j.gds.collections.ha.HugeLongArray; +import org.neo4j.gds.collections.haa.HugeAtomicLongArray; +import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator; +import org.neo4j.gds.core.utils.shuffle.ShuffleUtil; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SplittableRandom; +import java.util.function.LongUnaryOperator; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class ComponentNodesTest { + + private LongUnaryOperator prepare7DistinctSizeComponents() { + return nodeId -> { + if (nodeId < 3) { + return 0L; // size 3 + } else if (nodeId < 8) { + return 1L; // size 5 + } else if (nodeId < 12) { + return 2L; // size 4 + } else if (nodeId < 18) { + return 3L; // size 6 + } else if (nodeId < 19) { + return 4L; // size 1 + } else if (nodeId < 21) { + return 5L; // size 2 + } else { + return 6L; // size 7 + } + }; + } + + @Test + void shouldDetermineIndexUpperBound() { + // nodeId -> componentId + var components = prepare7DistinctSizeComponents(); + // componentId, upperBound + var idxUpperBoundPerComponent = ComponentNodes.computeIndexUpperBoundPerComponent(components, 28, 4); + // we cannot infer which component follows another, but the range must match in size for the component + Map componentPerIdxUpperBound = new HashMap<>(7); + for (int i = 0; i < 7; i++) { + componentPerIdxUpperBound.put(idxUpperBoundPerComponent.get(i), (long) i); + } + assertThat(idxUpperBoundPerComponent.get(7)).isEqualTo(0L); + int previousUpperBound = -1; + for (long key : componentPerIdxUpperBound.keySet().stream().sorted().collect(Collectors.toList())) { + switch ((int) (key - previousUpperBound)) { + case 1: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(4);break; + case 2: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(5);break; + case 3: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(0);break; + case 4: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(2);break; + case 5: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(1);break; + case 6: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(3);break; + case 7: assertThat(componentPerIdxUpperBound.get(key)).isEqualTo(6);break; + } + previousUpperBound = (int) (key); + } + } + + @Test + void shouldComputeNodesSortedByComponent() { + // nodeId -> componentId + var components = prepare7DistinctSizeComponents(); + // componentId, upperIdx of component + var upperBoundPerComponent = HugeAtomicLongArray.of(28, ParalleLongPageCreator.passThrough(4)); + upperBoundPerComponent.set(0, 2); + upperBoundPerComponent.set(1, 7); + upperBoundPerComponent.set(2, 11); + upperBoundPerComponent.set(3, 17); + upperBoundPerComponent.set(4, 18); + upperBoundPerComponent.set(5, 20); + upperBoundPerComponent.set(6, 27); + + var nodesSortedByComponent = ComponentNodes.computeNodesSortedByComponent(components, + upperBoundPerComponent, 4); + + // nodes may occur in arbitrary order within components, but with the given assignment, nodeIds must be within + // component index bounds + assertEquals(28, nodesSortedByComponent.size()); + for (int i = 0; i < 28; i++) { + var currentComp = components.applyAsLong(nodesSortedByComponent.get(i)); + + assertThat(nodesSortedByComponent.get(i)).isGreaterThan(currentComp == 0 ? + -1 : upperBoundPerComponent.get(currentComp - 1)); + assertThat(nodesSortedByComponent.get(i)).isLessThanOrEqualTo(upperBoundPerComponent.get(currentComp)); + } + } + + @Test + void shouldComputeNodesSortedByComponentsNotConsecutive() { + // nodeId -> componentId + LongUnaryOperator components = nodeId -> { + if (nodeId < 4) { + return 3; // size 4 + } else if (nodeId < 6) { + return 5; // size 2 + } else { + return 1; // size 5 + } + }; + // componentId, upperIdx of component + var upperBoundPerComponent = HugeAtomicLongArray.of(11, ParalleLongPageCreator.passThrough(4)); + upperBoundPerComponent.set(3, 3); + upperBoundPerComponent.set(5, 10); + upperBoundPerComponent.set(1, 8); + + var nodesSortedByComponent = ComponentNodes.computeNodesSortedByComponent(components, + upperBoundPerComponent, 4); + + // nodes may occur in arbitrary order within components, but with the given assignment, nodeIds must be within + // component index bounds + assertEquals(11, nodesSortedByComponent.size()); + Collection values = new ArrayList<>(); + int end = 0; + while (end < 11) { + int start = end; + if (nodesSortedByComponent.get(start) < 4) { + // next 4 nodes must be of component 3 + end += 4; + values.addAll(List.of(0L, 1L, 2L, 3L)); + } else if (nodesSortedByComponent.get(start) < 6) { + // next 2 nodes must be of component 5 + end += 2; + values.addAll(List.of(4L, 5L)); + } else { + // next 5 nodes must be of component 1 + end += 5; + values.addAll(List.of(6L, 7L, 8L, 9L, 10L)); + } + for (int i = start; i < end; i++) { + long nodeId = nodesSortedByComponent.get(i); + assertTrue(values.remove(nodeId)); + } + } + + ComponentNodes componentNodesMock = Mockito.mock(ComponentNodes.class); + Mockito.doReturn(components).when(componentNodesMock).getComponents(); + Mockito.doReturn(upperBoundPerComponent).when(componentNodesMock).getUpperBoundPerComponent(); + Mockito.doReturn(nodesSortedByComponent).when(componentNodesMock).getNodesSorted(); + + // no component with id 0 + Mockito.doCallRealMethod().when(componentNodesMock).iterator(0L,0L); + Iterator iterator = componentNodesMock.iterator(0L, 0L); + assertFalse(iterator.hasNext()); + + // 5 nodes for component with id 1 + Mockito.doCallRealMethod().when(componentNodesMock).iterator(1L,0L); + iterator = componentNodesMock.iterator(1L, 0L); + values.addAll(List.of(6L, 7L, 8L, 9L, 10L)); + for (int i = 0; i < 5; i++) { + assertTrue(iterator.hasNext()); + long nodeId = iterator.next(); + assertTrue(values.remove(nodeId)); + } + assertFalse(iterator.hasNext()); + } + + @Test + void shouldReturnNodesForComponent() { + // nodeId -> componentId + var components = new LongUnaryOperator() { + @Override + public long applyAsLong(long nodeId) { + if (nodeId < 3) { + return 0L; // size 3 + } else { + return 1L; // size 5 + } + } + }; + + var nodesSorted = HugeLongArray.newArray(8); + // uniqueIdx, nodeId + nodesSorted.set(0, 2); + nodesSorted.set(1, 1); + nodesSorted.set(2, 0); + nodesSorted.set(3, 7); + nodesSorted.set(4, 6); + nodesSorted.set(5, 5); + nodesSorted.set(6, 4); + nodesSorted.set(7, 3); + + var upperBoundPerComponent = HugeAtomicLongArray.of(8, ParalleLongPageCreator.passThrough(4)); + // componentId, upperBound + upperBoundPerComponent.set(0, 2); + upperBoundPerComponent.set(1, 7); + + ComponentNodes componentNodesMock = Mockito.mock(ComponentNodes.class); + Mockito.doReturn(components).when(componentNodesMock).getComponents(); + Mockito.doReturn(upperBoundPerComponent).when(componentNodesMock).getUpperBoundPerComponent(); + Mockito.doReturn(nodesSorted).when(componentNodesMock).getNodesSorted(); + + // first component + Mockito.doCallRealMethod().when(componentNodesMock).iterator(0L,0L); + Iterator iterator = componentNodesMock.iterator(0L, 0L); + for (int nodeId = 0; nodeId < 3; nodeId++) { + assertTrue(iterator.hasNext()); + assertThat(iterator.next()).isEqualTo(nodeId); + } + assertFalse(iterator.hasNext()); + // second component + Mockito.doCallRealMethod().when(componentNodesMock).iterator(1L,0L); + iterator = componentNodesMock.iterator(1L, 0L); + for (int nodeId = 3; nodeId < 8; nodeId++) { + assertTrue(iterator.hasNext()); + assertThat(iterator.next()).isEqualTo(nodeId); + } + assertFalse(iterator.hasNext()); + } + + @Test + void shouldRespectOffset() { + LongUnaryOperator components = nodeId -> 0L; + + var nodesSorted = HugeLongArray.newArray(20); + nodesSorted.setAll(x -> x); + ShuffleUtil.shuffleArray(nodesSorted, new SplittableRandom(92)); + + var upperBoundPerComponent = HugeAtomicLongArray.of(1, ParalleLongPageCreator.passThrough(4)); + upperBoundPerComponent.set(0, 19); + + ComponentNodes componentNodesMock = Mockito.mock(ComponentNodes.class); + Mockito.doReturn(components).when(componentNodesMock).getComponents(); + Mockito.doReturn(upperBoundPerComponent).when(componentNodesMock).getUpperBoundPerComponent(); + Mockito.doReturn(nodesSorted).when(componentNodesMock).getNodesSorted(); + Mockito.doCallRealMethod().when(componentNodesMock).iterator(0L,11L); + + Set resultingNodes = new HashSet<>(); + Iterator iterator = componentNodesMock.iterator(0L, 11L); + iterator.forEachRemaining(resultingNodes::add); + assertThat(resultingNodes).containsExactlyInAnyOrder(11L, 12L, 13L, 14L, 15L, 16L, 17L, 18L, 19L); + } +} diff --git a/algo/src/test/java/org/neo4j/gds/similarity/nodesim/ComponentPropertyNodeSimilarityTest.java b/algo/src/test/java/org/neo4j/gds/similarity/nodesim/ComponentPropertyNodeSimilarityTest.java new file mode 100644 index 0000000000..ff71e618dc --- /dev/null +++ b/algo/src/test/java/org/neo4j/gds/similarity/nodesim/ComponentPropertyNodeSimilarityTest.java @@ -0,0 +1,216 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.similarity.nodesim; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.neo4j.gds.Orientation; +import org.neo4j.gds.api.Graph; +import org.neo4j.gds.core.GraphDimensions; +import org.neo4j.gds.core.ImmutableGraphDimensions; +import org.neo4j.gds.core.concurrency.DefaultPool; +import org.neo4j.gds.core.utils.mem.MemoryEstimations; +import org.neo4j.gds.core.utils.mem.MemoryRange; +import org.neo4j.gds.core.utils.mem.MemoryTree; +import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; +import org.neo4j.gds.extension.GdlExtension; +import org.neo4j.gds.extension.GdlGraph; +import org.neo4j.gds.extension.Inject; +import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.similarity.SimilarityResult; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import static org.neo4j.gds.Orientation.NATURAL; +import static org.neo4j.gds.Orientation.REVERSE; +import static org.neo4j.gds.TestSupport.crossArguments; +import static org.neo4j.gds.TestSupport.toArguments; +import static org.neo4j.gds.similarity.nodesim.NodeSimilarityBaseConfig.TOP_K_DEFAULT; +import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; + +@GdlExtension +public class ComponentPropertyNodeSimilarityTest { + + @GdlGraph(graphNamePrefix = "natural", orientation = NATURAL, idOffset = 0) + @GdlGraph(graphNamePrefix = "reverse", orientation = REVERSE, idOffset = 0) + private static final String DB_CYPHER = + "CREATE" + + " (a:Person {compid: 0})" + + ", (b:Person {compid: 0})" + + ", (c:Person {compid: 0})" + + ", (d:Person {compid: 0})" + + ", (e:Person {compid: 1})" + + ", (i1:Item {compid: 0})" + + ", (i2:Item {compid: 0})" + + ", (i3:Item {compid: 0})" + + ", (i4:Item {compid: 1})" + + ", (i5:Item {compid: 1})" + + ", (a)-[:LIKES {prop: 1.0}]->(i1)" + + ", (a)-[:LIKES {prop: 1.0}]->(i2)" + + ", (a)-[:LIKES {prop: 2.0}]->(i3)" + + ", (b)-[:LIKES {prop: 1.0}]->(i1)" + + ", (b)-[:LIKES {prop: 1.0}]->(i2)" + + ", (c)-[:LIKES {prop: 1.0}]->(i3)" + + ", (d)-[:LIKES {prop: 0.5}]->(i1)" + + ", (d)-[:LIKES {prop: 1.0}]->(i2)" + + ", (d)-[:LIKES {prop: 1.0}]->(i3)" + + ", (e)-[:LIKES {prop: 1.0}]->(i4)" + + ", (e)-[:LIKES {prop: 1.0}]->(i5)"; + private static final Collection EXPECTED_OUTGOING_COMP_OPT = new HashSet<>(); + private static final Collection EXPECTED_INCOMING_COMP_OPT = new HashSet<>(); + + static { + EXPECTED_OUTGOING_COMP_OPT.add(resultString(0, 1, 2 / 3.0)); + EXPECTED_OUTGOING_COMP_OPT.add(resultString(0, 2, 1 / 3.0)); + EXPECTED_OUTGOING_COMP_OPT.add(resultString(0, 3, 1.0)); + EXPECTED_OUTGOING_COMP_OPT.add(resultString(1, 2, 0.0)); + EXPECTED_OUTGOING_COMP_OPT.add(resultString(1, 3, 2 / 3.0)); + EXPECTED_OUTGOING_COMP_OPT.add(resultString(2, 3, 1 / 3.0)); + // Add results in reverse direction because topK + EXPECTED_OUTGOING_COMP_OPT.add(resultString(1, 0, 2 / 3.0)); + EXPECTED_OUTGOING_COMP_OPT.add(resultString(2, 0, 1 / 3.0)); + EXPECTED_OUTGOING_COMP_OPT.add(resultString(3, 0, 1.0)); + EXPECTED_OUTGOING_COMP_OPT.add(resultString(2, 1, 0.0)); + EXPECTED_OUTGOING_COMP_OPT.add(resultString(3, 1, 2 / 3.0)); + EXPECTED_OUTGOING_COMP_OPT.add(resultString(3, 2, 1 / 3.0)); + + EXPECTED_INCOMING_COMP_OPT.add(resultString(9, 8, 1.0)); + EXPECTED_INCOMING_COMP_OPT.add(resultString(5, 6, 1.0)); + EXPECTED_INCOMING_COMP_OPT.add(resultString(5, 7, 1 / 2.0)); + EXPECTED_INCOMING_COMP_OPT.add(resultString(6, 7, 1 / 2.0)); + // Add results in reverse direction because topK + EXPECTED_INCOMING_COMP_OPT.add(resultString(8, 9, 1.0)); + EXPECTED_INCOMING_COMP_OPT.add(resultString(6, 5, 1.0)); + EXPECTED_INCOMING_COMP_OPT.add(resultString(7, 5, 1 / 2.0)); + EXPECTED_INCOMING_COMP_OPT.add(resultString(7, 6, 1 / 2.0)); + } + + @Inject + private TestGraph naturalGraph; + @Inject + private TestGraph reverseGraph; + + private static String resultString(long node1, long node2, double similarity) { + return formatWithLocale("%d,%d %f", node1, node2, similarity); + } + + private static String resultString(SimilarityResult result) { + return resultString(result.node1, result.node2, result.similarity); + } + + private static Stream concurrencies() { + return Stream.of(1, 4); + } + + static Stream supportedLoadAndComputeDirections() { + Stream directions = Stream.of( + arguments(NATURAL), + arguments(REVERSE) + ); + return crossArguments(() -> directions, toArguments(ComponentPropertyNodeSimilarityTest::concurrencies)); + } + + @ParameterizedTest(name = "componentProperty = {0}") + @ValueSource(booleans = {true, false}) + void shouldComputeMemrecWithOrWithoutComponentMapping(boolean componentPropertySet) { + GraphDimensions dimensions = ImmutableGraphDimensions.builder() + .nodeCount(1_000_000) + .relCountUpperBound(5_000_000) + .build(); + + NodeSimilarityWriteConfig config = ImmutableNodeSimilarityWriteConfig + .builder() + .similarityCutoff(0.0) + .topK(TOP_K_DEFAULT) + .writeProperty("writeProperty") + .writeRelationshipType("writeRelationshipType") + .considerComponents(true) + .componentProperty(componentPropertySet ? "compid" : null) + .build(); + + MemoryTree actual = new NodeSimilarityFactory<>().memoryEstimation(config).estimate(dimensions, 1); + + long nodeFilterRangeMin = 125_016L; + long nodeFilterRangeMax = 125_016L; + MemoryRange nodeFilterRange = MemoryRange.of(nodeFilterRangeMin, nodeFilterRangeMax); + + long vectorsRangeMin = 56_000_016L; + long vectorsRangeMax = 56_000_016L; + MemoryRange vectorsRange = MemoryRange.of(vectorsRangeMin, vectorsRangeMax); + + long weightsRangeMin = 16L; + long weightsRangeMax = 56_000_016L; + MemoryRange weightsRange = MemoryRange.of(weightsRangeMin, weightsRangeMax); + + MemoryEstimations.Builder builder = MemoryEstimations.builder() + .fixed("upper bound per component", 8000040) + .fixed("nodes sorted by component", 8000040) + .fixed("node filter", nodeFilterRange) + .fixed("vectors", vectorsRange) + .fixed("weights", weightsRange) + .fixed("similarityComputer", 8); + if (componentPropertySet) { + builder.fixed("component mapping", 8000040); + } + + long topKMapRangeMin = 248_000_016L; + long topKMapRangeMax = 248_000_016L; + builder.fixed("topK map", MemoryRange.of(topKMapRangeMin, topKMapRangeMax)); + + MemoryTree expected = builder.build().estimate(dimensions, 1); + + assertEquals(expected.memoryUsage(), actual.memoryUsage()); + } + + @ParameterizedTest(name = "orientation: {0}, concurrency: {1}") + @MethodSource("supportedLoadAndComputeDirections") + void shouldOptimizeForDistinctComponentsProperty(Orientation orientation, int concurrency) { + Graph graph = orientation == NATURAL ? naturalGraph : reverseGraph; + var config = ImmutableNodeSimilarityStreamConfig.builder() + .similarityCutoff(0.0) + .considerComponents(true) + .componentProperty("compid") + .concurrency(concurrency) + .build(); + + var nodeSimilarity = NodeSimilarity.create( + graph, + config, + DefaultPool.INSTANCE, + ProgressTracker.NULL_TRACKER + ); + + Set result = nodeSimilarity + .compute() + .streamResult() + .map(ComponentPropertyNodeSimilarityTest::resultString) + .collect(Collectors.toSet()); + + assertEquals(orientation == REVERSE ? EXPECTED_INCOMING_COMP_OPT : EXPECTED_OUTGOING_COMP_OPT, result); + } +} diff --git a/algo/src/test/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityTest.java b/algo/src/test/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityTest.java index 21f630ccfc..030c7aaacd 100644 --- a/algo/src/test/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityTest.java +++ b/algo/src/test/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityTest.java @@ -74,7 +74,7 @@ @GdlExtension final class NodeSimilarityTest { - // fixing idOffset to 0 as the expecatations hard-code ids + // fixing idOffset to 0 as the expectations hard-code ids @GdlGraph(graphNamePrefix = "natural", orientation = NATURAL, idOffset = 0) @GdlGraph(graphNamePrefix = "reverse", orientation = REVERSE, idOffset = 0) @GdlGraph(graphNamePrefix = "undirected", orientation = UNDIRECTED, idOffset = 0) @@ -685,7 +685,6 @@ void shouldComputeMemrec(int topK) { MemoryRange weightsRange = MemoryRange.of(weightsRangeMin, weightsRangeMax); MemoryEstimations.Builder builder = MemoryEstimations.builder() - .fixed("components", 8000040) .fixed("node filter", nodeFilterRange) .fixed("vectors", vectorsRange) .fixed("weights", weightsRange) @@ -743,7 +742,6 @@ void shouldComputeMemrecWithTop(int topK) { MemoryRange topNListRange = MemoryRange.of(topNListMin, topNListMax); MemoryEstimations.Builder builder = MemoryEstimations.builder() - .fixed("components", 8000040) .fixed("node filter", nodeFilterRange) .fixed("vectors", vectorsRange) .fixed("weights", weightsRange) @@ -784,8 +782,8 @@ void shouldComputeMemrecWithTopKAndTopNGreaterThanNodeCount() { MemoryTree actual = new NodeSimilarityFactory<>().memoryEstimation(config).estimate(dimensions, 1); - assertEquals(571432, actual.memoryUsage().min); - assertEquals(733032, actual.memoryUsage().max); + assertEquals(570592, actual.memoryUsage().min); + assertEquals(732192, actual.memoryUsage().max); } @@ -881,6 +879,46 @@ void shouldLogProgress(int concurrency) { ); } + @Test + void shouldLogProgressForWccOptimization() { + var graph = naturalGraph; + var config = ImmutableNodeSimilarityStreamConfig.builder() + .considerComponents(true) + .concurrency(4) + .build(); + var progressTask = new NodeSimilarityFactory<>().progressTask(graph, config); + TestLog log = Neo4jProxy.testLog(); + var progressTracker = new TestProgressTracker( + progressTask, + log, + 4, + EmptyTaskRegistryFactory.INSTANCE + ); + + NodeSimilarity.create( + graph, + config, + DefaultPool.INSTANCE, + progressTracker + ).compute().streamResult().count(); + + List progresses = progressTracker.getProgresses(); + + // Should log progress for prepare and actual comparisons + assertThat(progresses).hasSize(6); + + assertThat(log.getMessages(INFO)) + .extracting(removingThreadId()) + .contains( + "NodeSimilarity :: prepare :: WCC :: Start", + "NodeSimilarity :: prepare :: WCC :: Finished", + "NodeSimilarity :: prepare :: Start", + "NodeSimilarity :: prepare :: Finished", + "NodeSimilarity :: compare node pairs :: Start", + "NodeSimilarity :: compare node pairs :: Finished" + ); + } + @Test void shouldGiveCorrectResultsWithOverlap() { var gdl = @@ -981,42 +1019,4 @@ void shouldThrowIfUpperIsSmaller() { assertThatThrownBy(streamConfigBuilder().upperDegreeCutoff(3).degreeCutoff(4)::build) .hasMessageContaining("upperDegreeCutoff cannot be smaller than degreeCutoff"); } - - - @Test - void shouldOptimizeForDistinctComponents() { - var graph = naturalGraph; - var config = ImmutableNodeSimilarityStreamConfig.builder().isEnableComponentOptimization(true).degreeCutoff(0).concurrency(4).build(); - var progressTask = new NodeSimilarityFactory<>().progressTask(graph, config); - TestLog log = Neo4jProxy.testLog(); - var progressTracker = new TestProgressTracker( - progressTask, - log, - 4, - EmptyTaskRegistryFactory.INSTANCE - ); - - NodeSimilarity.create( - graph, - config, - DefaultPool.INSTANCE, - progressTracker - ).compute().streamResult().count(); - - List progresses = progressTracker.getProgresses(); - - // Should log progress for prepare and actual comparisons - assertThat(progresses).hasSize(6); - - assertThat(log.getMessages(INFO)) - .extracting(removingThreadId()) - .contains( - "NodeSimilarity :: prepare :: WCC :: Start", - "NodeSimilarity :: prepare :: WCC :: Finished", - "NodeSimilarity :: prepare :: Start", - "NodeSimilarity :: prepare :: Finished", - "NodeSimilarity :: compare node pairs :: Start", - "NodeSimilarity :: compare node pairs :: Finished" - ); - } } diff --git a/doc/modules/ROOT/pages/algorithms/filtered-node-similarity.adoc b/doc/modules/ROOT/pages/algorithms/filtered-node-similarity.adoc index 0b39d0794e..078e55e454 100644 --- a/doc/modules/ROOT/pages/algorithms/filtered-node-similarity.adoc +++ b/doc/modules/ROOT/pages/algorithms/filtered-node-similarity.adoc @@ -327,7 +327,7 @@ YIELD nodeCount, relationshipCount, bytesMin, bytesMax, requiredMemory [opts="header",cols="1,1,1,1,1"] |=== | nodeCount | relationshipCount | bytesMin | bytesMax | requiredMemory -| 9 | 9 | 2496 | 2712 | "[2496 Bytes \... 2712 Bytes]" +| 9 | 9 | 2384 | 2600 | "[2384 Bytes \... 2600 Bytes]" |=== -- [[algorithms-filtered-node-similarity-examples-stream]] diff --git a/doc/modules/ROOT/pages/algorithms/node-similarity.adoc b/doc/modules/ROOT/pages/algorithms/node-similarity.adoc index 72cc1f5bd7..a3e10434e7 100644 --- a/doc/modules/ROOT/pages/algorithms/node-similarity.adoc +++ b/doc/modules/ROOT/pages/algorithms/node-similarity.adoc @@ -333,7 +333,7 @@ YIELD nodeCount, relationshipCount, bytesMin, bytesMax, requiredMemory [opts="header",cols="1,1,1,1,1"] |=== | nodeCount | relationshipCount | bytesMin | bytesMax | requiredMemory -| 9 | 9 | 2496 | 2712 | "[2496 Bytes \... 2712 Bytes]" +| 9 | 9 | 2384 | 2600 | "[2384 Bytes \... 2600 Bytes]" |=== -- diff --git a/doc/modules/ROOT/partials/algorithms/node-similarity/specific-configuration.adoc b/doc/modules/ROOT/partials/algorithms/node-similarity/specific-configuration.adoc index 99ee46363a..b4e0ff5475 100644 --- a/doc/modules/ROOT/partials/algorithms/node-similarity/specific-configuration.adoc +++ b/doc/modules/ROOT/partials/algorithms/node-similarity/specific-configuration.adoc @@ -21,3 +21,5 @@ If unspecified, the algorithm runs unweighted. | similarityMetric | String | JACCARD | yes | The metric used to compute similarity. Can be either `JACCARD`, `OVERLAP` or `COSINE`. +| [[consider-components]] considerComponents | Boolean | false | yes | If enabled applies an optimization which can increase performance for multi-component graphs. Makes use of the fact that nodes of distinct components always have a similarity of 0. If not already provided through xref:#component-property [componentProperty], internally runs xref:algorithms/wcc.adoc[WCC]. +| [[component-property]] componentProperty | String | null | yes | Name of the pre-computed node property to use for enabled xref:#consider-components [component optimization] in case pre-computed values are available. diff --git a/proc/catalog/src/test/java/org/neo4j/gds/beta/generator/GraphGenerateProcTest.java b/proc/catalog/src/test/java/org/neo4j/gds/beta/generator/GraphGenerateProcTest.java index 233ce61920..f232bed834 100644 --- a/proc/catalog/src/test/java/org/neo4j/gds/beta/generator/GraphGenerateProcTest.java +++ b/proc/catalog/src/test/java/org/neo4j/gds/beta/generator/GraphGenerateProcTest.java @@ -188,9 +188,9 @@ void shouldGenerateGraphWithRelationshipProperty() { private static Stream estimations() { return Stream.of( - Arguments.of(100, 2, MemoryRange.of(28_928, 32_128)), - Arguments.of(100, 4, MemoryRange.of(30_528, 35_328)), - Arguments.of(200, 4, MemoryRange.of(60_944, 70_544)) + Arguments.of(100, 2, MemoryRange.of(28_088, 31_288)), + Arguments.of(100, 4, MemoryRange.of(29_688, 34_488)), + Arguments.of(200, 4, MemoryRange.of(59_304, 68_904)) ); } }