diff --git a/algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java b/algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java index 29f7695f00..b99fdc76e0 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnn.java @@ -25,8 +25,10 @@ import org.neo4j.gds.similarity.filtering.NodeFilter; import org.neo4j.gds.similarity.knn.Knn; import org.neo4j.gds.similarity.knn.KnnContext; +import org.neo4j.gds.similarity.knn.KnnNeighborFilterFactory; import org.neo4j.gds.similarity.knn.KnnResult; import org.neo4j.gds.similarity.knn.SimilarityFunction; +import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer; import java.util.Optional; import java.util.concurrent.ExecutorService; @@ -57,7 +59,7 @@ public static FilteredKnn createWithoutSeeding(Graph graph, FilteredKnnBaseConfi // a bit speculative, but we imagine this being used as entrypoint for seeding public static FilteredKnn createWithDefaultSeeding(Graph graph, FilteredKnnBaseConfig config, KnnContext context) { - var similarityFunction = Knn.defaultSimilarityFunction(graph, config.nodeProperties()); + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperties(graph, config.nodeProperties())); return create(graph, config, context, Optional.of(similarityFunction)); } @@ -71,9 +73,28 @@ public static FilteredKnn createWithDefaultSeeding(Graph graph, FilteredKnnBaseC */ static FilteredKnn create(Graph graph, FilteredKnnBaseConfig config, KnnContext context, Optional optionalSimilarityFunction) { var targetNodeFilter = config.targetNodeFilter().toNodeFilter(graph); - var targetNodeFiltering = TargetNodeFiltering.create(graph.nodeCount(),config.boundedK(graph.nodeCount()), targetNodeFilter, graph, optionalSimilarityFunction, config.similarityCutoff()); - var similarityFunction = optionalSimilarityFunction.orElse(Knn.defaultSimilarityFunction(graph, config.nodeProperties())); - var knn = Knn.createWithDefaultsAndInstrumentation(graph, config, context, targetNodeFiltering, similarityFunction); + var targetNodeFiltering = TargetNodeFiltering.create(graph.nodeCount(), config.k(graph.nodeCount()).value, targetNodeFilter, graph, optionalSimilarityFunction, config.similarityCutoff()); + var similarityFunction = optionalSimilarityFunction.orElse(new SimilarityFunction(SimilarityComputer.ofProperties( + graph, + config.nodeProperties() + ))); + var knn = new Knn( + graph, + context.progressTracker(), + context.executor(), + config.k(graph.nodeCount()), + config.concurrency(), + config.minBatchSize(), + config.maxIterations(), + config.similarityCutoff(), + config.perturbationRate(), + config.randomJoins(), + config.randomSeed(), + config.initialSampler(), + similarityFunction, + new KnnNeighborFilterFactory(graph.nodeCount()), + targetNodeFiltering + ); var sourceNodeFilter = config.sourceNodeFilter().toNodeFilter(graph); return new FilteredKnn(context.progressTracker(), knn, targetNodeFiltering, sourceNodeFilter); diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/GenerateRandomNeighbors.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/GenerateRandomNeighbors.java index 0c0adc4914..997005c847 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/knn/GenerateRandomNeighbors.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/GenerateRandomNeighbors.java @@ -19,7 +19,7 @@ */ package org.neo4j.gds.similarity.knn; -import org.neo4j.gds.collections.ha.HugeObjectArray; +import org.jetbrains.annotations.NotNull; import org.neo4j.gds.core.utils.partition.Partition; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; @@ -29,39 +29,78 @@ * Initial step in KNN calculation. */ final class GenerateRandomNeighbors implements Runnable { + + static final class Factory { + private final SimilarityFunction similarityFunction; + private final NeighbourConsumers neighbourConsumers; + private final int boundedK; + private final SplittableRandom random; + private final ProgressTracker progressTracker; + + Factory( + SimilarityFunction similarityFunction, + NeighbourConsumers neighbourConsumers, + int boundedK, + SplittableRandom random, + ProgressTracker progressTracker + ) { + this.similarityFunction = similarityFunction; + this.neighbourConsumers = neighbourConsumers; + this.boundedK = boundedK; + this.random = random; + this.progressTracker = progressTracker; + } + + @NotNull GenerateRandomNeighbors create( + Partition partition, + Neighbors neighbors, + KnnSampler sampler, + NeighborFilter neighborFilter + ) { + return new GenerateRandomNeighbors( + partition, + neighbors, + sampler, + neighborFilter, + similarityFunction, + neighbourConsumers, + boundedK, + random.split(), + progressTracker + ); + } + } + + private final Partition partition; + private final Neighbors neighbors; private final KnnSampler sampler; + private final NeighborFilter neighborFilter; private final SplittableRandom random; private final SimilarityFunction similarityFunction; - private final NeighborFilter neighborFilter; - private final HugeObjectArray neighbors; + private final NeighbourConsumers neighbourConsumers; private final int boundedK; private final ProgressTracker progressTracker; - private final Partition partition; - private final NeighbourConsumers neighbourConsumers; - - private long neighborsFound; GenerateRandomNeighbors( + Partition partition, + Neighbors neighbors, KnnSampler sampler, - SplittableRandom random, - SimilarityFunction similarityFunction, NeighborFilter neighborFilter, - HugeObjectArray neighbors, + SimilarityFunction similarityFunction, + NeighbourConsumers neighbourConsumers, int boundedK, - Partition partition, - ProgressTracker progressTracker, - NeighbourConsumers neighbourConsumers + SplittableRandom random, + ProgressTracker progressTracker ) { + this.partition = partition; + this.neighbors = neighbors; this.sampler = sampler; + this.neighborFilter = neighborFilter; this.random = random; this.similarityFunction = similarityFunction; - this.neighborFilter = neighborFilter; - this.neighbors = neighbors; + this.neighbourConsumers = neighbourConsumers; this.boundedK = boundedK; this.progressTracker = progressTracker; - this.partition = partition; - this.neighborsFound = 0; - this.neighbourConsumers = neighbourConsumers; } @Override @@ -88,12 +127,7 @@ public void run() { assert neighbors.size() >= Math.min(neighborFilter.lowerBoundOfPotentialNeighbours(nodeId), boundedK); this.neighbors.set(nodeId, neighbors); - neighborsFound += neighbors.size(); }); progressTracker.logProgress(partition.nodeCount()); } - - long neighborsFound() { - return neighborsFound; - } } diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/JoinNeighbors.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/JoinNeighbors.java new file mode 100644 index 0000000000..1f161b7278 --- /dev/null +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/JoinNeighbors.java @@ -0,0 +1,281 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.similarity.knn; + +import com.carrotsearch.hppc.LongArrayList; +import org.jetbrains.annotations.Nullable; +import org.neo4j.gds.collections.ha.HugeObjectArray; +import org.neo4j.gds.core.utils.partition.Partition; +import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; + +import java.util.SplittableRandom; + +final class JoinNeighbors implements Runnable { + + static class Factory { + private final SimilarityFunction similarityFunction; + private final int sampledK; + private final double perturbationRate; + private final int randomJoins; + private final SplittableRandom splittableRandom; + private final ProgressTracker progressTracker; + + Factory( + SimilarityFunction similarityFunction, + int sampledK, + double perturbationRate, + int randomJoins, + SplittableRandom splittableRandom, + ProgressTracker progressTracker + ) { + this.similarityFunction = similarityFunction; + this.sampledK = sampledK; + this.perturbationRate = perturbationRate; + this.randomJoins = randomJoins; + this.splittableRandom = splittableRandom; + this.progressTracker = progressTracker; + } + JoinNeighbors create( + Partition partition, + Neighbors allNeighbors, + HugeObjectArray allOldNeighbors, + HugeObjectArray allNewNeighbors, + HugeObjectArray allReverseOldNeighbors, + HugeObjectArray allReverseNewNeighbors, + NeighborFilter neighborFilter + ) { + return new JoinNeighbors( + partition, + allNeighbors, + allOldNeighbors, + allNewNeighbors, + allReverseOldNeighbors, + allReverseNewNeighbors, + neighborFilter, + similarityFunction, + sampledK, + perturbationRate, + randomJoins, + splittableRandom.split(), + progressTracker + ); + } + } + + private final SplittableRandom random; + private final SimilarityFunction similarityFunction; + private final NeighborFilter neighborFilter; + private final Neighbors allNeighbors; + private final HugeObjectArray allOldNeighbors; + private final HugeObjectArray allNewNeighbors; + private final HugeObjectArray allReverseOldNeighbors; + private final HugeObjectArray allReverseNewNeighbors; + private final int sampledK; + private final int randomJoins; + private final ProgressTracker progressTracker; + private final long nodeCount; + private final Partition partition; + private final double perturbationRate; + private long updateCount; + + JoinNeighbors( + Partition partition, + Neighbors allNeighbors, + HugeObjectArray allOldNeighbors, + HugeObjectArray allNewNeighbors, + HugeObjectArray allReverseOldNeighbors, + HugeObjectArray allReverseNewNeighbors, + NeighborFilter neighborFilter, + SimilarityFunction similarityFunction, + int sampledK, + double perturbationRate, + int randomJoins, + SplittableRandom random, + ProgressTracker progressTracker + ) { + this.random = random; + this.similarityFunction = similarityFunction; + this.neighborFilter = neighborFilter; + this.allNeighbors = allNeighbors; + this.nodeCount = allNewNeighbors.size(); + this.allOldNeighbors = allOldNeighbors; + this.allNewNeighbors = allNewNeighbors; + this.allReverseOldNeighbors = allReverseOldNeighbors; + this.allReverseNewNeighbors = allReverseNewNeighbors; + this.sampledK = sampledK; + this.randomJoins = randomJoins; + this.partition = partition; + this.progressTracker = progressTracker; + this.perturbationRate = perturbationRate; + this.updateCount = 0; + } + + @Override + public void run() { + var startNode = partition.startNode(); + long endNode = startNode + partition.nodeCount(); + + for (long nodeId = startNode; nodeId < endNode; nodeId++) { + // old[v] ∪ Sample(old′[v], ρK) + var oldNeighbors = allOldNeighbors.get(nodeId); + if (oldNeighbors != null) { + combineNeighbors(allReverseOldNeighbors.get(nodeId), oldNeighbors); + } + + + // new[v] ∪ Sample(new′[v], ρK) + var newNeighbors = allNewNeighbors.get(nodeId); + if (newNeighbors != null) { + combineNeighbors(allReverseNewNeighbors.get(nodeId), newNeighbors); + + this.updateCount += joinNewNeighbors(nodeId, oldNeighbors, newNeighbors); + } + + // this isn't in the paper + randomJoins(nodeCount, nodeId); + } + progressTracker.logProgress(partition.nodeCount()); + } + + private long joinNewNeighbors( + long nodeId, LongArrayList oldNeighbors, LongArrayList newNeighbors + ) { + long updateCount = 0; + + var newNeighborElements = newNeighbors.buffer; + var newNeighborsCount = newNeighbors.elementsCount; + boolean similarityIsSymmetric = similarityFunction.isSymmetric(); + + for (int i = 0; i < newNeighborsCount; i++) { + var elem1 = newNeighborElements[i]; + assert elem1 != nodeId; + + // join(u1, nodeId), this isn't in the paper + updateCount += join(elem1, nodeId); + + // try out using the new neighbors between themselves / join(new_nbd, new_ndb) + for (int j = i + 1; j < newNeighborsCount; j++) { + var elem2 = newNeighborElements[j]; + if (elem1 == elem2) { + continue; + } + + if (similarityIsSymmetric) { + updateCount += joinSymmetric(elem1, elem2); + } else { + updateCount += join(elem1, elem2); + updateCount += join(elem2, elem1); + } + } + + // try out joining the old neighbors with the new neighbor / join(new_nbd, old_ndb) + if (oldNeighbors != null) { + for (var oldElemCursor : oldNeighbors) { + var elem2 = oldElemCursor.value; + + if (elem1 == elem2) { + continue; + } + + if (similarityIsSymmetric) { + updateCount += joinSymmetric(elem1, elem2); + } else { + updateCount += join(elem1, elem2); + updateCount += join(elem2, elem1); + } + } + } + } + return updateCount; + } + + private void combineNeighbors(@Nullable LongArrayList reversedNeighbors, LongArrayList neighbors) { + if (reversedNeighbors != null) { + var numberOfReverseNeighbors = reversedNeighbors.size(); + for (var elem : reversedNeighbors) { + if (random.nextInt(numberOfReverseNeighbors) < sampledK) { + // TODO: this could add nodes twice, maybe? should this be a set? + neighbors.add(elem.value); + } + } + } + } + + private void randomJoins(long nodeCount, long nodeId) { + for (int i = 0; i < randomJoins; i++) { + var randomNodeId = random.nextLong(nodeCount - 1); + // shifting the randomNode as the randomNode was picked from [0, n-1) + if (randomNodeId >= nodeId) { + ++randomNodeId; + } + // random joins are not counted towards the actual update counter + join(nodeId, randomNodeId); + } + } + + private long joinSymmetric(long node1, long node2) { + assert node1 != node2; + + if (neighborFilter.excludeNodePair(node1, node2)) { + return 0; + } + + var similarity = similarityFunction.computeSimilarity(node1, node2); + + var neighbors1 = allNeighbors.getAndIncrementCounter(node1); + + var updates = 0L; + + synchronized (neighbors1) { + updates += neighbors1.add(node2, similarity, random, perturbationRate); + } + + var neighbors2 = allNeighbors.get(node2); + + synchronized (neighbors2) { + updates += neighbors2.add(node1, similarity, random, perturbationRate); + } + + return updates; + } + + private long join(long node1, long node2) { + assert node1 != node2; + + if (neighborFilter.excludeNodePair(node1, node2)) { + return 0; + } + + var similarity = similarityFunction.computeSimilarity(node1, node2); + var neighbors = allNeighbors.getAndIncrementCounter(node1); + + synchronized (neighbors) { + return neighbors.add(node2, similarity, random, perturbationRate); + } + } + + long nodePairsConsidered() { + return allNeighbors.joinCounter(); + } + + long updateCount() { + return updateCount; + } +} diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/K.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/K.java new file mode 100644 index 0000000000..f35f3ded66 --- /dev/null +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/K.java @@ -0,0 +1,55 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.similarity.knn; + +import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; + +public final class K { + public static K create(int k, long nodeCount, double sampleRate, double deltaThreshold) { + // user-provided k value must be at least 1 + if (k < 1) throw new IllegalArgumentException("K k must be 1 or more"); + // sampleRate -- value range (0.0;1.0] + if (Double.compare(sampleRate, 0.0) < 1 || Double.compare(sampleRate, 1.0) > 0) + throw new IllegalArgumentException("sampleRate must be more than 0.0 and less than or equal to 1.0"); + // deltaThreshold -- value range [0.0;1.0] + if (Double.compare(deltaThreshold, 0.0) < 0 || Double.compare(deltaThreshold, 1.0) > 0) + throw new IllegalArgumentException(formatWithLocale("deltaThreshold must be more than or equal to 0.0 and less than or equal to 1.0, was `%f`", deltaThreshold)); + + // (int) is safe because k is at most `topK`, which is an int + // upper bound for k is all other nodes in the graph + var boundedValue = Math.max(0, (int) Math.min(k, nodeCount - 1)); + var sampledValue = Math.max(0, (int) Math.min((long) Math.ceil(sampleRate * k), nodeCount - 1)); + + var maxUpdates = (long) Math.ceil(sampleRate * k * nodeCount); + var updateThreshold = (long) Math.floor(deltaThreshold * maxUpdates); + + return new K(boundedValue, sampledValue, updateThreshold); + } + + public final int value; + public final int sampledValue; + public final long updateThreshold; + + private K(int value, int sampledValue, long updateThreshold) { + this.value = value; + this.sampledValue = sampledValue; + this.updateThreshold = updateThreshold; + } +} diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/Knn.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/Knn.java index a1d95e70dd..42fddf413f 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/knn/Knn.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/Knn.java @@ -20,274 +20,215 @@ package org.neo4j.gds.similarity.knn; import com.carrotsearch.hppc.LongArrayList; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; import org.neo4j.gds.Algorithm; import org.neo4j.gds.api.Graph; import org.neo4j.gds.collections.ha.HugeObjectArray; import org.neo4j.gds.core.concurrency.ParallelUtil; import org.neo4j.gds.core.concurrency.RunWithConcurrency; -import org.neo4j.gds.core.utils.ProgressTimer; -import org.neo4j.gds.core.utils.partition.Partition; import org.neo4j.gds.core.utils.partition.PartitionUtils; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer; -import java.util.List; import java.util.Optional; import java.util.SplittableRandom; import java.util.concurrent.ExecutorService; import java.util.stream.LongStream; -import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; - public class Knn extends Algorithm { - private final Graph graph; - private final KnnBaseConfig config; - private final int concurrency; - private final NeighborFilterFactory neighborFilterFactory; - private final ExecutorService executorService; - private final SplittableRandom splittableRandom; - private final SimilarityFunction similarityFunction; - private final NeighbourConsumers neighborConsumers; - - private long nodePairsConsidered; - - public static Knn createWithDefaults(Graph graph, KnnBaseConfig config, KnnContext context) { - return createWithDefaultsAndInstrumentation(graph, config, context, NeighbourConsumers.no_op, defaultSimilarityFunction(graph, config.nodeProperties())); - } - - public static SimilarityFunction defaultSimilarityFunction(Graph graph, List nodeProperties) { - return defaultSimilarityFunction(SimilarityComputer.ofProperties(graph, nodeProperties)); - } - - private static SimilarityFunction defaultSimilarityFunction(SimilarityComputer similarityComputer) { - return new SimilarityFunction(similarityComputer); - } - - @NotNull - public static Knn createWithDefaultsAndInstrumentation( - Graph graph, - KnnBaseConfig config, - KnnContext context, - NeighbourConsumers neighborConsumers, - SimilarityFunction similarityFunction - ) { - return new Knn( - context.progressTracker(), - graph, - config, - similarityFunction, - new KnnNeighborFilterFactory(graph.nodeCount()), - context.executor(), - getSplittableRandom(config.randomSeed()), - neighborConsumers - ); - } public static Knn create( Graph graph, - KnnBaseConfig config, + KnnParameters parameters, SimilarityComputer similarityComputer, NeighborFilterFactory neighborFilterFactory, KnnContext context ) { - SplittableRandom splittableRandom = getSplittableRandom(config.randomSeed()); - SimilarityFunction similarityFunction = defaultSimilarityFunction(similarityComputer); + var similarityFunction = new SimilarityFunction(similarityComputer); return new Knn( - context.progressTracker(), graph, - config, + context.progressTracker(), + context.executor(), + parameters.kHolder(), + parameters.concurrency(), + parameters.minBatchSize(), + parameters.maxIterations(), + parameters.similarityCutoff(), + parameters.perturbationRate(), + parameters.randomJoins(), + parameters.randomSeed(), + parameters.samplerType(), similarityFunction, neighborFilterFactory, - context.executor(), - splittableRandom, NeighbourConsumers.no_op ); } - @NotNull - private static SplittableRandom getSplittableRandom(Optional randomSeed) { - return randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new); - } + private final Graph graph; + private final int concurrency; + private final int maxIterations; + private final double similarityCutoff; + private final int minBatchSize; + private final NeighborFilterFactory neighborFilterFactory; + private final ExecutorService executorService; + private final KnnSampler.Factory samplerFactory; + private final JoinNeighbors.Factory joinNeighborsFactory; + private final GenerateRandomNeighbors.Factory generateRandomNeighborsFactory; + private final SplitOldAndNewNeighbors.Factory splitOldAndNewNeighborsFactory; + private final long updateThreshold; - Knn( - ProgressTracker progressTracker, + public Knn( Graph graph, - KnnBaseConfig config, + ProgressTracker progressTracker, + ExecutorService executorService, + K k, + int concurrency, + int minBatchSize, + int maxIterations, + double similarityCutoff, + double perturbationRate, + int randomJoins, + Optional randomSeed, + KnnSampler.SamplerType initialSamplerType, SimilarityFunction similarityFunction, NeighborFilterFactory neighborFilterFactory, - ExecutorService executorService, - SplittableRandom splittableRandom, NeighbourConsumers neighborConsumers ) { super(progressTracker); this.graph = graph; - this.config = config; - this.concurrency = config.concurrency(); - this.similarityFunction = similarityFunction; + this.concurrency = concurrency; + this.maxIterations = maxIterations; + this.similarityCutoff = similarityCutoff; + this.minBatchSize = minBatchSize; this.neighborFilterFactory = neighborFilterFactory; this.executorService = executorService; - this.splittableRandom = splittableRandom; - this.neighborConsumers = neighborConsumers; + + this.updateThreshold = k.updateThreshold; + + var splittableRandom = randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new); + switch (initialSamplerType) { + case UNIFORM: + this.samplerFactory = new UniformKnnSampler.Factory(graph.nodeCount(), splittableRandom); + break; + case RANDOMWALK: + this.samplerFactory = new RandomWalkKnnSampler.Factory(graph, randomSeed, k.value, splittableRandom); + break; + default: + throw new IllegalStateException("Invalid KnnSampler"); + } + this.generateRandomNeighborsFactory = new GenerateRandomNeighbors.Factory( + similarityFunction, + neighborConsumers, + k.value, + splittableRandom, + progressTracker + ); + this.splitOldAndNewNeighborsFactory = new SplitOldAndNewNeighbors.Factory( + k.sampledValue, + splittableRandom, + progressTracker + ); + this.joinNeighborsFactory = new JoinNeighbors.Factory( + similarityFunction, + k.sampledValue, + perturbationRate, + randomJoins, + splittableRandom, + progressTracker + ); } public ExecutorService executorService() { - return this.executorService; + return executorService; } @Override public KnnResult compute() { - this.progressTracker.beginSubTask(); - HugeObjectArray neighbors; - try (var ignored1 = ProgressTimer.start(this::logOverallTime)) { - try (var ignored2 = ProgressTimer.start(this::logInitTime)) { - this.progressTracker.beginSubTask(); - neighbors = this.initializeRandomNeighbors(); - this.progressTracker.endSubTask(); - } - if (neighbors == null) { - return new EmptyResult(); - } - - var maxIterations = this.config.maxIterations(); - var maxUpdates = (long) Math.ceil(this.config.sampleRate() * this.config.topK() * graph.nodeCount()); - var updateThreshold = (long) Math.floor(this.config.deltaThreshold() * maxUpdates); + if (graph.nodeCount() < 2) { + return new EmptyResult(); + } + progressTracker.beginSubTask(); + progressTracker.beginSubTask(); + var neighbors = initializeRandomNeighbors(); + progressTracker.endSubTask(); - long updateCount; - int iteration = 0; - boolean didConverge = false; + long updateCount; + int iteration = 0; + boolean didConverge = false; - this.progressTracker.beginSubTask(); - for (; iteration < maxIterations; iteration++) { - int currentIteration = iteration; - try (var ignored3 = ProgressTimer.start(took -> this.logIterationTime(currentIteration + 1, took))) { - updateCount = iteration(neighbors); - } - if (updateCount <= updateThreshold) { - iteration++; - didConverge = true; - break; - } - } - if (config.similarityCutoff() > 0) { - var similarityCutoff = config.similarityCutoff(); - var neighborFilterTasks = PartitionUtils.rangePartition( - concurrency, - neighbors.size(), - partition -> (Runnable) () -> partition.consume( - nodeId -> neighbors.get(nodeId).filterHighSimilarityResults(similarityCutoff) - ), - Optional.of(config.minBatchSize()) - ); - RunWithConcurrency.builder() - .concurrency(concurrency) - .tasks(neighborFilterTasks) - .terminationFlag(terminationFlag) - .executor(this.executorService) - .run(); + progressTracker.beginSubTask(); + for (; iteration < maxIterations; iteration++) { + updateCount = iteration(neighbors); + if (updateCount <= updateThreshold) { + iteration++; + didConverge = true; + break; } - this.progressTracker.endSubTask(); - - this.progressTracker.endSubTask(); - return ImmutableKnnResult.of( - neighbors, - iteration, - didConverge, - this.nodePairsConsidered, - graph.nodeCount() + } + if (similarityCutoff > 0) { + var neighborFilterTasks = PartitionUtils.rangePartition( + concurrency, + neighbors.size(), + partition -> (Runnable) () -> partition.consume( + nodeId -> neighbors.filterHighSimilarityResult(nodeId, similarityCutoff) + ), + Optional.of(minBatchSize) ); + RunWithConcurrency.builder() + .concurrency(concurrency) + .tasks(neighborFilterTasks) + .terminationFlag(terminationFlag) + .executor(executorService) + .run(); } - } - - private @Nullable HugeObjectArray initializeRandomNeighbors() { - var k = this.config.topK(); - // (int) is safe since it is at most k, which is an int - var boundedK = (int) Math.min(graph.nodeCount() - 1, k); - - assert boundedK <= k && boundedK <= graph.nodeCount() - 1; + progressTracker.endSubTask(); - if (graph.nodeCount() < 2 || k == 0) { - return null; - } + progressTracker.endSubTask(); + return ImmutableKnnResult.of( + neighbors.data(), + iteration, + didConverge, + neighbors.neighborsFound() + neighbors.joinCounter(), + graph.nodeCount() + ); + } - var neighbors = HugeObjectArray.newArray(NeighborList.class, graph.nodeCount()); + private Neighbors initializeRandomNeighbors() { + var neighbors = new Neighbors(graph.nodeCount()); var randomNeighborGenerators = PartitionUtils.rangePartition( concurrency, graph.nodeCount(), - partition -> { - var localRandom = splittableRandom.split(); - return new GenerateRandomNeighbors( - initializeSampler(localRandom), - localRandom, - this.similarityFunction, - this.neighborFilterFactory.create(), - neighbors, - boundedK, - partition, - progressTracker, - neighborConsumers - ); - }, - Optional.of(config.minBatchSize()) + partition -> generateRandomNeighborsFactory.create( + partition, + neighbors, + samplerFactory.create(), + neighborFilterFactory.create() + ), + Optional.of(minBatchSize) ); RunWithConcurrency.builder() .concurrency(concurrency) .tasks(randomNeighborGenerators) .terminationFlag(terminationFlag) - .executor(this.executorService) + .executor(executorService) .run(); - this.nodePairsConsidered += randomNeighborGenerators.stream().mapToLong(GenerateRandomNeighbors::neighborsFound).sum(); - return neighbors; } - private KnnSampler initializeSampler(SplittableRandom random) { - switch(config.initialSampler()) { - case UNIFORM: { - return new UniformKnnSampler(random, graph.nodeCount()); - } - case RANDOMWALK: { - return new RandomWalkKnnSampler( - graph.concurrentCopy(), - random, - config.randomSeed(), - config.boundedK(graph.nodeCount()) - ); - } - default: - throw new IllegalStateException("Invalid KnnSampler"); - } - } - - private long iteration(HugeObjectArray neighbors) { - // this is a sanity check - // we check for this before any iteration and return - // and just make sure that this invariant holds on every iteration + private long iteration(Neighbors neighbors) { var nodeCount = graph.nodeCount(); - if (nodeCount < 2 || this.config.topK() == 0) { - return NeighborList.NOT_INSERTED; - } - - var minBatchSize = this.config.minBatchSize(); - - var sampledK = this.config.sampledK(nodeCount); // TODO: init in ctor and reuse - benchmark against new allocations var allOldNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount); var allNewNeighbors = HugeObjectArray.newArray(LongArrayList.class, nodeCount); progressTracker.beginSubTask(); - ParallelUtil.readParallel(concurrency, nodeCount, this.executorService, new SplitOldAndNewNeighbors( - this.splittableRandom, + ParallelUtil.readParallel(concurrency, nodeCount, executorService, splitOldAndNewNeighborsFactory.create( neighbors, allOldNeighbors, - allNewNeighbors, - sampledK, - progressTracker + allNewNeighbors )); progressTracker.endSubTask(); @@ -310,20 +251,14 @@ private long iteration(HugeObjectArray neighbors) { var neighborsJoiners = PartitionUtils.rangePartition( concurrency, nodeCount, - partition -> new JoinNeighbors( - this.splittableRandom.split(), - this.similarityFunction, - this.neighborFilterFactory.create(), + partition -> joinNeighborsFactory.create( + partition, neighbors, allOldNeighbors, allNewNeighbors, reverseOldNeighbors, reverseNewNeighbors, - sampledK, - this.config.perturbationRate(), - this.config.randomJoins(), - partition, - progressTracker + neighborFilterFactory.create() ), Optional.of(minBatchSize) ); @@ -333,13 +268,11 @@ private long iteration(HugeObjectArray neighbors) { .concurrency(concurrency) .tasks(neighborsJoiners) .terminationFlag(terminationFlag) - .executor(this.executorService) + .executor(executorService) .run(); progressTracker.endSubTask(); - this.nodePairsConsidered += neighborsJoiners.stream().mapToLong(JoinNeighbors::nodePairsConsidered).sum(); - - return neighborsJoiners.stream().mapToLong(joiner -> joiner.updateCount).sum(); + return neighborsJoiners.stream().mapToLong(JoinNeighbors::updateCount).sum(); } private static void reverseOldAndNewNeighbors( @@ -385,219 +318,6 @@ static void reverseNeighbors( } } - static final class JoinNeighbors implements Runnable { - private final SplittableRandom random; - private final SimilarityFunction similarityFunction; - private final NeighborFilter neighborFilter; - private final HugeObjectArray allNeighbors; - private final HugeObjectArray allOldNeighbors; - private final HugeObjectArray allNewNeighbors; - private final HugeObjectArray allReverseOldNeighbors; - private final HugeObjectArray allReverseNewNeighbors; - private final int sampledK; - private final int randomJoins; - private final ProgressTracker progressTracker; - private final long nodeCount; - private long updateCount; - private final Partition partition; - private long nodePairsConsidered; - private final double perturbationRate; - - JoinNeighbors( - SplittableRandom random, - SimilarityFunction similarityFunction, - NeighborFilter neighborFilter, - HugeObjectArray allNeighbors, - HugeObjectArray allOldNeighbors, - HugeObjectArray allNewNeighbors, - HugeObjectArray allReverseOldNeighbors, - HugeObjectArray allReverseNewNeighbors, - int sampledK, - double perturbationRate, - int randomJoins, - Partition partition, - ProgressTracker progressTracker - ) { - this.random = random; - this.similarityFunction = similarityFunction; - this.neighborFilter = neighborFilter; - this.allNeighbors = allNeighbors; - this.nodeCount = allNewNeighbors.size(); - this.allOldNeighbors = allOldNeighbors; - this.allNewNeighbors = allNewNeighbors; - this.allReverseOldNeighbors = allReverseOldNeighbors; - this.allReverseNewNeighbors = allReverseNewNeighbors; - this.sampledK = sampledK; - this.randomJoins = randomJoins; - this.partition = partition; - this.progressTracker = progressTracker; - this.perturbationRate = perturbationRate; - this.updateCount = 0; - this.nodePairsConsidered = 0; - } - - @Override - public void run() { - var startNode = partition.startNode(); - long endNode = startNode + partition.nodeCount(); - - for (long nodeId = startNode; nodeId < endNode; nodeId++) { - // old[v] ∪ Sample(old′[v], ρK) - var oldNeighbors = allOldNeighbors.get(nodeId); - if (oldNeighbors != null) { - combineNeighbors(allReverseOldNeighbors.get(nodeId), oldNeighbors); - } - - - // new[v] ∪ Sample(new′[v], ρK) - var newNeighbors = allNewNeighbors.get(nodeId); - if (newNeighbors != null) { - combineNeighbors(allReverseNewNeighbors.get(nodeId), newNeighbors); - - this.updateCount += joinNewNeighbors(nodeId, oldNeighbors, newNeighbors); - } - - // this isn't in the paper - randomJoins(nodeCount, nodeId); - } - progressTracker.logProgress(partition.nodeCount()); - } - - private long joinNewNeighbors(long nodeId, LongArrayList oldNeighbors, LongArrayList newNeighbors - ) { - long updateCount = 0; - - var newNeighborElements = newNeighbors.buffer; - var newNeighborsCount = newNeighbors.elementsCount; - boolean similarityIsSymmetric = similarityFunction.isSymmetric(); - - for (int i = 0; i < newNeighborsCount; i++) { - var elem1 = newNeighborElements[i]; - assert elem1 != nodeId; - - // join(u1, nodeId), this isn't in the paper - updateCount += join(elem1, nodeId); - - // try out using the new neighbors between themselves / join(new_nbd, new_ndb) - for (int j = i + 1; j < newNeighborsCount; j++) { - var elem2 = newNeighborElements[j]; - if (elem1 == elem2) { - continue; - } - - if (similarityIsSymmetric) { - updateCount += joinSymmetric(elem1, elem2); - } else { - updateCount += join(elem1, elem2); - updateCount += join(elem2, elem1); - } - } - - // try out joining the old neighbors with the new neighbor / join(new_nbd, old_ndb) - if (oldNeighbors != null) { - for (var oldElemCursor : oldNeighbors) { - var elem2 = oldElemCursor.value; - - if (elem1 == elem2) { - continue; - } - - if (similarityIsSymmetric) { - updateCount += joinSymmetric(elem1, elem2); - } else { - updateCount += join(elem1, elem2); - updateCount += join(elem2, elem1); - } - } - } - } - return updateCount; - } - - private void combineNeighbors(@Nullable LongArrayList reversedNeighbors, LongArrayList neighbors) { - if (reversedNeighbors != null) { - var numberOfReverseNeighbors = reversedNeighbors.size(); - for (var elem : reversedNeighbors) { - if (random.nextInt(numberOfReverseNeighbors) < sampledK) { - // TODO: this could add nodes twice, maybe? should this be a set? - neighbors.add(elem.value); - } - } - } - } - - private void randomJoins(long nodeCount, long nodeId) { - for (int i = 0; i < randomJoins; i++) { - var randomNodeId = random.nextLong(nodeCount - 1); - // shifting the randomNode as the randomNode was picked from [0, n-1) - if (randomNodeId >= nodeId) { - ++randomNodeId; - } - // random joins are not counted towards the actual update counter - join(nodeId, randomNodeId); - } - } - - private long joinSymmetric(long node1, long node2) { - assert node1 != node2; - - if (neighborFilter.excludeNodePair(node1, node2)) { - return 0; - } - - nodePairsConsidered++; - var similarity = similarityFunction.computeSimilarity(node1, node2); - - var neighbors1 = allNeighbors.get(node1); - - var updates = 0L; - - synchronized (neighbors1) { - updates += neighbors1.add(node2, similarity, random, perturbationRate); - } - - var neighbors2 = allNeighbors.get(node2); - - synchronized (neighbors2) { - updates += neighbors2.add(node1, similarity, random, perturbationRate); - } - - return updates; - } - - private long join(long node1, long node2) { - assert node1 != node2; - - if (neighborFilter.excludeNodePair(node1, node2)) { - return 0; - } - - var similarity = similarityFunction.computeSimilarity(node1, node2); - nodePairsConsidered++; - var neighbors = allNeighbors.get(node1); - - synchronized (neighbors) { - return neighbors.add(node2, similarity, random, perturbationRate); - } - } - - long nodePairsConsidered() { - return nodePairsConsidered; - } - } - - private void logInitTime(long ms) { - progressTracker.logInfo(formatWithLocale("Graph init took %d ms", ms)); - } - - private void logIterationTime(int iteration, long ms) { - progressTracker.logInfo(formatWithLocale("Graph iteration %d took %d ms", iteration, ms)); - } - - private void logOverallTime(long ms) { - progressTracker.logInfo(formatWithLocale("Graph execution took %d ms", ms)); - } - private static final class EmptyResult extends KnnResult { @Override diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnBaseConfig.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnBaseConfig.java index 8f93b84c24..73a1653995 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnBaseConfig.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnBaseConfig.java @@ -86,24 +86,36 @@ default int randomJoins() { return 10; } - @Configuration.Ignore - default int sampledK(long nodeCount) { - // (int) is safe because value is at most `topK`, which is an int - // This could be violated if a sampleRate outside of [0,1] is used - // which is only possible from our tests - return Math.max(0, (int) Math.min((long) Math.ceil(this.sampleRate() * this.topK()), nodeCount - 1)); + @Value.Default + @Configuration.ConvertWith(method = "org.neo4j.gds.similarity.knn.KnnSampler.SamplerType#parse") + @Configuration.ToMapValue("org.neo4j.gds.similarity.knn.KnnSampler.SamplerType#toString") + default KnnSampler.SamplerType initialSampler() { + return KnnSampler.SamplerType.UNIFORM; } + @Value.Default @Configuration.Ignore - default int boundedK(long nodeCount) { - // (int) is safe because value is at most `topK`, which is an int - return Math.max(0, (int) Math.min(this.topK(), nodeCount - 1)); + default K k(long nodeCount) { + return K.create(topK(), nodeCount, sampleRate(), deltaThreshold()); } @Value.Default - @Configuration.ConvertWith(method = "org.neo4j.gds.similarity.knn.KnnSampler.SamplerType#parse") - @Configuration.ToMapValue("org.neo4j.gds.similarity.knn.KnnSampler.SamplerType#toString") - default KnnSampler.SamplerType initialSampler() { - return KnnSampler.SamplerType.UNIFORM; + @Configuration.Ignore + default KnnParametersSansNodeCount toParameters() { + return KnnParametersSansNodeCount.create( + concurrency(), + maxIterations(), + similarityCutoff(), + deltaThreshold(), + sampleRate(), + topK(), + perturbationRate(), + randomJoins(), + minBatchSize(), + initialSampler(), + randomSeed(), + nodeProperties() + ); } + } diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnFactory.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnFactory.java index c41dd3525d..e5003451ff 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnFactory.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnFactory.java @@ -30,6 +30,7 @@ import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.core.utils.progress.tasks.Task; import org.neo4j.gds.core.utils.progress.tasks.Tasks; +import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer; import java.util.List; import java.util.function.LongFunction; @@ -48,15 +49,16 @@ public String taskName() { return KNN_BASE_TASK_NAME; } - @Override public Knn build( Graph graph, - CONFIG configuration, + KnnParameters parameters, ProgressTracker progressTracker ) { - return Knn.createWithDefaults( + return Knn.create( graph, - configuration, + parameters, + SimilarityComputer.ofProperties(graph, parameters.nodePropertySpecs()), + new KnnNeighborFilterFactory(graph.nodeCount()), ImmutableKnnContext .builder() .progressTracker(progressTracker) @@ -65,6 +67,16 @@ public Knn build( ); } + @Override + public Knn build( + Graph graph, + CONFIG configuration, + ProgressTracker progressTracker + ) { + var parameters = configuration.toParameters().finalize(graph.nodeCount()); + return build(graph, parameters, progressTracker); + } + @Override public MemoryEstimation memoryEstimation(CONFIG configuration) { return KnnFactory.memoryEstimation(taskName(), Knn.class, configuration); @@ -112,18 +124,17 @@ public static MemoryEstimation memoryEstimation( return MemoryEstimations.setup( taskName, (dim, concurrency) -> { - var boundedK = configuration.boundedK(dim.nodeCount()); - var sampledK = configuration.sampledK(dim.nodeCount()); + var k = configuration.k(dim.nodeCount()); LongFunction tempListEstimation = nodeCount -> MemoryRange.of( HugeObjectArray.memoryEstimation(nodeCount, 0), HugeObjectArray.memoryEstimation( nodeCount, - sizeOfInstance(LongArrayList.class) + sizeOfLongArray(sampledK) + sizeOfInstance(LongArrayList.class) + sizeOfLongArray(k.sampledValue) ) ); - var neighborListEstimate = NeighborList.memoryEstimation(boundedK) + var neighborListEstimate = NeighborList.memoryEstimation(k.value) .estimate(dim, concurrency) .memoryUsage(); @@ -142,13 +153,13 @@ public static MemoryEstimation memoryEstimation( .fixed( "initial-random-neighbors (per thread)", KnnFactory - .initialSamplerMemoryEstimation(configuration.initialSampler(), boundedK) + .initialSamplerMemoryEstimation(configuration.initialSampler(), k.value) .times(concurrency) ) .fixed( "sampled-random-neighbors (per thread)", MemoryRange.of( - sizeOfIntArray(sizeOfOpenHashContainer(sampledK)) * concurrency + sizeOfIntArray(sizeOfOpenHashContainer(k.sampledValue)) * concurrency ) ) .build(); diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnParameters.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnParameters.java new file mode 100644 index 0000000000..135b4cbf50 --- /dev/null +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnParameters.java @@ -0,0 +1,150 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.similarity.knn; + +import java.util.List; +import java.util.Optional; + +public class KnnParameters { + + static KnnParameters create( + long nodeCount, + int concurrency, + int maxIterations, + double similarityCutoff, + double deltaThreshold, + double sampleRate, + int rawK, + double perturbationRate, + int randomJoins, + int minBatchSize, + KnnSampler.SamplerType samplerType, + Optional randomSeed, + List nodePropertySpecs + ) { + // concurrency -- no test atm, it probably shouldn't be here + // maxIterations -- must be 1 or more + if (maxIterations < 1) throw new IllegalArgumentException("maxIterations"); + // similarityCutoff -- value range [0.0;1.0] + if (Double.compare(similarityCutoff, 0.0) < 0 || Double.compare(similarityCutoff, 1.0) > 0) + throw new IllegalArgumentException("similarityCutoff must be more than or equal to 0.0 and less than or equal to 1.0"); + // sampleRate -- value range (0.0;1.0] + if (Double.compare(sampleRate, 0.0) < 1 || Double.compare(sampleRate, 1.0) > 0) + throw new IllegalArgumentException("sampleRate must be more than 0.0 and less than or equal to 1.0"); + // deltaThreshold -- value range [0.0;1.0] + if (Double.compare(deltaThreshold, 0.0) < 0 || Double.compare(deltaThreshold, 1.0) > 0) + throw new IllegalArgumentException("deltaThreshold must be more than or equal to 0.0 and less than or equal to 1.0"); + // rawK -- user provided k value must be at least 1 + if (rawK < 1) throw new IllegalArgumentException("K k must be 1 or more"); + // perturbationRate -- value range [0.0;1.0] + if (Double.compare(perturbationRate, 0.0) < 0 || Double.compare(perturbationRate, 1.0) > 0) + throw new IllegalArgumentException("perturbationRate must be more than or equal to 0.0 and less than or equal to 1.0"); + // randomJoins -- 0 or more + if (randomJoins < 0) throw new IllegalArgumentException("randomJoins must be 0 or more"); + + return new KnnParameters( + concurrency, + maxIterations, + similarityCutoff, + K.create(rawK, nodeCount, sampleRate, deltaThreshold), + perturbationRate, + randomJoins, + minBatchSize, + samplerType, + randomSeed, + nodePropertySpecs + ); + } + + private final int concurrency; + private final int maxIterations; + private final double similarityCutoff; + private final K kHolder; + private final double perturbationRate; + private final int randomJoins; + private final int minBatchSize; + private final KnnSampler.SamplerType samplerType; + private final Optional randomSeed; + private final List nodePropertySpecs; + + private KnnParameters( + int concurrency, + int maxIterations, + double similarityCutoff, + K kHolder, + double perturbationRate, + int randomJoins, + int minBatchSize, + KnnSampler.SamplerType samplerType, + Optional randomSeed, + List nodePropertySpecs + ) { + this.concurrency = concurrency; + this.maxIterations = maxIterations; + this.similarityCutoff = similarityCutoff; + this.kHolder = kHolder; + this.perturbationRate = perturbationRate; + this.randomJoins = randomJoins; + this.minBatchSize = minBatchSize; + this.samplerType = samplerType; + this.randomSeed = randomSeed; + this.nodePropertySpecs = nodePropertySpecs; + } + + int concurrency() { + return concurrency; + } + + int maxIterations() { + return maxIterations; + } + + double similarityCutoff() { + return similarityCutoff; + } + + K kHolder() { + return kHolder; + } + + double perturbationRate() { + return perturbationRate; + } + + int randomJoins() { + return randomJoins; + } + + int minBatchSize() { + return minBatchSize; + } + + KnnSampler.SamplerType samplerType() { + return samplerType; + } + + Optional randomSeed() { + return randomSeed; + } + + List nodePropertySpecs() { + return nodePropertySpecs; + } +} diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnParametersSansNodeCount.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnParametersSansNodeCount.java new file mode 100644 index 0000000000..74c89356bd --- /dev/null +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnParametersSansNodeCount.java @@ -0,0 +1,135 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.similarity.knn; + +import java.util.List; +import java.util.Optional; + +public class KnnParametersSansNodeCount { + + static KnnParametersSansNodeCount create( + int concurrency, + int maxIterations, + double similarityCutoff, + double deltaThreshold, + double sampleRate, + int rawK, + double perturbationRate, + int randomJoins, + int minBatchSize, + KnnSampler.SamplerType samplerType, + Optional randomSeed, + List nodePropertySpecs + ) { + // concurrency -- no test atm, it probably shouldn't be here + // maxIterations -- must be 1 or more + if (maxIterations < 1) throw new IllegalArgumentException("maxIterations"); + // similarityCutoff -- value range [0.0;1.0] + if (Double.compare(similarityCutoff, 0.0) < 0 || Double.compare(similarityCutoff, 1.0) > 0) + throw new IllegalArgumentException("similarityCutoff must be more than or equal to 0.0 and less than or equal to 1.0"); + // sampleRate -- value range (0.0;1.0] + if (Double.compare(sampleRate, 0.0) < 1 || Double.compare(sampleRate, 1.0) > 0) + throw new IllegalArgumentException("sampleRate must be more than 0.0 and less than or equal to 1.0"); + // deltaThreshold -- value range [0.0;1.0] + if (Double.compare(deltaThreshold, 0.0) < 0 || Double.compare(deltaThreshold, 1.0) > 0) + throw new IllegalArgumentException("deltaThreshold must be more than or equal to 0.0 and less than or equal to 1.0"); + // rawK -- user provided k value must be at least 1 + if (rawK < 1) throw new IllegalArgumentException("K k must be 1 or more"); + // perturbationRate -- value range [0.0;1.0] + if (Double.compare(perturbationRate, 0.0) < 0 || Double.compare(perturbationRate, 1.0) > 0) + throw new IllegalArgumentException("perturbationRate must be more than or equal to 0.0 and less than or equal to 1.0"); + // randomJoins -- 0 or more + if (randomJoins < 0) throw new IllegalArgumentException("randomJoins must be 0 or more"); + + return new KnnParametersSansNodeCount( + concurrency, + maxIterations, + similarityCutoff, + deltaThreshold, + sampleRate, + rawK, + perturbationRate, + randomJoins, + minBatchSize, + samplerType, + randomSeed, + nodePropertySpecs + ); + } + + private final int concurrency; + private final int maxIterations; + private final double similarityCutoff; + private final double deltaThreshold; + private final double sampleRate; + private final int rawK; + private final double perturbationRate; + private final int randomJoins; + private final int minBatchSize; + private final KnnSampler.SamplerType samplerType; + private final Optional randomSeed; + private final List nodePropertySpecs; + + public KnnParametersSansNodeCount( + int concurrency, + int maxIterations, + double similarityCutoff, + double deltaThreshold, + double sampleRate, + int k, + double perturbationRate, + int randomJoins, + int minBatchSize, + KnnSampler.SamplerType samplerType, + Optional randomSeed, + List nodePropertySpecs + ) { + this.concurrency = concurrency; + this.maxIterations = maxIterations; + this.similarityCutoff = similarityCutoff; + this.deltaThreshold = deltaThreshold; + this.sampleRate = sampleRate; + this.rawK = k; + this.perturbationRate = perturbationRate; + this.randomJoins = randomJoins; + this.minBatchSize = minBatchSize; + this.samplerType = samplerType; + this.randomSeed = randomSeed; + this.nodePropertySpecs = nodePropertySpecs; + } + + public KnnParameters finalize(long nodeCount) { + return KnnParameters.create( + nodeCount, + concurrency, + maxIterations, + similarityCutoff, + deltaThreshold, + sampleRate, + rawK, + perturbationRate, + randomJoins, + minBatchSize, + samplerType, + randomSeed, + nodePropertySpecs + ); + } +} diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnSampler.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnSampler.java index 68973b660d..7cc53c62e6 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnSampler.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/KnnSampler.java @@ -26,6 +26,11 @@ import java.util.stream.Collectors; public interface KnnSampler { + + interface Factory { + KnnSampler create(); + } + long[] sample( long nodeId, long lowerBoundOnValidSamplesInRange, diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/Neighbors.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/Neighbors.java new file mode 100644 index 0000000000..5bd55d498e --- /dev/null +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/Neighbors.java @@ -0,0 +1,73 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.similarity.knn; + +import org.neo4j.gds.collections.ha.HugeObjectArray; + +final class Neighbors { + private final HugeObjectArray neighbors; + private long neighborsFound; + private long joinCounter; + + Neighbors(long nodeCount) { + this.neighbors = HugeObjectArray.newArray(NeighborList.class, nodeCount); + } + + Neighbors(HugeObjectArray neighbors) { + this.neighbors = neighbors; + } + + NeighborList get(long nodeId) { + return neighbors.get(nodeId); + } + + NeighborList getAndIncrementCounter(long nodeId) { + incrementJoinCounter(); + return get(nodeId); + } + void set(long nodeId, NeighborList neighborList) { + neighbors.set(nodeId, neighborList); + neighborsFound += neighborList.size(); + } + + long size() { + return neighbors.size(); + } + + long neighborsFound() { + return neighborsFound; + } + + void incrementJoinCounter() { + joinCounter++; + } + + long joinCounter() { + return joinCounter; + } + + void filterHighSimilarityResult(long nodeId, double similarityCutoff) { + neighbors.get(nodeId).filterHighSimilarityResults(similarityCutoff); + } + + HugeObjectArray data() { + return neighbors; + } +} diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/RandomWalkKnnSampler.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/RandomWalkKnnSampler.java index 9b9f5e0dd5..c817e5b308 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/knn/RandomWalkKnnSampler.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/RandomWalkKnnSampler.java @@ -126,4 +126,22 @@ public long[] sample( return samples; } + + public static class Factory implements KnnSampler.Factory { + private final Graph graph; + private final Optional randomSeed; + private final int boundedK; + private final SplittableRandom random; + + Factory(Graph graph, Optional randomSeed, int boundedK, SplittableRandom random) { + this.graph = graph; + this.randomSeed = randomSeed; + this.boundedK = boundedK; + this.random = random; + } + + public KnnSampler create() { + return new RandomWalkKnnSampler(graph.concurrentCopy(), random.split(), randomSeed, boundedK); + } + } } diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/SplitOldAndNewNeighbors.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/SplitOldAndNewNeighbors.java index 945dd510b1..28e308ec07 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/knn/SplitOldAndNewNeighbors.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/SplitOldAndNewNeighbors.java @@ -21,8 +21,8 @@ import com.carrotsearch.hppc.IntArrayList; import com.carrotsearch.hppc.LongArrayList; -import org.neo4j.gds.core.concurrency.BiLongConsumer; import org.neo4j.gds.collections.ha.HugeObjectArray; +import org.neo4j.gds.core.concurrency.BiLongConsumer; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import java.util.SplittableRandom; @@ -36,26 +36,58 @@ * Mark sampled items in B[v] as false; */ final class SplitOldAndNewNeighbors implements BiLongConsumer { + + static final class Factory { + private final int sampledK; + private final SplittableRandom random; + private final ProgressTracker progressTracker; + + Factory( + int sampledK, + SplittableRandom random, + ProgressTracker progressTracker + ) { + this.sampledK = sampledK; + this.random = random; + this.progressTracker = progressTracker; + } + + SplitOldAndNewNeighbors create( + Neighbors neighbors, + HugeObjectArray allOldNeighbors, + HugeObjectArray allNewNeighbors + ) { + return new SplitOldAndNewNeighbors( + neighbors, + allOldNeighbors, + allNewNeighbors, + sampledK, + random, + progressTracker + ); + } + } + private final SplittableRandom random; - private final HugeObjectArray neighbors; + private final Neighbors neighbors; private final HugeObjectArray allOldNeighbors; private final HugeObjectArray allNewNeighbors; private final int sampledK; private final ProgressTracker progressTracker; SplitOldAndNewNeighbors( - SplittableRandom random, - HugeObjectArray neighbors, + Neighbors neighbors, HugeObjectArray allOldNeighbors, HugeObjectArray allNewNeighbors, int sampledK, + SplittableRandom random, ProgressTracker progressTracker ) { - this.random = random; this.neighbors = neighbors; this.allOldNeighbors = allOldNeighbors; this.allNewNeighbors = allNewNeighbors; this.sampledK = sampledK; + this.random = random; this.progressTracker = progressTracker; } diff --git a/algo/src/main/java/org/neo4j/gds/similarity/knn/UniformKnnSampler.java b/algo/src/main/java/org/neo4j/gds/similarity/knn/UniformKnnSampler.java index ccecc3aaa5..59eef7ab10 100644 --- a/algo/src/main/java/org/neo4j/gds/similarity/knn/UniformKnnSampler.java +++ b/algo/src/main/java/org/neo4j/gds/similarity/knn/UniformKnnSampler.java @@ -29,6 +29,7 @@ class UniformKnnSampler implements KnnSampler { + private final LongUniformSamplerFromRange uniformSamplerFromRange; private final long exclusiveMax; @@ -57,4 +58,18 @@ public long[] sample( isInvalidSample ); } + + static class Factory implements KnnSampler.Factory { + private final long nodeCount; + private final SplittableRandom random; + + Factory(long nodeCount, SplittableRandom random) { + this.nodeCount = nodeCount; + this.random = random; + } + + public KnnSampler create() { + return new UniformKnnSampler(random.split(), nodeCount); + } + } } diff --git a/algo/src/test/java/org/neo4j/gds/similarity/knn/GenerateRandomNeighborsTest.java b/algo/src/test/java/org/neo4j/gds/similarity/knn/GenerateRandomNeighborsTest.java index 47f4dbc47d..833d47f5e4 100644 --- a/algo/src/test/java/org/neo4j/gds/similarity/knn/GenerateRandomNeighborsTest.java +++ b/algo/src/test/java/org/neo4j/gds/similarity/knn/GenerateRandomNeighborsTest.java @@ -64,15 +64,15 @@ void neighborsForKEqualsNMinus1startWithEachOtherAsNeighbors( var random = new SplittableRandom(); var generateRandomNeighbors = new GenerateRandomNeighbors( + Partition.of(0, nodeCount), + new Neighbors(allNeighbors), new UniformKnnSampler(random, nodeCount), - random, - similarityFunction, new KnnNeighborFilter(nodeCount), - allNeighbors, + similarityFunction, + NeighbourConsumers.no_op, k, - Partition.of(0, nodeCount), - ProgressTracker.NULL_TRACKER, - NeighbourConsumers.no_op + random, + ProgressTracker.NULL_TRACKER ); generateRandomNeighbors.run(); diff --git a/algo/src/test/java/org/neo4j/gds/similarity/knn/JoinNeighborsTest.java b/algo/src/test/java/org/neo4j/gds/similarity/knn/JoinNeighborsTest.java new file mode 100644 index 0000000000..2e941a541a --- /dev/null +++ b/algo/src/test/java/org/neo4j/gds/similarity/knn/JoinNeighborsTest.java @@ -0,0 +1,111 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.similarity.knn; + +import com.carrotsearch.hppc.LongArrayList; +import org.junit.jupiter.api.Test; +import org.neo4j.gds.collections.ha.HugeObjectArray; +import org.neo4j.gds.core.utils.partition.Partition; +import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; +import org.neo4j.gds.extension.GdlExtension; +import org.neo4j.gds.extension.GdlGraph; +import org.neo4j.gds.extension.Inject; +import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer; + +import java.util.SplittableRandom; + +import static org.assertj.core.api.Assertions.assertThat; + +@GdlExtension +class JoinNeighborsTest { + + @GdlGraph + private static final String DB_CYPHER = + "CREATE" + + " (a { knn: 1.2, prop: 1.0 } )" + + ", (b { knn: 1.1, prop: 5.0 } )" + + ", (c { knn: 42.0, prop: 10.0 } )"; + @Inject + private TestGraph graph; + + @Test + void joinNeighbors() { + NeighbourConsumer neighbourConsumer = NeighbourConsumer.devNull; + SplittableRandom random = new SplittableRandom(42); + double perturbationRate = 0.0; + var allNeighbors = HugeObjectArray.of( + new NeighborList(1, neighbourConsumer), + new NeighborList(1, neighbourConsumer), + new NeighborList(1, neighbourConsumer) + ); + // setting an artificial priority to assure they will be replaced + allNeighbors.get(0).add(1, 0.0, random, perturbationRate); + allNeighbors.get(1).add(2, 0.0, random, perturbationRate); + allNeighbors.get(2).add(0, 0.0, random, perturbationRate); + + var allNewNeighbors = HugeObjectArray.of( + LongArrayList.from(1, 2), + null, + null + ); + + var allOldNeighbors = HugeObjectArray.newArray(LongArrayList.class, graph.nodeCount()); + + SimilarityFunction similarityFunction = new SimilarityFunction(new SimilarityComputer() { + @Override + public double similarity(long firstNodeId, long secondNodeId) { + return ((double) secondNodeId) / (firstNodeId + secondNodeId); + } + + @Override + public boolean isSymmetric() { + return true; + } + }); + + var joinNeighbors = new JoinNeighbors( + Partition.of(0, 1), + new Neighbors(allNeighbors), + allOldNeighbors, + allNewNeighbors, + HugeObjectArray.newArray(LongArrayList.class, graph.nodeCount()), + HugeObjectArray.newArray(LongArrayList.class, graph.nodeCount()), + new KnnNeighborFilter(graph.nodeCount()), + similarityFunction, + 1, + perturbationRate, + 0, + random, + // simplifying the test by only running over a single node + ProgressTracker.NULL_TRACKER + ); + + joinNeighbors.run(); + + // 1-0, 2-0, 1-2/2-1 + assertThat(joinNeighbors.nodePairsConsidered()).isEqualTo(3); + + assertThat(allNeighbors.get(0).elements()).containsExactly(1L); + assertThat(allNeighbors.get(1).elements()).containsExactly(2L); + // this gets updated due to joining the new neighbors together + assertThat(allNeighbors.get(2).elements()).containsExactly(1L); + } +} diff --git a/algo/src/test/java/org/neo4j/gds/similarity/knn/KnnConfigTest.java b/algo/src/test/java/org/neo4j/gds/similarity/knn/KnnConfigTest.java new file mode 100644 index 0000000000..7e9c02feaf --- /dev/null +++ b/algo/src/test/java/org/neo4j/gds/similarity/knn/KnnConfigTest.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.similarity.knn; + +import org.junit.jupiter.api.Test; +import org.neo4j.gds.core.CypherMapWrapper; +import org.neo4j.gds.extension.GdlExtension; +import org.neo4j.gds.extension.GdlGraph; +import org.neo4j.gds.extension.Inject; +import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer; +import org.neo4j.gds.similarity.knn.metrics.SimilarityMetric; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +@GdlExtension +class KnnConfigTest { + + @GdlGraph + private static final String DB_CYPHER = + "CREATE" + + " (a { knn: 1.2, prop: 1.0 } )" + + ", (b { knn: 1.1, prop: 5.0 } )" + + ", (c { knn: 42.0, prop: 10.0 } )"; + @Inject + private TestGraph graph; + + @Test + void shouldRenderNodePropertiesWithResolvedDefaultMetrics() { + var userInput = CypherMapWrapper.create( + Map.of( + "nodeProperties", List.of("knn") + ) + ); + var knnConfig = new KnnBaseConfigImpl(userInput); + + // Initializing the similarity computer causes the default metric to be resolved + SimilarityComputer.ofProperties(graph, knnConfig.nodeProperties()); + + assertThat(knnConfig.toMap().get("nodeProperties")).isEqualTo( + Map.of( + "knn", SimilarityMetric.DOUBLE_PROPERTY_METRIC.name() + ) + ); + } + + @Test + void invalidRandomParameters() { + var configBuilder = ImmutableKnnBaseConfig.builder() + .nodeProperties(List.of(new KnnNodePropertySpec("dummy"))) + .concurrency(4) + .randomSeed(1337L); + assertThrows(IllegalArgumentException.class, configBuilder::build); + } +} diff --git a/algo/src/test/java/org/neo4j/gds/similarity/knn/KnnTest.java b/algo/src/test/java/org/neo4j/gds/similarity/knn/KnnTest.java index b4b6aba62f..f6d3b039bd 100644 --- a/algo/src/test/java/org/neo4j/gds/similarity/knn/KnnTest.java +++ b/algo/src/test/java/org/neo4j/gds/similarity/knn/KnnTest.java @@ -32,9 +32,8 @@ import org.neo4j.gds.collections.ha.HugeObjectArray; import org.neo4j.gds.compat.Neo4jProxy; import org.neo4j.gds.compat.TestLog; -import org.neo4j.gds.core.CypherMapWrapper; +import org.neo4j.gds.core.concurrency.DefaultPool; import org.neo4j.gds.core.loading.NullPropertyMap; -import org.neo4j.gds.core.utils.partition.Partition; import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker; @@ -48,12 +47,10 @@ import org.neo4j.gds.nodeproperties.DoubleTestPropertyValues; import org.neo4j.gds.nodeproperties.FloatArrayTestPropertyValues; import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer; -import org.neo4j.gds.similarity.knn.metrics.SimilarityMetric; import java.util.Comparator; import java.util.List; -import java.util.Map; -import java.util.SplittableRandom; +import java.util.Optional; import java.util.stream.LongStream; import java.util.stream.Stream; @@ -61,7 +58,6 @@ import static org.assertj.core.api.Assertions.withPrecision; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.neo4j.gds.assertj.Extractors.removingThreadId; import static org.neo4j.gds.assertj.Extractors.replaceTimings; @@ -103,15 +99,25 @@ class KnnTest { void shouldRun() { IdFunction idFunction = graph::toMappedNodeId; - var knnConfig = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("knn"))) - .concurrency(1) - .randomSeed(19L) - .topK(1) - .build(); - var knnContext = ImmutableKnnContext.builder().build(); - - var knn = Knn.createWithDefaults(graph, knnConfig, knnContext); + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperty(graph, new KnnNodePropertySpec("knn"))); + var k = K.create(1, graph.nodeCount(), 0.5, 0.001); + var knn = new Knn( + graph, + ProgressTracker.NULL_TRACKER, + DefaultPool.INSTANCE, + k, + 1, + 1000, + 100, + 0.0, + 0.0, + 10, + Optional.of(19L), + KnnSampler.SamplerType.UNIFORM, + similarityFunction, + new KnnNeighborFilterFactory(graph.nodeCount()), + NeighbourConsumers.no_op + ); var result = knn.compute(); assertThat(result).isNotNull(); @@ -130,13 +136,25 @@ void shouldRun() { void shouldHaveEachNodeConnected() { IdFunction idFunction = graph::toMappedNodeId; - var knnConfig = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("knn"))) - .topK(2) - .build(); - var knnContext = ImmutableKnnContext.builder().build(); - - var knn = Knn.createWithDefaults(graph, knnConfig, knnContext); + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperty(graph, new KnnNodePropertySpec("knn"))); + var k = K.create(2, graph.nodeCount(), 0.5, 0.001); + var knn = new Knn( + graph, + ProgressTracker.NULL_TRACKER, + DefaultPool.INSTANCE, + k, + 4, + 1000, + 100, + 0.0, + 0.0, + 10, + Optional.empty(), + KnnSampler.SamplerType.UNIFORM, + similarityFunction, + new KnnNeighborFilterFactory(graph.nodeCount()), + NeighbourConsumers.no_op + ); var result = knn.compute(); assertThat(result).isNotNull(); @@ -150,6 +168,7 @@ void shouldHaveEachNodeConnected() { assertCorrectNeighborList(result, nodeBId, nodeAId, nodeCId); assertCorrectNeighborList(result, nodeCId, nodeAId, nodeBId); } + private void assertCorrectNeighborList( KnnResult result, long nodeId, @@ -170,15 +189,25 @@ private void assertCorrectNeighborList( void shouldWorkWithMultipleProperties() { IdFunction idFunction = graph::toMappedNodeId; - var knnConfig = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("knn"), new KnnNodePropertySpec("prop"))) - .concurrency(1) - .randomSeed(19L) - .topK(1) - .build(); - var knnContext = ImmutableKnnContext.builder().build(); - - var knn = Knn.createWithDefaults(graph, knnConfig, knnContext); + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperties(graph, List.of(new KnnNodePropertySpec("knn"), new KnnNodePropertySpec("prop")))); + var k = K.create(1, graph.nodeCount(), 0.5, 0.001); + var knn = new Knn( + graph, + ProgressTracker.NULL_TRACKER, + DefaultPool.INSTANCE, + k, + 1, + 1000, + 100, + 0.0, + 0.0, + 10, + Optional.of(19L), + KnnSampler.SamplerType.UNIFORM, + similarityFunction, + new KnnNeighborFilterFactory(graph.nodeCount()), + NeighbourConsumers.no_op + ); var result = knn.compute(); assertThat(result).isNotNull(); @@ -206,16 +235,25 @@ void shouldWorkWithMultipleProperties() { @Test void shouldWorkWithMultiplePropertiesEvenIfSomeAreMissing() { - - var knnConfig = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("prop1"), new KnnNodePropertySpec("prop2"))) - .concurrency(1) - .randomSeed(19L) - .topK(2) - .build(); - var knnContext = ImmutableKnnContext.builder().build(); - - var knn = Knn.createWithDefaults(multPropMissingGraph, knnConfig, knnContext); + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperties(multPropMissingGraph, List.of(new KnnNodePropertySpec("prop1"), new KnnNodePropertySpec("prop2")))); + var k = K.create(2, multPropMissingGraph.nodeCount(), 0.5, 0.001); + var knn = new Knn( + multPropMissingGraph, + ProgressTracker.NULL_TRACKER, + DefaultPool.INSTANCE, + k, + 1, + 1000, + 100, + 0.0, + 0.0, + 10, + Optional.of(19L), + KnnSampler.SamplerType.UNIFORM, + similarityFunction, + new KnnNeighborFilterFactory(multPropMissingGraph.nodeCount()), + NeighbourConsumers.no_op + ); var result = knn.compute(); assertThat(result).isNotNull(); @@ -240,17 +278,25 @@ void shouldWorkWithMultiplePropertiesEvenIfSomeAreMissing() { @Test void shouldFilterResultsOfLowSimilarity() { - - var knnConfig = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("age"))) - .concurrency(1) - .randomSeed(19L) - .similarityCutoff(0.14) - .topK(2) - .build(); - var knnContext = ImmutableKnnContext.builder().build(); - - var knn = Knn.createWithDefaults(simThresholdGraph, knnConfig, knnContext); + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperty(simThresholdGraph, new KnnNodePropertySpec("age"))); + var k = K.create(2, simThresholdGraph.nodeCount(), 0.5, 0.001); + var knn = new Knn( + simThresholdGraph, + ProgressTracker.NULL_TRACKER, + DefaultPool.INSTANCE, + k, + 1, + 1000, + 100, + 0.14, + 0.0, + 10, + Optional.of(19L), + KnnSampler.SamplerType.UNIFORM, + similarityFunction, + new KnnNeighborFilterFactory(simThresholdGraph.nodeCount()), + NeighbourConsumers.no_op + ); var result = knn.compute(); assertThat(result).isNotNull(); @@ -267,25 +313,27 @@ void shouldFilterResultsOfLowSimilarity() { assertCorrectNeighborList(result, nodeEveId, nodeBobId); } - private void assertEmptyNeighborList(KnnResult result, long nodeId) { - var actualNeighbors = result.neighborsOf(nodeId).toArray(); - assertThat(actualNeighbors).isEmpty(); - } - @ParameterizedTest @MethodSource("emptyProperties") void testNonExistingProperties(NodePropertyValues nodePropertyValues) { - var knnConfig = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("knn"))) - .topK(2) - .build(); - var knnContext = ImmutableKnnContext.builder().build(); - var knn = Knn.create( + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperty(graph, "knn", nodePropertyValues)); + var k = K.create(2, graph.nodeCount(), 0.5, 0.001); + var knn = new Knn( graph, - knnConfig, - SimilarityComputer.ofProperty(graph, "knn", nodePropertyValues), + ProgressTracker.NULL_TRACKER, + DefaultPool.INSTANCE, + k, + 4, + 1000, + 100, + 0.0, + 0.0, + 10, + Optional.empty(), + KnnSampler.SamplerType.UNIFORM, + similarityFunction, new KnnNeighborFilterFactory(graph.nodeCount()), - knnContext + NeighbourConsumers.no_op ); var result = knn.compute(); assertThat(result) @@ -308,20 +356,26 @@ void testMixedExistingAndNonExistingProperties(SoftAssertions softly) { IdFunction idFunction = graph::toMappedNodeId; var nodeProperties = new DoubleTestPropertyValues(nodeId -> nodeId == 0 ? Double.NaN : 42.1337); - var knn = Knn.create( + + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperty(graph, "{knn}", nodeProperties)); + var k = K.create(1, graph.nodeCount(), 0.5, 0.001); + var knn = new Knn( graph, - ImmutableKnnBaseConfig - .builder() - .nodeProperties(List.of(new KnnNodePropertySpec("knn"))) - .topK(1) - .concurrency(1) - .randomSeed(42L) - .build(), - SimilarityComputer.ofProperty(graph, "{knn}", nodeProperties), + ProgressTracker.NULL_TRACKER, + DefaultPool.INSTANCE, + k, + 1, + 1000, + 100, + 0.0, + 0.0, + 10, + Optional.of(42L), + KnnSampler.SamplerType.UNIFORM, + similarityFunction, new KnnNeighborFilterFactory(graph.nodeCount()), - ImmutableKnnContext.builder().build() + NeighbourConsumers.no_op ); - var result = knn.compute(); softly.assertThat(result) @@ -405,80 +459,29 @@ void testReverseSingleNeighbors() { } } - @Test - void joinNeighbors() { - NeighbourConsumer neighbourConsumer = NeighbourConsumer.devNull; - SplittableRandom random = new SplittableRandom(42); - double perturbationRate = 0.0; - var allNeighbors = HugeObjectArray.of( - new NeighborList(1, neighbourConsumer), - new NeighborList(1, neighbourConsumer), - new NeighborList(1, neighbourConsumer) - ); - // setting an artificial priority to assure they will be replaced - allNeighbors.get(0).add(1, 0.0, random, perturbationRate); - allNeighbors.get(1).add(2, 0.0, random, perturbationRate); - allNeighbors.get(2).add(0, 0.0, random, perturbationRate); - - var allNewNeighbors = HugeObjectArray.of( - LongArrayList.from(1, 2), - null, - null - ); - - var allOldNeighbors = HugeObjectArray.newArray(LongArrayList.class, graph.nodeCount()); - - SimilarityFunction similarityFunction = new SimilarityFunction(new SimilarityComputer() { - @Override - public double similarity(long firstNodeId, long secondNodeId) { - return ((double) secondNodeId) / (firstNodeId + secondNodeId); - } - - @Override - public boolean isSymmetric() { - return true; - } - }); - - var joinNeighbors = new Knn.JoinNeighbors( - random, - similarityFunction, - new KnnNeighborFilter(graph.nodeCount()), - allNeighbors, - allOldNeighbors, - allNewNeighbors, - HugeObjectArray.newArray(LongArrayList.class, graph.nodeCount()), - HugeObjectArray.newArray(LongArrayList.class, graph.nodeCount()), - 1, - perturbationRate, - 0, - // simplifying the test by only running over a single node - Partition.of(0, 1), - ProgressTracker.NULL_TRACKER - ); - - joinNeighbors.run(); - - // 1-0, 2-0, 1-2/2-1 - assertThat(joinNeighbors.nodePairsConsidered()).isEqualTo(3); - - assertThat(allNeighbors.get(0).elements()).containsExactly(1L); - assertThat(allNeighbors.get(1).elements()).containsExactly(2L); - // this gets updated due to joining the new neighbors together - assertThat(allNeighbors.get(2).elements()).containsExactly(1L); - } - @Test void testNegativeFloatArrays() { var graph = GdlFactory.of("({weight: [1.0, 2.0]}), ({weight: [3.0, -10.0]})").build().getUnion(); - var knnConfig = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("weight"))) - .topK(1) - .build(); - var knnContext = ImmutableKnnContext.builder().build(); - - var knn = Knn.createWithDefaults(graph, knnConfig, knnContext); + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperty(graph, new KnnNodePropertySpec("weight"))); + var k = K.create(1, graph.nodeCount(), 0.5, 0.001); + var knn = new Knn( + graph, + ProgressTracker.NULL_TRACKER, + DefaultPool.INSTANCE, + k, + 4, + 1000, + 100, + 0.0, + 0.0, + 10, + Optional.empty(), + KnnSampler.SamplerType.UNIFORM, + similarityFunction, + new KnnNeighborFilterFactory(graph.nodeCount()), + NeighbourConsumers.no_op + ); var result = knn.compute(); @@ -488,22 +491,32 @@ void testNegativeFloatArrays() { @Test void shouldLogProgress() { - var config = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("knn"))) - .randomSeed(42L) - .topK(1) - .concurrency(1) - .build(); - - var factory = new KnnFactory<>(); + var maxIterations = 100; - var progressTask = factory.progressTask(graph, config); + var progressTask = KnnFactory.knnTaskTree(graph.nodeCount(), maxIterations); var log = Neo4jProxy.testLog(); var progressTracker = new TaskProgressTracker(progressTask, log, 4, EmptyTaskRegistryFactory.INSTANCE); - factory - .build(graph, config, progressTracker) - .compute(); + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperty(graph, new KnnNodePropertySpec("knn"))); + var k = K.create(1, graph.nodeCount(), 0.5, 0.001); + var knn = new Knn( + graph, + progressTracker, + DefaultPool.INSTANCE, + k, + 1, + 1000, + maxIterations, + 0.0, + 0.0, + 10, + Optional.of(42L), + KnnSampler.SamplerType.UNIFORM, + similarityFunction, + new KnnNeighborFilterFactory(graph.nodeCount()), + NeighbourConsumers.no_op + ); + knn.compute(); assertThat(log.getMessages(TestLog.INFO)) .extracting(removingThreadId()) @@ -513,7 +526,6 @@ void shouldLogProgress() { "Knn :: Initialize random neighbors :: Start", "Knn :: Initialize random neighbors 100%", "Knn :: Initialize random neighbors :: Finished", - "Knn :: Graph init took `some time`", "Knn :: Iteration :: Start", "Knn :: Iteration :: Split old and new neighbors 1 of 100 :: Start", "Knn :: Iteration :: Split old and new neighbors 1 of 100 100%", @@ -524,54 +536,35 @@ void shouldLogProgress() { "Knn :: Iteration :: Join neighbors 1 of 100 :: Start", "Knn :: Iteration :: Join neighbors 1 of 100 100%", "Knn :: Iteration :: Join neighbors 1 of 100 :: Finished", - "Knn :: Iteration :: Graph iteration 1 took `some time`", "Knn :: Iteration :: Finished", - "Knn :: Finished", - "Knn :: Graph execution took `some time`" + "Knn :: Finished" ); } - @Test - void shouldRenderNodePropertiesWithResolvedDefaultMetrics() { - var userInput = CypherMapWrapper.create( - Map.of( - "nodeProperties", List.of("knn") - ) - ); - var knnConfig = new KnnBaseConfigImpl(userInput); - var knnContext = ImmutableKnnContext.builder().build(); - - // Initializing KNN will cause the default metric to be resolved - Knn.createWithDefaults(graph, knnConfig, knnContext); - - assertThat(knnConfig.toMap().get("nodeProperties")).isEqualTo( - Map.of( - "knn", SimilarityMetric.DOUBLE_PROPERTY_METRIC.name() - ) - ); - } - - @Test - void invalidRandomParameters() { - var configBuilder = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("dummy"))) - .concurrency(4) - .randomSeed(1337L); - assertThrows(IllegalArgumentException.class, configBuilder::build); - } - @ParameterizedTest(name = "{1}") @MethodSource("negativeGraphs") void supportNegativeArrays(String graphCreateQuery, String desc) { var graphWithNegativeNodePropertyValues = GdlFactory.of(graphCreateQuery).build().getUnion(); - var config = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("weight"))) - .randomSeed(42L) - .concurrency(1) - .build(); - var knnContext = KnnContext.empty(); - var knn = Knn.createWithDefaults(graphWithNegativeNodePropertyValues, config, knnContext); + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperty(graphWithNegativeNodePropertyValues, new KnnNodePropertySpec("weight"))); + var k = K.create(10, graphWithNegativeNodePropertyValues.nodeCount(), 0.5, 0.001); + var knn = new Knn( + graphWithNegativeNodePropertyValues, + ProgressTracker.NULL_TRACKER, + DefaultPool.INSTANCE, + k, + 1, + 1000, + 100, + 0.0, + 0.0, + 10, + Optional.of(42L), + KnnSampler.SamplerType.UNIFORM, + similarityFunction, + new KnnNeighborFilterFactory(graphWithNegativeNodePropertyValues.nodeCount()), + NeighbourConsumers.no_op + ); var result = knn.compute(); assertThat(result.streamSimilarityResult()) .hasSize(2); @@ -603,22 +596,30 @@ class IterationsLimitTest { @Test void shouldRespectIterationLimit() { - var config = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("knn"))) - .deltaThreshold(0) - .topK(1) - .maxIterations(1) - .randomSeed(42L) - .concurrency(1) - .build(); - var knnContext = KnnContext.empty(); - var knn = Knn.createWithDefaults(graph, config, knnContext); + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperty(graph, new KnnNodePropertySpec("knn"))); + var k = K.create(1, graph.nodeCount(), 0.5, 0.0); + var knn = new Knn( + graph, + ProgressTracker.NULL_TRACKER, + DefaultPool.INSTANCE, + k, + 1, + 1000, + 1, + 0.0, + 0.0, + 10, + Optional.of(42L), + KnnSampler.SamplerType.UNIFORM, + similarityFunction, + new KnnNeighborFilterFactory(graph.nodeCount()), + NeighbourConsumers.no_op + ); var result = knn.compute(); assertEquals(1, result.ranIterations()); assertFalse(result.didConverge()); } - } @Nested @@ -632,20 +633,30 @@ class DidConvergeTest { @Test void shouldReturnCorrectNumberIterationsWhenConverging() { - var config = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("knn"))) - .deltaThreshold(1.0) - .maxIterations(5) - .build(); - - var knnContext = KnnContext.empty(); - var knn = Knn.createWithDefaults(graph, config, knnContext); + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperty(graph, new KnnNodePropertySpec("knn"))); + var k = K.create(10, graph.nodeCount(), 0.5, 1.0); + var knn = new Knn( + graph, + ProgressTracker.NULL_TRACKER, + DefaultPool.INSTANCE, + k, + 4, + 1000, + 5, + 0.0, + 0.0, + 10, + Optional.empty(), + KnnSampler.SamplerType.UNIFORM, + similarityFunction, + new KnnNeighborFilterFactory(graph.nodeCount()), + NeighbourConsumers.no_op + ); var result = knn.compute(); assertTrue(result.didConverge()); assertEquals(1, result.ranIterations()); } - } @Nested @@ -684,17 +695,25 @@ class RandomWalkInitialSamplerTest { void testReasonableTopKWithRandomWalk(SoftAssertions softly) { IdFunction idFunction = graph::toMappedNodeId; - var config = ImmutableKnnBaseConfig.builder() - .nodeProperties(List.of(new KnnNodePropertySpec("knn"))) - .topK(4) - .randomJoins(0) - .maxIterations(1) - .randomSeed(20L) - .concurrency(1) - .initialSampler(KnnSampler.SamplerType.RANDOMWALK) - .build(); - var knnContext = KnnContext.empty(); - var knn = Knn.createWithDefaults(graph, config, knnContext); + var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperty(graph, new KnnNodePropertySpec("knn"))); + var k = K.create(4, graph.nodeCount(), 0.5, 0.001); + var knn = new Knn( + graph, + ProgressTracker.NULL_TRACKER, + DefaultPool.INSTANCE, + k, + 1, + 1000, + 1, + 0.0, + 0.0, + 0, + Optional.of(20L), + KnnSampler.SamplerType.RANDOMWALK, + similarityFunction, + new KnnNeighborFilterFactory(graph.nodeCount()), + NeighbourConsumers.no_op + ); var result = knn.compute(); long nodeAId = idFunction.of("a"); diff --git a/algo/src/test/java/org/neo4j/gds/similarity/knn/SplitOldAndNewNeighborsTest.java b/algo/src/test/java/org/neo4j/gds/similarity/knn/SplitOldAndNewNeighborsTest.java index 9d3e5ffe4e..7d85630be3 100644 --- a/algo/src/test/java/org/neo4j/gds/similarity/knn/SplitOldAndNewNeighborsTest.java +++ b/algo/src/test/java/org/neo4j/gds/similarity/knn/SplitOldAndNewNeighborsTest.java @@ -76,11 +76,11 @@ void name( ); var splitNeighbors = new SplitOldAndNewNeighbors( - new SplittableRandom(), - allNeighbors, + new Neighbors(allNeighbors), allOldNeighbors, allNewNeighbors, sampledK, + new SplittableRandom(), ProgressTracker.NULL_TRACKER ); splitNeighbors.apply(0, nodeCount); diff --git a/doc/modules/ROOT/pages/algorithms/knn.adoc b/doc/modules/ROOT/pages/algorithms/knn.adoc index 0df6488aec..cc0156315d 100644 --- a/doc/modules/ROOT/pages/algorithms/knn.adoc +++ b/doc/modules/ROOT/pages/algorithms/knn.adoc @@ -489,7 +489,7 @@ YIELD nodeCount, bytesMin, bytesMax, requiredMemory [opts="header", cols="1,1,1,1"] |=== | nodeCount | bytesMin | bytesMax | requiredMemory -| 5 | 2208 | 3264 | "[2208 Bytes \... 3264 Bytes]" +| 5 | 2224 | 3280 | "[2224 Bytes \... 3280 Bytes]" |=== -- diff --git a/doc/modules/ROOT/pages/machine-learning/linkprediction-pipelines/predict.adoc b/doc/modules/ROOT/pages/machine-learning/linkprediction-pipelines/predict.adoc index d81a1c6368..fae0fb46b5 100644 --- a/doc/modules/ROOT/pages/machine-learning/linkprediction-pipelines/predict.adoc +++ b/doc/modules/ROOT/pages/machine-learning/linkprediction-pipelines/predict.adoc @@ -307,7 +307,7 @@ Because we are using the `UNDIRECTED` orientation, we will write twice as many r [opts="header",cols="3,7"] |=== | relationshipsWritten | samplingStats -| 16 | {didConverge=true, linksConsidered=48, ranIterations=2, strategy=approximate} +| 16 | {didConverge=true, linksConsidered=43, ranIterations=2, strategy=approximate} |=== -- diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/ApproximateLinkPrediction.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/ApproximateLinkPrediction.java index 60d88fd085..2c1b1ba7d5 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/ApproximateLinkPrediction.java +++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/ApproximateLinkPrediction.java @@ -78,7 +78,7 @@ public static MemoryEstimation estimate(LinkPredictionPredictPipelineBaseConfig LinkPredictionResult predictLinks(LinkPredictionSimilarityComputer linkPredictionSimilarityComputer) { var knn = Knn.create( graph, - knnConfig, + knnConfig.toParameters().finalize(graph.nodeCount()), linkPredictionSimilarityComputer, new LinkPredictionSimilarityComputer.LinkFilterFactory( graph, diff --git a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/ApproximateLinkPredictionTest.java b/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/ApproximateLinkPredictionTest.java index 334baea71d..60ae3c90a0 100644 --- a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/ApproximateLinkPredictionTest.java +++ b/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/ApproximateLinkPredictionTest.java @@ -81,7 +81,7 @@ void setUp() { } @ParameterizedTest - @CsvSource(value = {"1, 59, 3", "2, 96, 2"}) + @CsvSource(value = {"1, 54, 3", "2, 90, 2"}) void shouldPredictWithTopK(int topK, long expectedLinksConsidered, int ranIterations) { var modelData = ImmutableLogisticRegressionData.of( 2, @@ -101,7 +101,7 @@ void shouldPredictWithTopK(int topK, long expectedLinksConsidered, int ranIterat LPNodeFilter.of(graphN, graphN), LPNodeFilter.of(graphN, graphN), ImmutableKnnBaseConfig.builder() - .randomSeed(42L) + .randomSeed(1337L) .concurrency(1) .randomJoins(2) .maxIterations(4) @@ -176,7 +176,7 @@ void shouldPredictTwice() { LPNodeFilter.of(graphN, graphN), LPNodeFilter.of(graphN, graphN), ImmutableKnnBaseConfig.builder() - .randomSeed(42L) + .randomSeed(1337L) .concurrency(1) .randomJoins(10) .maxIterations(10) diff --git a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineStreamProcTest.java b/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineStreamProcTest.java index 66644bdd6f..250f9b59f1 100644 --- a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineStreamProcTest.java +++ b/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineStreamProcTest.java @@ -132,7 +132,7 @@ void shouldPredictWithTopN(int concurrency, String nodeLabel) { } @ParameterizedTest - @CsvSource(value = {"N, [2432 Bytes ... 3776 Bytes]", "M, [2992 Bytes ... 5456 Bytes]"}) + @CsvSource(value = {"N, [2448 Bytes ... 3792 Bytes]", "M, [3008 Bytes ... 5472 Bytes]"}) void estimate(String targetNodeLabel, String expectedMemoryRange) { assertCypherResult( "CALL gds.beta.pipeline.linkPrediction.predict.stream.estimate('g', {" + diff --git a/proc/similarity/src/test/java/org/neo4j/gds/similarity/knn/KnnFactoryTest.java b/proc/similarity/src/test/java/org/neo4j/gds/similarity/knn/KnnFactoryTest.java index 6e84e93e5e..791b755809 100644 --- a/proc/similarity/src/test/java/org/neo4j/gds/similarity/knn/KnnFactoryTest.java +++ b/proc/similarity/src/test/java/org/neo4j/gds/similarity/knn/KnnFactoryTest.java @@ -55,8 +55,7 @@ static Stream smallParameters() { @MethodSource("smallParameters") void memoryEstimationWithNodeProperty(long nodeCount, KnnSampler.SamplerType initialSampler) { var config = knnConfig(initialSampler); - var boundedK = config.boundedK(nodeCount); - var sampledK = config.sampledK(nodeCount); + var k = config.k(nodeCount); MemoryEstimation estimation = new KnnFactory<>().memoryEstimation(config); GraphDimensions dimensions = ImmutableGraphDimensions.builder().nodeCount(nodeCount).build(); @@ -65,8 +64,8 @@ void memoryEstimationWithNodeProperty(long nodeCount, KnnSampler.SamplerType ini assertEstimation( nodeCount, - boundedK, - sampledK, + k.value, + k.sampledValue, initialSampler, actual ); @@ -86,15 +85,14 @@ static Stream largeParameters() { @MethodSource("largeParameters") void memoryEstimationLargePagesWithProperty(long nodeCount, KnnSampler.SamplerType initialSampler) { var config = knnConfig(initialSampler); - var boundedK = config.boundedK(nodeCount); - var sampledK = config.sampledK(nodeCount); + var k = config.k(nodeCount); MemoryEstimation estimation = new KnnFactory<>().memoryEstimation(config); GraphDimensions dimensions = ImmutableGraphDimensions.builder().nodeCount(nodeCount).build(); MemoryTree estimate = estimation.estimate(dimensions, 1); MemoryRange actual = estimate.memoryUsage(); - assertEstimation(nodeCount, boundedK, sampledK, initialSampler, actual); + assertEstimation(nodeCount, k.value, k.sampledValue, initialSampler, actual); } private void assertEstimation(