Skip to content

Commit

Permalink
Merge pull request #8471 from neo-technology/knn
Browse files Browse the repository at this point in the history
knn
  • Loading branch information
jjaderberg authored Dec 20, 2023
2 parents 64b7e39 + b9eaf97 commit 68bea47
Show file tree
Hide file tree
Showing 25 changed files with 1,489 additions and 723 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import org.neo4j.gds.similarity.filtering.NodeFilter;
import org.neo4j.gds.similarity.knn.Knn;
import org.neo4j.gds.similarity.knn.KnnContext;
import org.neo4j.gds.similarity.knn.KnnNeighborFilterFactory;
import org.neo4j.gds.similarity.knn.KnnResult;
import org.neo4j.gds.similarity.knn.SimilarityFunction;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;

import java.util.Optional;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -57,7 +59,7 @@ public static FilteredKnn createWithoutSeeding(Graph graph, FilteredKnnBaseConfi

// a bit speculative, but we imagine this being used as entrypoint for seeding
public static FilteredKnn createWithDefaultSeeding(Graph graph, FilteredKnnBaseConfig config, KnnContext context) {
var similarityFunction = Knn.defaultSimilarityFunction(graph, config.nodeProperties());
var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperties(graph, config.nodeProperties()));

return create(graph, config, context, Optional.of(similarityFunction));
}
Expand All @@ -71,9 +73,28 @@ public static FilteredKnn createWithDefaultSeeding(Graph graph, FilteredKnnBaseC
*/
static FilteredKnn create(Graph graph, FilteredKnnBaseConfig config, KnnContext context, Optional<SimilarityFunction> optionalSimilarityFunction) {
var targetNodeFilter = config.targetNodeFilter().toNodeFilter(graph);
var targetNodeFiltering = TargetNodeFiltering.create(graph.nodeCount(),config.boundedK(graph.nodeCount()), targetNodeFilter, graph, optionalSimilarityFunction, config.similarityCutoff());
var similarityFunction = optionalSimilarityFunction.orElse(Knn.defaultSimilarityFunction(graph, config.nodeProperties()));
var knn = Knn.createWithDefaultsAndInstrumentation(graph, config, context, targetNodeFiltering, similarityFunction);
var targetNodeFiltering = TargetNodeFiltering.create(graph.nodeCount(), config.k(graph.nodeCount()).value, targetNodeFilter, graph, optionalSimilarityFunction, config.similarityCutoff());
var similarityFunction = optionalSimilarityFunction.orElse(new SimilarityFunction(SimilarityComputer.ofProperties(
graph,
config.nodeProperties()
)));
var knn = new Knn(
graph,
context.progressTracker(),
context.executor(),
config.k(graph.nodeCount()),
config.concurrency(),
config.minBatchSize(),
config.maxIterations(),
config.similarityCutoff(),
config.perturbationRate(),
config.randomJoins(),
config.randomSeed(),
config.initialSampler(),
similarityFunction,
new KnnNeighborFilterFactory(graph.nodeCount()),
targetNodeFiltering
);
var sourceNodeFilter = config.sourceNodeFilter().toNodeFilter(graph);

return new FilteredKnn(context.progressTracker(), knn, targetNodeFiltering, sourceNodeFilter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
*/
package org.neo4j.gds.similarity.knn;

import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;

Expand All @@ -29,39 +29,78 @@
* Initial step in KNN calculation.
*/
final class GenerateRandomNeighbors implements Runnable {

static final class Factory {
private final SimilarityFunction similarityFunction;
private final NeighbourConsumers neighbourConsumers;
private final int boundedK;
private final SplittableRandom random;
private final ProgressTracker progressTracker;

Factory(
SimilarityFunction similarityFunction,
NeighbourConsumers neighbourConsumers,
int boundedK,
SplittableRandom random,
ProgressTracker progressTracker
) {
this.similarityFunction = similarityFunction;
this.neighbourConsumers = neighbourConsumers;
this.boundedK = boundedK;
this.random = random;
this.progressTracker = progressTracker;
}

@NotNull GenerateRandomNeighbors create(
Partition partition,
Neighbors neighbors,
KnnSampler sampler,
NeighborFilter neighborFilter
) {
return new GenerateRandomNeighbors(
partition,
neighbors,
sampler,
neighborFilter,
similarityFunction,
neighbourConsumers,
boundedK,
random.split(),
progressTracker
);
}
}

private final Partition partition;
private final Neighbors neighbors;
private final KnnSampler sampler;
private final NeighborFilter neighborFilter;
private final SplittableRandom random;
private final SimilarityFunction similarityFunction;
private final NeighborFilter neighborFilter;
private final HugeObjectArray<NeighborList> neighbors;
private final NeighbourConsumers neighbourConsumers;
private final int boundedK;
private final ProgressTracker progressTracker;
private final Partition partition;
private final NeighbourConsumers neighbourConsumers;

private long neighborsFound;

GenerateRandomNeighbors(
Partition partition,
Neighbors neighbors,
KnnSampler sampler,
SplittableRandom random,
SimilarityFunction similarityFunction,
NeighborFilter neighborFilter,
HugeObjectArray<NeighborList> neighbors,
SimilarityFunction similarityFunction,
NeighbourConsumers neighbourConsumers,
int boundedK,
Partition partition,
ProgressTracker progressTracker,
NeighbourConsumers neighbourConsumers
SplittableRandom random,
ProgressTracker progressTracker
) {
this.partition = partition;
this.neighbors = neighbors;
this.sampler = sampler;
this.neighborFilter = neighborFilter;
this.random = random;
this.similarityFunction = similarityFunction;
this.neighborFilter = neighborFilter;
this.neighbors = neighbors;
this.neighbourConsumers = neighbourConsumers;
this.boundedK = boundedK;
this.progressTracker = progressTracker;
this.partition = partition;
this.neighborsFound = 0;
this.neighbourConsumers = neighbourConsumers;
}

@Override
Expand All @@ -88,12 +127,7 @@ public void run() {
assert neighbors.size() >= Math.min(neighborFilter.lowerBoundOfPotentialNeighbours(nodeId), boundedK);

this.neighbors.set(nodeId, neighbors);
neighborsFound += neighbors.size();
});
progressTracker.logProgress(partition.nodeCount());
}

long neighborsFound() {
return neighborsFound;
}
}
Loading

0 comments on commit 68bea47

Please sign in to comment.