Skip to content

Commit

Permalink
Expose algorithm-metrics-api as transitive of proc-common
Browse files Browse the repository at this point in the history
Co-authored-by: Ioannis Panagiotas <[email protected]>
  • Loading branch information
vnickolov and IoannisPanagiotas committed Nov 15, 2023
1 parent 759750d commit 79aabee
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.neo4j.gds.NodeProjections;
import org.neo4j.gds.ProcedureCallContextReturnColumns;
import org.neo4j.gds.RelationshipProjections;
import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService;
import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar;
import org.neo4j.gds.api.AlgorithmMetaDataSetter;
import org.neo4j.gds.api.CloseableResourceRegistry;
import org.neo4j.gds.api.DatabaseId;
Expand Down Expand Up @@ -85,6 +87,7 @@ void setup() throws Exception {
.nodeLookup(NodeLookup.EMPTY)
.modelCatalog(ModelCatalog.EMPTY)
.isGdsAdmin(false)
.algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()))
.build();

memoryEstimationExecutor = new MemoryEstimationExecutor<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.junit.jupiter.api.Test;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService;
import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar;
import org.neo4j.gds.api.AlgorithmMetaDataSetter;
import org.neo4j.gds.api.CloseableResourceRegistry;
import org.neo4j.gds.api.DatabaseId;
Expand Down Expand Up @@ -195,6 +197,7 @@ void deriveRelationshipWeightProperty() {
.taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE)
.userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE)
.isGdsAdmin(false)
.algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()))
.build();

var pipeline = new LinkPredictionTrainingPipeline();
Expand Down Expand Up @@ -239,6 +242,7 @@ void deriveRelationshipWeightPropertyFromTrainedModel() {
.taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE)
.userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE)
.isGdsAdmin(false)
.algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()))
.build();

var pipeline = new LinkPredictionTrainingPipeline();
Expand Down Expand Up @@ -283,6 +287,7 @@ void notDerivePropertyFromUnweightedTrainedModel() {
.taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE)
.userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE)
.isGdsAdmin(false)
.algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()))
.build();

var pipeline = new LinkPredictionTrainingPipeline();
Expand Down
2 changes: 1 addition & 1 deletion proc/common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ dependencies {
annotationProcessor group: 'org.immutables', name: 'value', version: ver.'immutables'

api(project(':algo'))
api project(':algorithm-metrics-api')
api(project(':model-catalog-api'))

implementation project(':annotations')
implementation project(':algo-common')
implementation project(':algorithm-metrics-api')
implementation project(':config-api')
implementation project(':core')
implementation project(':core-write')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService;
import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar;
import org.neo4j.gds.api.AlgorithmMetaDataSetter;
import org.neo4j.gds.api.CSRGraph;
import org.neo4j.gds.api.CloseableResourceRegistry;
Expand Down Expand Up @@ -88,6 +90,7 @@ class MutatePropertyComputationResultConsumerTest {
.nodeLookup(NodeLookup.EMPTY)
.modelCatalog(ModelCatalog.EMPTY)
.isGdsAdmin(false)
.algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()))
.build();

@BeforeEach
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
package org.neo4j.gds;

import org.junit.jupiter.api.Test;
import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService;
import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar;
import org.neo4j.gds.api.AlgorithmMetaDataSetter;
import org.neo4j.gds.api.CloseableResourceRegistry;
import org.neo4j.gds.api.DatabaseId;
Expand Down Expand Up @@ -96,6 +98,7 @@ class WriteNodePropertiesComputationResultConsumerTest extends BaseTest {
.modelCatalog(ModelCatalog.EMPTY)
.isGdsAdmin(false)
.nodePropertyExporterBuilder(new NativeNodePropertiesExporterBuilder(EmptyTransactionContext.INSTANCE))
.algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()))
.build();

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
package org.neo4j.gds;

import org.junit.jupiter.api.Test;
import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService;
import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar;
import org.neo4j.gds.api.AlgorithmMetaDataSetter;
import org.neo4j.gds.api.CloseableResourceRegistry;
import org.neo4j.gds.api.DatabaseId;
Expand Down Expand Up @@ -124,6 +126,7 @@ public long nodeCount() {
.modelCatalog(ModelCatalog.EMPTY)
.isGdsAdmin(false)
.nodePropertyExporterBuilder(new NativeNodePropertiesExporterBuilder(DatabaseTransactionContext.of(db, tx)))
.algorithmMetricsService(new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()))
.build();

