diff --git a/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecAlgorithmFactory.java b/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecAlgorithmFactory.java deleted file mode 100644 index f70b3d4f47..0000000000 --- a/algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecAlgorithmFactory.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.embeddings.node2vec; - -import org.neo4j.gds.GraphAlgorithmFactory; -import org.neo4j.gds.api.Graph; -import org.neo4j.gds.mem.MemoryEstimation; -import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; -import org.neo4j.gds.core.utils.progress.tasks.Task; -import org.neo4j.gds.core.utils.progress.tasks.Tasks; -import org.neo4j.gds.degree.DegreeCentralityFactory; -import org.neo4j.gds.termination.TerminationFlag; - -import java.util.ArrayList; -import java.util.List; - -import static java.lang.Math.multiplyExact; -import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; - -public class Node2VecAlgorithmFactory extends GraphAlgorithmFactory { - - @Override - public String taskName() { - return "Node2Vec"; - } - - @Override - public Node2Vec build( - Graph graph, - CONFIG configuration, - ProgressTracker progressTracker - ) { - validateConfig(configuration, graph); - return new Node2Vec( - graph, - configuration.concurrency(), - configuration.sourceNodes(), - configuration.randomSeed(), - configuration.walkBufferSize(), - Node2VecConfigTransformer.node2VecParameters(configuration), - progressTracker, - TerminationFlag.RUNNING_TRUE - ); - } - - @Override - public MemoryEstimation memoryEstimation(CONFIG configuration) { - return new Node2VecMemoryEstimateDefinition(Node2VecConfigTransformer.node2VecParameters(configuration)).memoryEstimation(); - } - - @Override - public Task progressTask(Graph graph, CONFIG config) { - var randomWalkTasks = new ArrayList(); - if (graph.hasRelationshipProperty()) { - randomWalkTasks.add(DegreeCentralityFactory.degreeCentralityProgressTask(graph)); - } - randomWalkTasks.add(Tasks.leaf("create walks", graph.nodeCount())); - - return Tasks.task( - taskName(), - Tasks.task("RandomWalk", randomWalkTasks), - Tasks.iterativeFixed( - "train", - () -> List.of(Tasks.leaf("iteration")), - config.iterations() - ) - ); - } - - private void validateConfig(CONFIG config, Graph graph) { - try { - var ignored = multiplyExact( - multiplyExact(graph.nodeCount(), config.walksPerNode()), - config.walkLength() - ); - } catch (ArithmeticException ex) { - throw new IllegalArgumentException( - formatWithLocale( - "Aborting execution, running with the configured parameters is likely to overflow: node count: %d, walks per node: %d, walkLength: %d." + - " Try reducing these parameters or run on a smaller graph.", - graph.nodeCount(), - config.walksPerNode(), - config.walkLength() - )); - } - } -} diff --git a/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecAlgorithmFactoryTest.java b/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecAlgorithmFactoryTest.java deleted file mode 100644 index c876c76859..0000000000 --- a/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecAlgorithmFactoryTest.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.embeddings.node2vec; - -import org.junit.jupiter.api.Test; -import org.neo4j.gds.api.Graph; -import org.neo4j.gds.core.CypherMapWrapper; -import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; -import org.neo4j.gds.extension.GdlExtension; -import org.neo4j.gds.extension.GdlGraph; -import org.neo4j.gds.extension.Inject; - -import java.util.Map; - -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; - -@GdlExtension -class Node2VecAlgorithmFactoryTest { - - @GdlGraph - private static final String CYPHER = - "CREATE" + - " (:N)" + - ", (:N)" + - ", (:N)" + - ", (:N)" + - ", (:N)"; - - - @Inject - private Graph graph; - - @Test - void shouldThrowIfRunningWouldOverflow() { - - var config = Node2VecStreamConfig.of(CypherMapWrapper.create( - Map.of( - "writeProperty", "embedding", - "walksPerNode", Integer.MAX_VALUE, - "walkLength", Integer.MAX_VALUE, - "sudo", true - ) - )); - - var factory = new Node2VecAlgorithmFactory<>(); - - String expectedMessage = formatWithLocale( - "Aborting execution, running with the configured parameters is likely to overflow: node count: %d, walks per node: %d, walkLength: %d." + - " Try reducing these parameters or run on a smaller graph.", - graph.nodeCount(), - Integer.MAX_VALUE, - Integer.MAX_VALUE - ); - - assertThatIllegalArgumentException() - .isThrownBy(() -> factory.build(graph, config, ProgressTracker.NULL_TRACKER)) - .withMessage(expectedMessage); - } -} diff --git a/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java b/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java index 18bea4e04d..989c5079c2 100644 --- a/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java +++ b/algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java @@ -25,20 +25,17 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.neo4j.gds.NodeLabel; import org.neo4j.gds.Orientation; import org.neo4j.gds.RelationshipType; -import org.neo4j.gds.TestProgressTracker; import org.neo4j.gds.api.Graph; import org.neo4j.gds.api.GraphStore; import org.neo4j.gds.collections.ha.HugeLongArray; import org.neo4j.gds.collections.ha.HugeObjectArray; import org.neo4j.gds.collections.hsa.HugeSparseLongArray; -import org.neo4j.gds.compat.TestLog; import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.concurrency.DefaultPool; import org.neo4j.gds.core.loading.ArrayIdMap; @@ -46,14 +43,12 @@ import org.neo4j.gds.core.loading.construction.GraphFactory; import org.neo4j.gds.core.loading.construction.RelationshipsBuilder; import org.neo4j.gds.core.utils.Intersections; -import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.core.utils.shuffle.ShuffleUtil; import org.neo4j.gds.extension.GdlExtension; import org.neo4j.gds.extension.GdlGraph; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.gdl.GdlFactory; -import org.neo4j.gds.logging.GdsTestLog; import org.neo4j.gds.ml.core.tensor.FloatVector; import org.neo4j.gds.termination.TerminationFlag; @@ -65,7 +60,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.neo4j.gds.assertj.Extractors.removingThreadId; @ExtendWith(SoftAssertionsExtension.class) @GdlExtension @@ -133,77 +127,6 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, List node ); } - @ParameterizedTest - @CsvSource(value = { - "true,4", - "false,3" - }) - void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) { - Graph currentGraph; - if (relationshipWeights) { - currentGraph = graph; - } else { - currentGraph = graphStore.getGraph(RelationshipType.of("REL"), Optional.empty()); - } - - int embeddingDimension = 128; - Node2VecStreamConfig config = Node2VecStreamConfigImpl - .builder() - .embeddingDimension(embeddingDimension) - .build(); - var progressTask = new Node2VecAlgorithmFactory<>().progressTask(currentGraph, config); - - var walkParameters = new SamplingWalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75); - var trainParameters = new TrainParameters( - 0.025, - 0.0001, - 1, - 10, - 5, - embeddingDimension, - EmbeddingInitializer.NORMALIZED - ); - var concurrency = new Concurrency(4); - var log = new GdsTestLog(); - var progressTracker = new TestProgressTracker(progressTask, log, concurrency, EmptyTaskRegistryFactory.INSTANCE); - new Node2Vec( - currentGraph, - concurrency, - NO_SOURCE_NODES, - NO_RANDOM_SEED, - 1000, - new Node2VecParameters(walkParameters, trainParameters), - progressTracker, - TerminationFlag.RUNNING_TRUE - ).compute(); - - assertThat(log.getMessages(TestLog.INFO)) - .extracting(removingThreadId()) - .contains( - "Node2Vec :: Start", - "Node2Vec :: RandomWalk :: Start", - "Node2Vec :: RandomWalk :: create walks :: Start", - "Node2Vec :: RandomWalk :: create walks 100%", - "Node2Vec :: RandomWalk :: create walks :: Finished", - "Node2Vec :: RandomWalk :: Finished", - "Node2Vec :: train :: Start", - "Node2Vec :: train :: iteration 1 of 1 :: Start", - "Node2Vec :: train :: iteration 1 of 1 100%", - "Node2Vec :: train :: iteration 1 of 1 :: Finished", - "Node2Vec :: train :: Finished", - "Node2Vec :: Finished" - ); - - if (relationshipWeights) { - assertThat(log.getMessages(TestLog.INFO)) - .extracting(removingThreadId()) - .contains( - "Node2Vec :: RandomWalk :: DegreeCentrality :: Start", - "Node2Vec :: RandomWalk :: DegreeCentrality :: Finished" - ); - } - } - @Test void failOnNegativeWeights() { var negativeGraph = GdlFactory.of("CREATE (a)-[:REL {weight: -1}]->(b)").build().getUnion(); diff --git a/algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/node2vec/Constants.java b/algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/node2vec/Constants.java deleted file mode 100644 index 41ace00e66..0000000000 --- a/algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/node2vec/Constants.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.embeddings.node2vec; - -final class Constants { - static final String NODE2VEC_DESCRIPTION = "The Node2Vec algorithm computes embeddings for nodes based on random walks."; - - private Constants() {} -} diff --git a/algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecMutateSpec.java b/algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecMutateSpec.java deleted file mode 100644 index 631cf0da3a..0000000000 --- a/algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecMutateSpec.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.embeddings.node2vec; - -import org.neo4j.gds.NullComputationResultConsumer; -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.procedures.algorithms.embeddings.Node2VecMutateResult; - -import java.util.stream.Stream; - -import static org.neo4j.gds.executor.ExecutionMode.MUTATE_NODE_PROPERTY; - -@GdsCallable( - name = "gds.node2vec.mutate", - aliases = "gds.beta.node2vec.mutate", - description = Constants.NODE2VEC_DESCRIPTION, - executionMode = MUTATE_NODE_PROPERTY -) -public class Node2VecMutateSpec implements AlgorithmSpec, Node2VecAlgorithmFactory> { - @Override - public String name() { - return "Node2VecMutate"; - } - - @Override - public Node2VecAlgorithmFactory algorithmFactory(ExecutionContext executionContext) { - return new Node2VecAlgorithmFactory<>(); - } - - @Override - public NewConfigFunction newConfigFunction() { - return (__, userInput) -> Node2VecMutateConfig.of(userInput); - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - return new NullComputationResultConsumer<>(); - } -} diff --git a/algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecStreamSpec.java b/algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecStreamSpec.java deleted file mode 100644 index 5e36df3365..0000000000 --- a/algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecStreamSpec.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.embeddings.node2vec; - -import org.neo4j.gds.NullComputationResultConsumer; -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.procedures.algorithms.embeddings.Node2VecStreamResult; - -import java.util.stream.Stream; - -import static org.neo4j.gds.executor.ExecutionMode.STREAM; - -@GdsCallable( - name = "gds.node2vec.stream", - aliases = "gds.beta.node2vec.stream", - description = Constants.NODE2VEC_DESCRIPTION, - executionMode = STREAM -) -public class Node2VecStreamSpec implements AlgorithmSpec, Node2VecAlgorithmFactory> { - @Override - public String name() { - return "Node2VecStream"; - } - - @Override - public Node2VecAlgorithmFactory algorithmFactory(ExecutionContext executionContext) { - return new Node2VecAlgorithmFactory<>(); - } - - @Override - public NewConfigFunction newConfigFunction() { - return (__, userInput) -> Node2VecStreamConfig.of(userInput); - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - return new NullComputationResultConsumer<>(); - } -} diff --git a/algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecWriteSpec.java b/algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecWriteSpec.java deleted file mode 100644 index 62c48d601e..0000000000 --- a/algorithm-specifications/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecWriteSpec.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.embeddings.node2vec; - -import org.neo4j.gds.NullComputationResultConsumer; -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.procedures.algorithms.embeddings.Node2VecWriteResult; - -import java.util.stream.Stream; - -import static org.neo4j.gds.executor.ExecutionMode.WRITE_NODE_PROPERTY; - -@GdsCallable( - name = "gds.node2vec.write", - aliases = "gds.beta.node2vec.write", - description = Constants.NODE2VEC_DESCRIPTION, - executionMode = WRITE_NODE_PROPERTY -) -public class Node2VecWriteSpec implements AlgorithmSpec, Node2VecAlgorithmFactory> { - @Override - public String name() { - return "Node2VecWrite"; - } - - @Override - public Node2VecAlgorithmFactory algorithmFactory(ExecutionContext executionContext) { - return new Node2VecAlgorithmFactory<>(); - } - - @Override - public NewConfigFunction newConfigFunction() { - return (__, userInput) -> Node2VecWriteConfig.of(userInput); - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - return new NullComputationResultConsumer<>(); - } -} diff --git a/applications/algorithms/node-embeddings/build.gradle b/applications/algorithms/node-embeddings/build.gradle index 58bc84b6a0..25f77f768e 100644 --- a/applications/algorithms/node-embeddings/build.gradle +++ b/applications/algorithms/node-embeddings/build.gradle @@ -18,9 +18,18 @@ dependencies { implementation project(":model-catalog-api") implementation project(":model-catalog-applications") implementation project(":node-embeddings-configs") - // TODO: the `:path-finding-configs` is needed because it's indirectly used by `node2vec` parameters. implementation project(':path-finding-configs') implementation project(":progress-tracking") implementation project(":string-formatting") implementation project(":termination") + + testImplementation platform(openGds.junit5bom) + testImplementation openGds.junit5.jupiter.api + testImplementation openGds.junit5.jupiter.params + testImplementation openGds.mockito.junit.jupiter + testImplementation openGds.assertj.core + + testRuntimeOnly openGds.junit5.jupiter.engine + + testImplementation project(':test-utils') } diff --git a/applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsWriteModeBusinessFacade.java b/applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsWriteModeBusinessFacade.java index c49f6b008c..456cde529c 100644 --- a/applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsWriteModeBusinessFacade.java +++ b/applications/algorithms/node-embeddings/src/main/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsWriteModeBusinessFacade.java @@ -48,7 +48,7 @@ public final class NodeEmbeddingAlgorithmsWriteModeBusinessFacade { private final WriteToDatabase writeToDatabase; private final GraphSageAlgorithmProcessing graphSageAlgorithmProcessing; - private NodeEmbeddingAlgorithmsWriteModeBusinessFacade( + NodeEmbeddingAlgorithmsWriteModeBusinessFacade( NodeEmbeddingAlgorithmsEstimationModeBusinessFacade estimationFacade, NodeEmbeddingAlgorithms algorithms, AlgorithmProcessingTemplateConvenience algorithmProcessingTemplateConvenience, diff --git a/applications/algorithms/node-embeddings/src/test/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsBusinessFacadeTest.java b/applications/algorithms/node-embeddings/src/test/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsBusinessFacadeTest.java new file mode 100644 index 0000000000..59487e9338 --- /dev/null +++ b/applications/algorithms/node-embeddings/src/test/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsBusinessFacadeTest.java @@ -0,0 +1,111 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.applications.algorithms.embeddings; + +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatcher; +import org.mockito.internal.progress.ThreadSafeMockingProgress; +import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTemplateConvenience; +import org.neo4j.gds.core.loading.PostLoadValidationHook; + +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +class NodeEmbeddingAlgorithmsBusinessFacadeTest { + @Test + void shouldValidateNode2VecDataSizeInMutateMode() { + var processingTemplate = mock(AlgorithmProcessingTemplateConvenience.class); + var facade = new NodeEmbeddingAlgorithmsMutateModeBusinessFacade(null, null, processingTemplate, null, null); + + facade.node2Vec(null, null, null); + + verify(processingTemplate).processAlgorithmInMutateMode( + any(), + any(), + any(), + node2VecValidationHook(), + any(), + any(), + any(), + any(), + any(), + any() + ); + } + + @Test + void shouldValidateNode2VecDataSizeInStreamMode() { + var processingTemplate = mock(AlgorithmProcessingTemplateConvenience.class); + var facade = new NodeEmbeddingAlgorithmsStreamModeBusinessFacade(null, null, processingTemplate, null); + + facade.node2Vec(null, null, null); + + verify(processingTemplate).processAlgorithmInStreamMode( + any(), + any(), + any(), + any(), + any(), + any(), + node2VecValidationHook(), + any(), + any() + ); + } + + @Test + void shouldValidateNode2VecDataSizeInWriteMode() { + var processingTemplate = mock(AlgorithmProcessingTemplateConvenience.class); + var facade = new NodeEmbeddingAlgorithmsWriteModeBusinessFacade(null, null, processingTemplate, null, null); + + facade.node2Vec(null, null, null); + + verify(processingTemplate).processAlgorithmInWriteMode( + any(), + any(), + any(), + node2VecValidationHook(), + any(), + any(), + any(), + any(), + any(), + any() + ); + } + + /** + * Verify that you did indeed stack on the right validation hook + */ + private static Optional> node2VecValidationHook() { + ThreadSafeMockingProgress.mockingProgress() + .getArgumentMatcherStorage() + .reportMatcher((ArgumentMatcher>>) argument -> { + assertThat(argument.orElseThrow()).singleElement().isInstanceOf(Node2VecValidationHook.class); + return true; + }); + //noinspection DataFlowIssue,OptionalAssignedToNull + return null; + } +} diff --git a/applications/algorithms/node-embeddings/src/test/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsTest.java b/applications/algorithms/node-embeddings/src/test/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsTest.java new file mode 100644 index 0000000000..2d7b5650d4 --- /dev/null +++ b/applications/algorithms/node-embeddings/src/test/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsTest.java @@ -0,0 +1,133 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.applications.algorithms.embeddings; + +import org.junit.jupiter.api.Test; +import org.neo4j.gds.RelationshipType; +import org.neo4j.gds.api.Graph; +import org.neo4j.gds.api.GraphStore; +import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator; +import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies; +import org.neo4j.gds.compat.TestLog; +import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory; +import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory; +import org.neo4j.gds.embeddings.node2vec.Node2VecStreamConfigImpl; +import org.neo4j.gds.extension.GdlExtension; +import org.neo4j.gds.extension.GdlGraph; +import org.neo4j.gds.extension.Inject; +import org.neo4j.gds.logging.GdsTestLog; + +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.neo4j.gds.assertj.Extractors.removingThreadId; + +@GdlExtension +class NodeEmbeddingAlgorithmsTest { + @SuppressWarnings("unused") + @GdlGraph + private static final String DB_CYPHER = + "CREATE" + + " (a:Node1)" + + ", (b:Node1)" + + ", (c:Node2)" + + ", (d:Isolated)" + + ", (e:Isolated)" + + ", (a)-[:REL {prop: 1.0}]->(b)" + + ", (b)-[:REL {prop: 1.0}]->(a)" + + ", (a)-[:REL {prop: 1.0}]->(c)" + + ", (c)-[:REL {prop: 1.0}]->(a)" + + ", (b)-[:REL {prop: 1.0}]->(c)" + + ", (c)-[:REL {prop: 1.0}]->(b)"; + + @SuppressWarnings("unused") + @Inject + private Graph graph; + + @SuppressWarnings("unused") + @Inject + private GraphStore graphStore; + + @Test + void shouldLogProgressForNode2Vec() { + var log = new GdsTestLog(); + var requestScopedDependencies = RequestScopedDependencies.builder() + .with(EmptyTaskRegistryFactory.INSTANCE) + .with(EmptyUserLogRegistryFactory.INSTANCE) + .build(); + var progressTrackerCreator = new ProgressTrackerCreator(log, requestScopedDependencies); + var nodeEmbeddingAlgorithms = new NodeEmbeddingAlgorithms(null, progressTrackerCreator, null); + + var configuration = Node2VecStreamConfigImpl.builder().embeddingDimension(128).build(); + + var graph = graphStore.getGraph(RelationshipType.of("REL"), Optional.empty()); + nodeEmbeddingAlgorithms.node2Vec(graph, configuration); + + assertThat(log.getMessages(TestLog.INFO)) + .extracting(removingThreadId()) + .contains( + "Node2Vec :: Start", + "Node2Vec :: RandomWalk :: Start", + "Node2Vec :: RandomWalk :: create walks :: Start", + "Node2Vec :: RandomWalk :: create walks 100%", + "Node2Vec :: RandomWalk :: create walks :: Finished", + "Node2Vec :: RandomWalk :: Finished", + "Node2Vec :: train :: Start", + "Node2Vec :: train :: iteration 1 of 1 :: Start", + "Node2Vec :: train :: iteration 1 of 1 100%", + "Node2Vec :: train :: iteration 1 of 1 :: Finished", + "Node2Vec :: train :: Finished", + "Node2Vec :: Finished" + ); + } + + @Test + void shouldLogProgressForNode2VecWithRelationshipWeights() { + var log = new GdsTestLog(); + var requestScopedDependencies = RequestScopedDependencies.builder() + .with(EmptyTaskRegistryFactory.INSTANCE) + .with(EmptyUserLogRegistryFactory.INSTANCE) + .build(); + var progressTrackerCreator = new ProgressTrackerCreator(log, requestScopedDependencies); + var nodeEmbeddingAlgorithms = new NodeEmbeddingAlgorithms(null, progressTrackerCreator, null); + + var configuration = Node2VecStreamConfigImpl.builder().embeddingDimension(128).build(); + nodeEmbeddingAlgorithms.node2Vec(graph, configuration); + + assertThat(log.getMessages(TestLog.INFO)) + .extracting(removingThreadId()) + .contains( + "Node2Vec :: Start", + "Node2Vec :: RandomWalk :: Start", + "Node2Vec :: RandomWalk :: DegreeCentrality :: Start", + "Node2Vec :: RandomWalk :: DegreeCentrality :: Finished", + "Node2Vec :: RandomWalk :: create walks :: Start", + "Node2Vec :: RandomWalk :: create walks 100%", + "Node2Vec :: RandomWalk :: create walks :: Finished", + "Node2Vec :: RandomWalk :: Finished", + "Node2Vec :: train :: Start", + "Node2Vec :: train :: iteration 1 of 1 :: Start", + "Node2Vec :: train :: iteration 1 of 1 100%", + "Node2Vec :: train :: iteration 1 of 1 :: Finished", + "Node2Vec :: train :: Finished", + "Node2Vec :: Finished" + ); + } +}