Skip to content

Commit

Permalink
Remove test-only factories with defaults
Browse files Browse the repository at this point in the history
Co-Authored-By: Veselin Nikolov <[email protected]>
  • Loading branch information
jjaderberg and vnickolov committed Oct 31, 2023
1 parent 9cb8785 commit 28b8734
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 38 deletions.
30 changes: 0 additions & 30 deletions algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,36 +62,6 @@ public static MemoryEstimation memoryEstimation(int walksPerNode, int walkLength
.build();
}

static Node2Vec create(
Graph graph,
int concurrency,
WalkParameters walkParameters,
TrainParameters trainParameters,
ProgressTracker progressTracker
) {
return create(graph, concurrency, Optional.empty(), walkParameters, trainParameters, progressTracker);
}

static Node2Vec create(
Graph graph,
int concurrency,
Optional<Long> maybeRandomSeed,
WalkParameters walkParameters,
TrainParameters trainParameters,
ProgressTracker progressTracker
) {
return new Node2Vec(
graph,
concurrency,
List.of(),
maybeRandomSeed,
1000,
walkParameters,
trainParameters,
progressTracker
);
}

public Node2Vec(
Graph graph,
int concurrency,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
@ExtendWith(SoftAssertionsExtension.class)
class Node2VecTest extends BaseTest {

private static final List<Long> NO_SOURCE_NODES = List.of();
private static final Optional<Long> NO_RANDOM_SEED = Optional.empty();

private static final String DB_CYPHER =
"CREATE" +
" (a:Node1)" +
Expand Down Expand Up @@ -109,9 +112,12 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable<String> nod
embeddingDimension,
EmbeddingInitializer.NORMALIZED
);
HugeObjectArray<FloatVector> node2Vec = Node2Vec.create(
HugeObjectArray<FloatVector> node2Vec = new Node2Vec(
graph,
4,
NO_SOURCE_NODES,
NO_RANDOM_SEED,
1000,
new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75),
trainParameters,
ProgressTracker.NULL_TRACKER
Expand Down Expand Up @@ -156,10 +162,12 @@ void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) {
);
var log = Neo4jProxy.testLog();
var progressTracker = new TestProgressTracker(progressTask, log, 4, EmptyTaskRegistryFactory.INSTANCE);
Node2Vec.create(
new Node2Vec(
graph,
4,
Optional.empty(),
NO_SOURCE_NODES,
NO_RANDOM_SEED,
1000,
walkParameters,
trainParameters,
progressTracker
Expand Down Expand Up @@ -223,9 +231,12 @@ void failOnNegativeWeights() {
var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75);
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, 128, EmbeddingInitializer.NORMALIZED);

var node2Vec = Node2Vec.create(
var node2Vec = new Node2Vec(
graph,
4,
NO_SOURCE_NODES,
NO_RANDOM_SEED,
1000,
walkParameters,
trainParameters,
ProgressTracker.NULL_TRACKER
Expand All @@ -248,19 +259,23 @@ void randomSeed(SoftAssertions softly) {
var walkParameters = new WalkParameters(1, 20, 1.0, 1.0, 0.001, 0.75);
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, embeddingDimension, EmbeddingInitializer.NORMALIZED);

var embeddings = Node2Vec.create(
var embeddings = new Node2Vec(
graph,
4,
NO_SOURCE_NODES,
Optional.of(1337L),
1000,
walkParameters,
trainParameters,
ProgressTracker.NULL_TRACKER
).compute().embeddings();

var otherEmbeddings = Node2Vec.create(
var otherEmbeddings = new Node2Vec(
graph,
4,
NO_SOURCE_NODES,
Optional.of(1337L),
1000,
walkParameters,
trainParameters,
ProgressTracker.NULL_TRACKER
Expand Down Expand Up @@ -348,19 +363,23 @@ void shouldBeFairlyConsistentUnderOriginalIds(EmbeddingInitializer embeddingInit
var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.01, 0.75);
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 10, 5, embeddingDimension, embeddingInitializer);

var firstEmbeddings = Node2Vec.create(
var firstEmbeddings = new Node2Vec(
firstGraph,
4,
NO_SOURCE_NODES,
Optional.of(1337L),
1000,
walkParameters,
trainParameters,
ProgressTracker.NULL_TRACKER
).compute().embeddings();

var secondEmbeddings = Node2Vec.create(
var secondEmbeddings = new Node2Vec(
secondGraph,
4,
NO_SOURCE_NODES,
Optional.of(1337L),
1000,
walkParameters,
trainParameters,
ProgressTracker.NULL_TRACKER
Expand Down

0 comments on commit 28b8734

Please sign in to comment.