Skip to content

Commit

Permalink
migrate node classification write
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed Oct 8, 2024
1 parent b8aa611 commit 37afa78
Show file tree
Hide file tree
Showing 14 changed files with 414 additions and 154 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.applications.algorithms.machinery;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.PropertyState;
import org.neo4j.gds.api.ResultStore;
import org.neo4j.gds.api.schema.PropertySchema;
import org.neo4j.gds.applications.algorithms.metadata.NodePropertiesWritten;
import org.neo4j.gds.config.WriteConfig;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.loading.Capabilities;
import org.neo4j.gds.core.utils.progress.JobId;
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
import org.neo4j.gds.core.write.NodeProperty;
import org.neo4j.gds.core.write.NodePropertyExporter;
import org.neo4j.gds.core.write.NodePropertyExporterBuilder;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;

import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;

/**
* Common bits of node property writes, squirrelled away in one place
*/
public class NodePropertyWriter {
private final Log log;

private final NodePropertyExporterBuilder nodePropertyExporterBuilder;
private final TaskRegistryFactory taskRegistryFactory;
private final TerminationFlag terminationFlag;

public NodePropertyWriter(
Log log,
NodePropertyExporterBuilder nodePropertyExporterBuilder,
TaskRegistryFactory taskRegistryFactory,
TerminationFlag terminationFlag
) {
this.log = log;
this.nodePropertyExporterBuilder = nodePropertyExporterBuilder;
this.taskRegistryFactory = taskRegistryFactory;
this.terminationFlag = terminationFlag;
}

public NodePropertiesWritten writeNodeProperties(
Graph graph,
GraphStore graphStore,
Optional<ResultStore> resultStore,
Collection<NodeProperty> nodeProperties,
JobId jobId,
Label label,
WriteConfig writeConfig
) {
preFlightCheck(
graphStore.capabilities().writeMode(),
graph.schema().nodeSchema().unionProperties(),
nodeProperties
);

var progressTracker = createProgressTracker(graph.nodeCount(), writeConfig.writeConcurrency(), label);

var nodePropertyExporter = nodePropertyExporterBuilder
.parallel(DefaultPool.INSTANCE, writeConfig.writeConcurrency())
.withIdMap(graph)
.withJobId(jobId)
.withProgressTracker(progressTracker)
.withResultStore(resultStore)
.withTerminationFlag(terminationFlag)
.build();

try {
return writeNodeProperties(nodePropertyExporter, nodeProperties);
} finally {
progressTracker.release();
}
}

private ProgressTracker createProgressTracker(
long taskVolume,
Concurrency writeConcurrency,
Label label
) {
var task = NodePropertyExporter.baseTask(label.asString(), taskVolume);

return new TaskProgressTracker(
task,
log,
writeConcurrency,
taskRegistryFactory
);
}

private Predicate<PropertyState> expectedPropertyStateForWriteMode(Capabilities.WriteMode writeMode) {
return switch (writeMode) {
case LOCAL ->
// We need to allow persistent and transient as for example algorithms that support seeding will reuse a
// mutated (transient) property to write back properties that are in fact backed by a database
state -> state == PropertyState.PERSISTENT || state == PropertyState.TRANSIENT;
case REMOTE ->
// We allow transient properties for the same reason as above
state -> state == PropertyState.REMOTE || state == PropertyState.TRANSIENT;
default -> throw new IllegalStateException(
formatWithLocale(
"Graph with write mode `%s` cannot write back to a database",
writeMode
)
);
};
}

private void preFlightCheck(
Capabilities.WriteMode writeMode,
Map<String, PropertySchema> propertySchemas,
Collection<NodeProperty> nodeProperties
) {
if (writeMode == Capabilities.WriteMode.REMOTE) throw new IllegalArgumentException(
"Missing arrow connection information");

var expectedPropertyState = expectedPropertyStateForWriteMode(writeMode);

var unexpectedProperties = nodeProperties.stream()
.filter(nodeProperty -> {
var propertySchema = propertySchemas.get(nodeProperty.key());
if (propertySchema == null) {
// We are executing an algorithm write mode and the property we are writing is
// not in the GraphStore, therefore we do not perform any more checks
return false;
}
var propertyState = propertySchema.state();
return !expectedPropertyState.test(propertyState);
})
.map(
nodeProperty -> formatWithLocale(
"NodeProperty{propertyKey=%s, propertyState=%s}",
nodeProperty.key(),
propertySchemas.get(nodeProperty.key()).state()
)
)
.toList();

if (!unexpectedProperties.isEmpty()) {
throw new IllegalStateException(
formatWithLocale(
"Expected all properties to be of state `%s` but some properties differ: %s",
expectedPropertyState,
unexpectedProperties
)
);
}
}

private NodePropertiesWritten writeNodeProperties(
NodePropertyExporter nodePropertyExporter,
Collection<NodeProperty> nodeProperties
) {
nodePropertyExporter.write(nodeProperties);

return new NodePropertiesWritten(nodePropertyExporter.propertiesWritten());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,21 @@
*/
package org.neo4j.gds;

import org.neo4j.gds.api.PropertyState;
import org.neo4j.gds.api.schema.PropertySchema;
import org.neo4j.gds.applications.algorithms.machinery.NodePropertyWriter;
import org.neo4j.gds.applications.algorithms.machinery.StandardLabel;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.config.WritePropertyConfig;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.loading.Capabilities.WriteMode;
import org.neo4j.gds.core.utils.ProgressTimer;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
import org.neo4j.gds.core.write.NodeProperty;
import org.neo4j.gds.core.write.NodePropertyExporter;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ComputationResultConsumer;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.result.AbstractResultBuilder;

import java.util.Collection;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Stream;

import static org.neo4j.gds.LoggingUtil.runWithExceptionLogging;
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;

public class WriteNodePropertiesComputationResultConsumer<ALGO extends Algorithm<ALGO_RESULT>, ALGO_RESULT, CONFIG extends WritePropertyConfig & AlgoBaseConfig, RESULT>
implements
ComputationResultConsumer<ALGO, ALGO_RESULT, CONFIG, Stream<RESULT>> {

public class WriteNodePropertiesComputationResultConsumer<ALGO extends Algorithm<ALGO_RESULT>, ALGO_RESULT, CONFIG extends WritePropertyConfig & AlgoBaseConfig, RESULT> implements ComputationResultConsumer<ALGO, ALGO_RESULT, CONFIG, Stream<RESULT>> {
private final ResultBuilderFunction<ALGO, ALGO_RESULT, CONFIG, RESULT> resultBuilderFunction;
private final WriteNodePropertyListFunction<ALGO, ALGO_RESULT, CONFIG> nodePropertyListFunction;
private final String procedureName;
Expand All @@ -62,81 +48,6 @@ public WriteNodePropertiesComputationResultConsumer(
this.procedureName = procedureName;
}

private static void validatePropertiesCanBeWritten(
WriteMode writeMode,
Map<String, PropertySchema> propertySchemas,
Collection<NodeProperty> nodeProperties
) {
if (writeMode == WriteMode.REMOTE) {
throw new IllegalArgumentException("Missing arrow connection information");
}

var expectedPropertyState = expectedPropertyStateForWriteMode(writeMode);

var unexpectedProperties = nodeProperties
.stream()
.filter(nodeProperty -> {
var propertySchema = propertySchemas.get(nodeProperty.key());
if (propertySchema == null) {
// We are executing an algorithm write mode and the property we are writing is
// not in the GraphStore, therefore we do not perform any more checks
return false;
}
var propertyState = propertySchema.state();
return !expectedPropertyState.test(propertyState);
})
.map(
nodeProperty -> formatWithLocale(
"NodeProperty{propertyKey=%s, propertyState=%s}",
nodeProperty.key(),
propertySchemas.get(nodeProperty.key()).state()
)
)
.toList();

if (!unexpectedProperties.isEmpty()) {
throw new IllegalStateException(
formatWithLocale(
"Expected all properties to be of state `%s` but some properties differ: %s",
expectedPropertyState,
unexpectedProperties
)
);
}
}

private static Predicate<PropertyState> expectedPropertyStateForWriteMode(WriteMode writeMode) {
switch (writeMode) {
case LOCAL:
// We need to allow persistent and transient as for example algorithms that support seeding will reuse a
// mutated (transient) property to write back properties that are in fact backed by a database
return state -> state == PropertyState.PERSISTENT || state == PropertyState.TRANSIENT;
case REMOTE:
// We allow transient properties for the same reason as above
return state -> state == PropertyState.REMOTE || state == PropertyState.TRANSIENT;
default:
throw new IllegalStateException(
formatWithLocale(
"Graph with write mode `%s` cannot write back to a database",
writeMode
)
);
}
}

ProgressTracker createProgressTracker(
long taskVolume,
Concurrency writeConcurrency,
ExecutionContext executionContext
) {
return new TaskProgressTracker(
NodePropertyExporter.baseTask(this.procedureName, taskVolume),
executionContext.log(),
writeConcurrency,
executionContext.taskRegistryFactory()
);
}

@Override
public Stream<RESULT> consume(
ComputationResult<ALGO, ALGO_RESULT, CONFIG> computationResult,
Expand All @@ -158,48 +69,41 @@ public Stream<RESULT> consume(
});
}

void writeToNeo(
private void writeToNeo(
AbstractResultBuilder<?> resultBuilder,
ComputationResult<ALGO, ALGO_RESULT, CONFIG> computationResult,
ExecutionContext executionContext
) {
try (ProgressTimer ignored = ProgressTimer.start(resultBuilder::withWriteMillis)) {
var log = executionContext.log();
var nodePropertyExporterBuilder = executionContext.nodePropertyExporterBuilder();
var taskRegistryFactory = executionContext.taskRegistryFactory();
var terminationFlag = computationResult.algorithm().terminationFlag;
var nodePropertyWriter = new NodePropertyWriter(
log,
nodePropertyExporterBuilder,
taskRegistryFactory,
terminationFlag
);

var graph = computationResult.graph();
var graphStore = computationResult.graphStore();
var config = computationResult.config();
var progressTracker = createProgressTracker(
graph.nodeCount(),
config.writeConcurrency(),
executionContext
);
var writeMode = computationResult.graphStore().capabilities().writeMode();
var nodePropertySchema = graph.schema().nodeSchema().unionProperties();
var resultStore = config.resolveResultStore(computationResult.resultStore());
var nodeProperties = nodePropertyListFunction.apply(computationResult);

validatePropertiesCanBeWritten(
writeMode,
nodePropertySchema,
nodeProperties
var nodePropertiesWritten = nodePropertyWriter.writeNodeProperties(
graph,
graphStore,
resultStore,
nodeProperties,
config.jobId(),
new StandardLabel(procedureName),
config
);

var resultStore = config.resolveResultStore(computationResult.resultStore());
var exporter = executionContext
.nodePropertyExporterBuilder()
.withIdMap(graph)
.withTerminationFlag(computationResult.algorithm().terminationFlag)
.withProgressTracker(progressTracker)
.withResultStore(resultStore)
.withJobId(config.jobId())
.parallel(DefaultPool.INSTANCE, config.writeConcurrency())
.build();

try {
exporter.write(nodeProperties);
} finally {
progressTracker.release();
}

resultBuilder.withNodeCount(computationResult.graph().nodeCount());
resultBuilder.withNodePropertiesWritten(exporter.propertiesWritten());
resultBuilder.withNodePropertiesWritten(nodePropertiesWritten.value());
}
}
}
Loading

0 comments on commit 37afa78

Please sign in to comment.