Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
Co-authored-by: Lasse Westh-Nielsen <[email protected]>
Co-authored-by: Ioannis Panagiotas <[email protected]>
  • Loading branch information
3 people committed Nov 15, 2023
1 parent 79aabee commit 3141b30
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,15 @@ public <A extends Algorithm<R>, R, C extends AlgoBaseConfig> AlgorithmComputatio
);

// run the algorithm
var algorithmMetric = algorithmMetricsService.create(algorithmFactory.taskName());
var algorithmResult = runAlgorithm(algorithm, algorithmFactory.taskName());
return AlgorithmComputationResult.of(algorithmResult, graph, graphStore, algorithm.getTerminationFlag());
}

<R> R runAlgorithm(Algorithm<R> algorithm, String algorithmName) {
var algorithmMetric = algorithmMetricsService.create(algorithmName);
try(algorithmMetric) {
algorithmMetric.start();
var algorithmResult = algorithm.compute();

return AlgorithmComputationResult.of(algorithmResult, graph, graphStore, algorithm.getTerminationFlag());
return algorithm.compute();
} catch (Exception e) {
log.warn("Computation failed", e);
algorithm.getProgressTracker().endSubTaskWithFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,14 @@
*/
package org.neo4j.gds.algorithms.community;

import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.Test;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.algorithms.AlgorithmMemoryValidationService;
import org.neo4j.gds.algorithms.metrics.AlgorithmMetric;
import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService;
import org.neo4j.gds.api.DatabaseId;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.User;
import org.neo4j.gds.compat.Neo4jProxy;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.loading.GraphStoreCatalogService;
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory;
import org.neo4j.gds.logging.Log;

import java.util.Optional;

import static org.assertj.core.api.Assertions.assertThatException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
Expand All @@ -53,43 +39,24 @@ class BasicAlgorithmRunnerTest {

@Test
void shouldRegisterAlgorithmMetricCountForSuccess() {
var graphMock = mock(Graph.class);
when(graphMock.isEmpty()).thenReturn(false);

var graphStoreCatalogServiceMock = mock(GraphStoreCatalogService.class);
when(graphStoreCatalogServiceMock.getGraphWithGraphStore(any(), any(), any(), any(), any()))
.thenReturn(Pair.of(graphMock, mock(GraphStore.class)));

var algorithmMetricMock = mock(AlgorithmMetric.class);
var algorithmMetricsServiceMock = mock(AlgorithmMetricsService.class);
when(algorithmMetricsServiceMock.create(anyString())).thenReturn(algorithmMetricMock);

var logMock = mock(Log.class);
when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog());

var runner = new BasicAlgorithmRunner(
graphStoreCatalogServiceMock,
TaskRegistryFactory.empty(),
EmptyUserLogRegistryFactory.INSTANCE,
mock(AlgorithmMemoryValidationService.class),
null,
null,
null,
null,
algorithmMetricsServiceMock,
logMock
null
);

var algorithmMock = mock(Algorithm.class, RETURNS_DEEP_STUBS);
var algorithmMock = mock(Algorithm.class);
when(algorithmMock.compute()).thenReturn("WooHoo");
var algorithmFactoryMock = mock(GraphAlgorithmFactory.class);
when(algorithmFactoryMock.taskName()).thenReturn("TestingMetrics");
when(algorithmFactoryMock.build(any(), any(), any(), any(), any())).thenReturn(algorithmMock);

runner.run(
"foo",
mock(AlgoBaseConfig.class),
Optional.empty(),
algorithmFactoryMock,
mock(User.class),
DatabaseId.EMPTY
);


runner.runAlgorithm(algorithmMock, "TestingMetrics");

verify(algorithmMetricsServiceMock, times(1)).create("TestingMetrics");
verify(algorithmMetricMock, times(1)).start();
Expand All @@ -104,13 +71,6 @@ void shouldRegisterAlgorithmMetricCountForSuccess() {

@Test
void shouldRegisterAlgorithmMetricCountForFailure() {
var graphMock = mock(Graph.class);
when(graphMock.isEmpty()).thenReturn(false);

var graphStoreCatalogServiceMock = mock(GraphStoreCatalogService.class);
when(graphStoreCatalogServiceMock.getGraphWithGraphStore(any(), any(), any(), any(), any()))
.thenReturn(Pair.of(graphMock, mock(GraphStore.class)));

var algorithmMetricMock = mock(AlgorithmMetric.class);
var algorithmMetricsServiceMock = mock(AlgorithmMetricsService.class);
when(algorithmMetricsServiceMock.create(anyString())).thenReturn(algorithmMetricMock);
Expand All @@ -119,29 +79,21 @@ void shouldRegisterAlgorithmMetricCountForFailure() {
when(logMock.getNeo4jLog()).thenReturn(Neo4jProxy.testLog());

var runner = new BasicAlgorithmRunner(
graphStoreCatalogServiceMock,
TaskRegistryFactory.empty(),
EmptyUserLogRegistryFactory.INSTANCE,
mock(AlgorithmMemoryValidationService.class),
null,
null,
null,
null,
algorithmMetricsServiceMock,
logMock
);

var algorithmMock = mock(Algorithm.class, RETURNS_DEEP_STUBS);
when(algorithmMock.compute()).thenThrow(new RuntimeException("Ooops"));

var algorithmFactoryMock = mock(GraphAlgorithmFactory.class);
when(algorithmFactoryMock.taskName()).thenReturn("TestingMetrics");
when(algorithmFactoryMock.build(any(), any(), any(), any(), any())).thenReturn(algorithmMock);

assertThatException().isThrownBy(
() -> runner.run(
"foo",
mock(AlgoBaseConfig.class),
Optional.empty(),
algorithmFactoryMock,
mock(User.class),
DatabaseId.EMPTY
() -> runner.runAlgorithm(
algorithmMock,
"TestingMetrics"
)
).withMessage("Ooops");

Expand Down

0 comments on commit 3141b30

Please sign in to comment.