Skip to content

Commit

Permalink
termination flag is request scoped
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed Sep 26, 2023
1 parent c574767 commit aa5f88d
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.procedures.GraphDataScience;
import org.neo4j.gds.procedures.TaskRegistryFactoryService;
import org.neo4j.gds.procedures.TerminationFlagService;
import org.neo4j.gds.procedures.integration.AlgorithmMetaDataSetterService;
import org.neo4j.gds.procedures.integration.CatalogFacadeFactory;
import org.neo4j.gds.procedures.integration.CommunityProcedureFactory;
Expand Down Expand Up @@ -90,6 +91,7 @@ private void registerComponents(Dependencies dependencies, Log log) {
log.info("Progress tracking: " + (progressTrackingEnabled ? "enabled" : "disabled"));
var taskStoreService = new TaskStoreService(progressTrackingEnabled);
var taskRegistryFactoryService = new TaskRegistryFactoryService(progressTrackingEnabled, taskStoreService);
var terminationFlagService = new TerminationFlagService();
var useMaxMemoryEstimation = neo4jConfig.get(MemoryEstimationSettings.validate_using_max_memory_estimation);
log.info("Memory usage guard: " + (useMaxMemoryEstimation ? "maximum" : "minimum") + " estimate");
var userLogServices = new UserLogServices();
Expand All @@ -105,6 +107,7 @@ private void registerComponents(Dependencies dependencies, Log log) {
databaseIdService,
__ -> new NativeExportBuildersProvider(), // we always just offer native writes in OpenGDS
taskRegistryFactoryService,
terminationFlagService,
userLogServices,
userServices,
Optional.empty() // we have no extra checks to do in OpenGDS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
import org.neo4j.gds.core.utils.TerminationFlag;

public class TerminationFlagService {
/**
* Improve this: strip off the ktx service, can happen later
*/
@Deprecated
public TerminationFlag terminationFlag(KernelTransactionService kernelTransactionService) {
var kernelTransaction = kernelTransactionService.getKernelTransaction();
var terminationMonitor = new TransactionTerminationMonitor(kernelTransaction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@
import org.neo4j.gds.core.loading.GraphDropRelationshipResult;
import org.neo4j.gds.core.loading.GraphFilterResult;
import org.neo4j.gds.core.loading.GraphProjectCypherResult;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.warnings.UserLogEntry;
import org.neo4j.gds.procedures.KernelTransactionService;
import org.neo4j.gds.procedures.ProcedureTransactionService;
import org.neo4j.gds.procedures.TaskRegistryFactoryService;
import org.neo4j.gds.procedures.TerminationFlagService;
import org.neo4j.gds.procedures.TransactionContextService;
import org.neo4j.gds.projection.GraphProjectNativeResult;
import org.neo4j.gds.results.MemoryEstimateResult;
Expand Down Expand Up @@ -78,7 +78,7 @@ public class CatalogFacade {
private final ProcedureReturnColumns procedureReturnColumns;
private final ProcedureTransactionService procedureTransactionService;
private final TaskRegistryFactoryService taskRegistryFactoryService;
private final TerminationFlagService terminationFlagService;
private final TerminationFlag terminationFlag;
private final TransactionContextService transactionContextService;
private final User user;
private final UserLogServices userLogServices;
Expand All @@ -93,7 +93,7 @@ public CatalogFacade(
ProcedureReturnColumns procedureReturnColumns,
ProcedureTransactionService procedureTransactionService,
TaskRegistryFactoryService taskRegistryFactoryService,
TerminationFlagService terminationFlagService,
TerminationFlag terminationFlag,
TransactionContextService transactionContextService,
User user,
UserLogServices userLogServices,
Expand All @@ -105,12 +105,12 @@ public CatalogFacade(
this.procedureReturnColumns = procedureReturnColumns;
this.procedureTransactionService = procedureTransactionService;
this.taskRegistryFactoryService = taskRegistryFactoryService;
this.terminationFlagService = terminationFlagService;
this.transactionContextService = transactionContextService;
this.userLogServices = userLogServices;
this.user = user;

this.businessFacade = businessFacade;
this.terminationFlag = terminationFlag;
}

/**
Expand Down Expand Up @@ -176,7 +176,6 @@ public Stream<GraphInfoWithHistogram> listGraphs(String graphName) {
graphName = validateValue(graphName);

var displayDegreeDistribution = procedureReturnColumns.contains("degreeDistribution");
var terminationFlag = terminationFlagService.terminationFlag(kernelTransactionService);

var results = businessFacade.listGraphs(user, graphName, displayDegreeDistribution, terminationFlag);

Expand All @@ -198,7 +197,6 @@ public Stream<GraphProjectNativeResult> nativeProject(
Map<String, Object> configuration
) {
var taskRegistryFactory = taskRegistryFactoryService.getTaskRegistryFactory(databaseId, user);
var terminationFlag = terminationFlagService.terminationFlag(kernelTransactionService);
var transactionContext = transactionContextService.transactionContext(
graphDatabaseService,
procedureTransactionService
Expand Down Expand Up @@ -228,7 +226,6 @@ public Stream<MemoryEstimateResult> estimateNativeProject(
Map<String, Object> configuration
) {
var taskRegistryFactory = taskRegistryFactoryService.getTaskRegistryFactory(databaseId, user);
var terminationFlag = terminationFlagService.terminationFlag(kernelTransactionService);
var transactionContext = transactionContextService.transactionContext(
graphDatabaseService,
procedureTransactionService
Expand Down Expand Up @@ -256,7 +253,6 @@ public Stream<GraphProjectCypherResult> cypherProject(
Map<String, Object> configuration
) {
var taskRegistryFactory = taskRegistryFactoryService.getTaskRegistryFactory(databaseId, user);
var terminationFlag = terminationFlagService.terminationFlag(kernelTransactionService);
var transactionContext = transactionContextService.transactionContext(
graphDatabaseService,
procedureTransactionService
Expand Down Expand Up @@ -285,7 +281,6 @@ public Stream<MemoryEstimateResult> estimateCypherProject(
Map<String, Object> configuration
) {
var taskRegistryFactory = taskRegistryFactoryService.getTaskRegistryFactory(databaseId, user);
var terminationFlag = terminationFlagService.terminationFlag(kernelTransactionService);
var transactionContext = transactionContextService.transactionContext(
graphDatabaseService,
procedureTransactionService
Expand Down Expand Up @@ -537,7 +532,6 @@ public Stream<NodePropertiesWriteResult> writeNodeProperties(
Map<String, Object> configuration
) {
var taskRegistryFactory = taskRegistryFactoryService.getTaskRegistryFactory(databaseId, user);
var terminationFlag = terminationFlagService.terminationFlag(kernelTransactionService);
var userLogRegistryFactory = userLogServices.getUserLogRegistryFactory(databaseId, user);

var result = businessFacade.writeNodeProperties(
Expand All @@ -561,8 +555,6 @@ public Stream<WriteRelationshipPropertiesResult> writeRelationshipProperties(
List<String> relationshipProperties,
Map<String, Object> configuration
) {
var terminationFlag = terminationFlagService.terminationFlag(kernelTransactionService);

var result = businessFacade.writeRelationshipProperties(
user,
databaseId,
Expand All @@ -581,8 +573,6 @@ public Stream<WriteLabelResult> writeNodeLabel(
String nodeLabel,
Map<String, Object> configuration
) {
var terminationFlag = terminationFlagService.terminationFlag(kernelTransactionService);

var result = businessFacade.writeNodeLabel(
user,
databaseId,
Expand All @@ -602,7 +592,6 @@ public Stream<WriteRelationshipResult> writeRelationships(
Map<String, Object> configuration
) {
var taskRegistryFactory = taskRegistryFactoryService.getTaskRegistryFactory(databaseId, user);
var terminationFlag = terminationFlagService.terminationFlag(kernelTransactionService);
var userLogRegistryFactory = userLogServices.getUserLogRegistryFactory(databaseId, user);

var result = businessFacade.writeRelationships(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ public class CatalogFacadeFactory {
private final DatabaseIdService databaseIdService;
private final ExporterBuildersProviderService exporterBuildersProviderService;
private final TaskRegistryFactoryService taskRegistryFactoryService;
private final TerminationFlagService terminationFlagService;
private final UserLogServices userLogServices;
private final UserServices userServices;

Expand All @@ -99,6 +100,7 @@ public CatalogFacadeFactory(
DatabaseIdService databaseIdService,
ExporterBuildersProviderService exporterBuildersProviderService,
TaskRegistryFactoryService taskRegistryFactoryService,
TerminationFlagService terminationFlagService,
UserLogServices userLogServices,
UserServices userServices,
Optional<Function<CatalogBusinessFacade, CatalogBusinessFacade>> businessFacadeDecorator
Expand All @@ -109,6 +111,7 @@ public CatalogFacadeFactory(
this.databaseIdService = databaseIdService;
this.exporterBuildersProviderService = exporterBuildersProviderService;
this.taskRegistryFactoryService = taskRegistryFactoryService;
this.terminationFlagService = terminationFlagService;
this.userLogServices = userLogServices;
this.userServices = userServices;

Expand All @@ -128,7 +131,7 @@ public CatalogFacade createCatalogFacade(Context context) {
var kernelTransactionService = new KernelTransactionService(context);
var procedureTransactionService = new ProcedureTransactionService(context);
var procedureReturnColumns = new ProcedureCallContextReturnColumns(context.procedureCallContext());
var terminationFlagService = new TerminationFlagService();
var terminationFlag = terminationFlagService.terminationFlag(kernelTransactionService);
var transactionContextService = new TransactionContextService();
var user = userServices.getUser(context.securityContext());

Expand Down Expand Up @@ -232,7 +235,7 @@ public CatalogFacade createCatalogFacade(Context context) {
procedureReturnColumns,
procedureTransactionService,
taskRegistryFactoryService,
terminationFlagService,
terminationFlag,
transactionContextService,
user,
userLogServices,
Expand Down

0 comments on commit aa5f88d

Please sign in to comment.