diff --git a/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java b/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java index 03cd9efa92..b19bccfe77 100644 --- a/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java +++ b/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java @@ -62,36 +62,6 @@ public static MemoryEstimation memoryEstimation(int walksPerNode, int walkLength .build(); } - static Node2Vec create( - Graph graph, - int concurrency, - WalkParameters walkParameters, - TrainParameters trainParameters, - ProgressTracker progressTracker - ) { - return create(graph, concurrency, Optional.empty(), walkParameters, trainParameters, progressTracker); - } - - static Node2Vec create( - Graph graph, - int concurrency, - Optional maybeRandomSeed, - WalkParameters walkParameters, - TrainParameters trainParameters, - ProgressTracker progressTracker - ) { - return new Node2Vec( - graph, - concurrency, - List.of(), - maybeRandomSeed, - 1000, - walkParameters, - trainParameters, - progressTracker - ); - } - public Node2Vec( Graph graph, int concurrency, diff --git a/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java b/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java index 7e9716e5b0..78df6d5465 100644 --- a/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java +++ b/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java @@ -71,6 +71,9 @@ @ExtendWith(SoftAssertionsExtension.class) class Node2VecTest extends BaseTest { + private static final List NO_SOURCE_NODES = List.of(); + private static final Optional NO_RANDOM_SEED = Optional.empty(); + private static final String DB_CYPHER = "CREATE" + " (a:Node1)" + @@ -109,9 +112,12 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable nod embeddingDimension, EmbeddingInitializer.NORMALIZED ); - HugeObjectArray node2Vec = Node2Vec.create( + HugeObjectArray node2Vec = new Node2Vec( graph, 4, + NO_SOURCE_NODES, + NO_RANDOM_SEED, + 1000, new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75), trainParameters, ProgressTracker.NULL_TRACKER @@ -156,10 +162,12 @@ void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) { ); var log = Neo4jProxy.testLog(); var progressTracker = new TestProgressTracker(progressTask, log, 4, EmptyTaskRegistryFactory.INSTANCE); - Node2Vec.create( + new Node2Vec( graph, 4, - Optional.empty(), + NO_SOURCE_NODES, + NO_RANDOM_SEED, + 1000, walkParameters, trainParameters, progressTracker @@ -223,9 +231,12 @@ void failOnNegativeWeights() { var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75); var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, 128, EmbeddingInitializer.NORMALIZED); - var node2Vec = Node2Vec.create( + var node2Vec = new Node2Vec( graph, 4, + NO_SOURCE_NODES, + NO_RANDOM_SEED, + 1000, walkParameters, trainParameters, ProgressTracker.NULL_TRACKER @@ -248,19 +259,23 @@ void randomSeed(SoftAssertions softly) { var walkParameters = new WalkParameters(1, 20, 1.0, 1.0, 0.001, 0.75); var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, embeddingDimension, EmbeddingInitializer.NORMALIZED); - var embeddings = Node2Vec.create( + var embeddings = new Node2Vec( graph, 4, + NO_SOURCE_NODES, Optional.of(1337L), + 1000, walkParameters, trainParameters, ProgressTracker.NULL_TRACKER ).compute().embeddings(); - var otherEmbeddings = Node2Vec.create( + var otherEmbeddings = new Node2Vec( graph, 4, + NO_SOURCE_NODES, Optional.of(1337L), + 1000, walkParameters, trainParameters, ProgressTracker.NULL_TRACKER @@ -348,19 +363,23 @@ void shouldBeFairlyConsistentUnderOriginalIds(EmbeddingInitializer embeddingInit var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.01, 0.75); var trainParameters = new TrainParameters(0.025, 0.0001, 1, 10, 5, embeddingDimension, embeddingInitializer); - var firstEmbeddings = Node2Vec.create( + var firstEmbeddings = new Node2Vec( firstGraph, 4, + NO_SOURCE_NODES, Optional.of(1337L), + 1000, walkParameters, trainParameters, ProgressTracker.NULL_TRACKER ).compute().embeddings(); - var secondEmbeddings = Node2Vec.create( + var secondEmbeddings = new Node2Vec( secondGraph, 4, + NO_SOURCE_NODES, Optional.of(1337L), + 1000, walkParameters, trainParameters, ProgressTracker.NULL_TRACKER