assertThatThrownBy(() -> resultConsumer.consume(computationResult, executionContext))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class DfsStreamComputationResultConsumerTest {
void shouldNotComputePath() {
when(graphMock.toOriginalNodeId(anyLong())).then(returnsFirstArg());


when(computationResultMock.graph()).thenReturn(graphMock);
when(computationResultMock.result()).thenReturn(Optional.of(HugeLongArray.of(1L, 2L)));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.neo4j.gds.GdsCypher;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.TestTaskStore;
import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService;
import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.nodeproperties.ValueType;
import org.neo4j.gds.assertj.ConditionFactory;
Expand Down Expand Up @@ -209,6 +211,9 @@ void cleanupTaskRegistryWhenTheAlgorithmFailsInStreamMode() {
proc.procedureTransaction = transactions.tx();
proc.log = NullLog.getInstance();
proc.callContext = ProcedureCallContext.EMPTY;

proc.algorithmMetricsService = new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar());

Map<String, Object> config = Map.of(
"maxIterations", 20,
"throwInCompute", true
Expand All @@ -234,6 +239,9 @@ void cleanupTaskRegistryWhenTheAlgorithmFailsInWriteMode() {
proc.procedureTransaction = transactions.tx();
proc.log = NullLog.getInstance();
proc.callContext = ProcedureCallContext.EMPTY;

proc.algorithmMetricsService = new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar());

Map<String, Object> config = Map.of(
"maxIterations", 20,
"throwInCompute", true
Expand All @@ -258,6 +266,9 @@ void cleanupTaskRegistryWhenTheAlgorithmFailsInMutateMode() {
proc.procedureTransaction = transactions.tx();
proc.log = NullLog.getInstance();
proc.callContext = ProcedureCallContext.EMPTY;

proc.algorithmMetricsService = new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar());

Map<String, Object> config = Map.of(
"maxIterations", 20,
"throwInCompute", true
Expand Down
10 changes: 8 additions & 2 deletions proc/test/src/main/java/org/neo4j/gds/ProcedureRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
*/
package org.neo4j.gds;

import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService;
import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar;
import org.neo4j.gds.compat.GraphDatabaseApiProxy;
import org.neo4j.gds.core.Username;
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
Expand All @@ -43,7 +45,8 @@ public static <P extends BaseProc> P instantiateProcedure(
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
Transaction tx,
Username username
Username username,
AlgorithmMetricsService algorithmMetricsService
) {
P proc;
try {
Expand All @@ -61,6 +64,8 @@ public static <P extends BaseProc> P instantiateProcedure(
proc.userLogRegistryFactory = userLogRegistryFactory;
proc.username = username;

proc.algorithmMetricsService = algorithmMetricsService;

return proc;
}

Expand All @@ -82,7 +87,8 @@ public static <P extends BaseProc> P applyOnProcedure(
taskRegistryFactory,
EmptyUserLogRegistryFactory.INSTANCE,
tx,
username
username,
new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar())
);
func.accept(proc);
return proc;
Expand Down
25 changes: 25 additions & 0 deletions test-utils/src/main/java/org/neo4j/gds/BaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.assertj.core.api.Assertions;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Timeout;
import org.neo4j.gds.algorithms.metrics.AlgorithmMetricsService;
import org.neo4j.gds.algorithms.metrics.PassthroughAlgorithmMetricRegistrar;
import org.neo4j.gds.compat.Neo4jProxy;
import org.neo4j.gds.compat.TestLog;
import org.neo4j.gds.core.Settings;
Expand All @@ -35,6 +37,11 @@
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Result;
import org.neo4j.graphdb.Transaction;
import org.neo4j.kernel.api.procedure.GlobalProcedures;
import org.neo4j.kernel.extension.ExtensionFactory;
import org.neo4j.kernel.extension.context.ExtensionContext;
import org.neo4j.kernel.lifecycle.Lifecycle;
import org.neo4j.kernel.lifecycle.LifecycleAdapter;
import org.neo4j.test.TestDatabaseManagementServiceBuilder;
import org.neo4j.test.extension.ExtensionCallback;
import org.neo4j.test.extension.ImpermanentDbmsExtension;
Expand Down Expand Up @@ -86,6 +93,24 @@ protected void configuration(TestDatabaseManagementServiceBuilder builder) {
builder.setConfigRaw(Map.of("unsupported.dbms.debug.trace_cursors", "true"));
testLog = Neo4jProxy.testLog();
builder.setUserLogProvider(new TestLogProvider(testLog));

// Hacky as hell but will have to do until we make this BaseTest obsolete
builder.addExtension(new ExtensionFactory<Dependencies>("AlgorithmMetricsServiceExtensionFactory") {
@Override
public Lifecycle newInstance(ExtensionContext context, Dependencies dependencies) {
dependencies.globalProcedures().registerComponent(
AlgorithmMetricsService.class,
ctx -> new AlgorithmMetricsService(new PassthroughAlgorithmMetricRegistrar()),
false
);
return new LifecycleAdapter();
}

});
}

interface Dependencies {
GlobalProcedures globalProcedures();
}

protected long clearDb() {
Expand Down

0 comments on commit 79aabee

Please sign in to comment.