Skip to content

Commit

Permalink
Merge pull request #9550 from neo-technology/knn-task-label
Browse files Browse the repository at this point in the history
Use typed task label for filtered KNN
  • Loading branch information
jjaderberg authored Aug 28, 2024
2 parents 12a20d6 + d1bb82e commit 1ee4d03
Showing 1 changed file with 17 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.neo4j.gds.similarity.knn.Knn;
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
import org.neo4j.gds.similarity.knn.KnnContext;
import org.neo4j.gds.similarity.knn.KnnFactory;
import org.neo4j.gds.similarity.knn.KnnNeighborFilterFactory;
import org.neo4j.gds.similarity.knn.KnnResult;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
Expand All @@ -47,6 +46,7 @@

import java.util.List;

import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.FilteredKNN;
import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.FilteredNodeSimilarity;
import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.KNN;

Expand All @@ -65,10 +65,24 @@ public SimilarityAlgorithms(
}

FilteredKnnResult filteredKnn(Graph graph, FilteredKnnBaseConfig configuration) {
var taskTree = KnnFactory.knnTaskTree(graph.nodeCount(), configuration.maxIterations());
long nodeCount = graph.nodeCount();

Task task = Tasks.task(
FilteredKNN.value,
Tasks.leaf("Initialize random neighbors", nodeCount),
Tasks.iterativeDynamic(
"Iteration",
() -> List.of(
Tasks.leaf("Split old and new neighbors", nodeCount),
Tasks.leaf("Reverse old and new neighbors", nodeCount),
Tasks.leaf("Join neighbors", nodeCount)
),
configuration.maxIterations()
)
);
var progressTracker = progressTrackerCreator.createProgressTracker(
configuration,
taskTree
task
);
var knnContext = ImmutableKnnContext
.builder()
Expand Down

0 comments on commit 1ee4d03

Please sign in to comment.