Skip to content

Commit

Permalink
migrate kge stream
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed Aug 28, 2024
1 parent 1ee4d03 commit f6fd679
Show file tree
Hide file tree
Showing 13 changed files with 243 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.ml.kge;
package org.neo4j.gds.algorithms.machinelearning;

import org.neo4j.gds.algorithms.machinelearning.KGEPredictBaseConfig;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.core.CypherMapWrapper;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,22 @@

public class MachineLearningAlgorithmsMutateModeBusinessFacade {
private final RequestScopedDependencies requestScopedDependencies;

private final MachineLearningAlgorithmsEstimationModeBusinessFacade estimation;
private final MachineLearningAlgorithms algorithms;
private final AlgorithmProcessingTemplateConvenience algorithmProcessingTemplateConvenience;

private final AlgorithmProcessingTemplateConvenience convenience;

MachineLearningAlgorithmsMutateModeBusinessFacade(
RequestScopedDependencies requestScopedDependencies,
MachineLearningAlgorithmsEstimationModeBusinessFacade estimation,
MachineLearningAlgorithms algorithms,
AlgorithmProcessingTemplateConvenience algorithmProcessingTemplateConvenience
AlgorithmProcessingTemplateConvenience convenience
) {
this.requestScopedDependencies = requestScopedDependencies;
this.estimation = estimation;
this.algorithms = algorithms;
this.algorithmProcessingTemplateConvenience = algorithmProcessingTemplateConvenience;
this.convenience = convenience;
}

public <RESULT> RESULT kge(
Expand All @@ -54,7 +56,7 @@ public <RESULT> RESULT kge(
) {
var mutateStep = new KgeMutateStep(requestScopedDependencies.getTerminationFlag(), configuration);

return algorithmProcessingTemplateConvenience.processRegularAlgorithmInMutateOrWriteMode(
return convenience.processRegularAlgorithmInMutateOrWriteMode(
graphName,
configuration,
KGE,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.applications.algorithms.machinelearning;

import org.neo4j.gds.algorithms.machinelearning.KGEPredictResult;
import org.neo4j.gds.algorithms.machinelearning.KGEPredictStreamConfig;
import org.neo4j.gds.api.GraphName;
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTemplateConvenience;
import org.neo4j.gds.applications.algorithms.machinery.StreamResultBuilder;

import java.util.stream.Stream;

import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.KGE;

public class MachineLearningAlgorithmsStreamModeBusinessFacade {
private final AlgorithmProcessingTemplateConvenience convenience;

private final MachineLearningAlgorithmsEstimationModeBusinessFacade estimation;
private final MachineLearningAlgorithms algorithms;

MachineLearningAlgorithmsStreamModeBusinessFacade(
AlgorithmProcessingTemplateConvenience convenience,
MachineLearningAlgorithmsEstimationModeBusinessFacade estimation,
MachineLearningAlgorithms algorithms
) {
this.convenience = convenience;
this.algorithms = algorithms;
this.estimation = estimation;
}

public <RESULT> Stream<RESULT> kge(
GraphName graphName,
KGEPredictStreamConfig configuration,
StreamResultBuilder<KGEPredictStreamConfig, KGEPredictResult, RESULT> resultBuilder
) {
return convenience.processRegularAlgorithmInStreamMode(
graphName,
configuration,
KGE,
estimation::kge,
(graph, __) -> algorithms.kge(graph, configuration),
resultBuilder
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,22 @@
public final class MachineLearningApplications {
private final MachineLearningAlgorithmsEstimationModeBusinessFacade estimation;
private final MachineLearningAlgorithmsMutateModeBusinessFacade mutation;
private final MachineLearningAlgorithmsStreamModeBusinessFacade streaming;

private MachineLearningApplications(
MachineLearningAlgorithmsEstimationModeBusinessFacade estimation,
MachineLearningAlgorithmsMutateModeBusinessFacade mutation
MachineLearningAlgorithmsMutateModeBusinessFacade mutation,
MachineLearningAlgorithmsStreamModeBusinessFacade streaming
) {
this.estimation = estimation;
this.mutation = mutation;
this.streaming = streaming;
}

public static MachineLearningApplications create(
RequestScopedDependencies requestScopedDependencies,
ProgressTrackerCreator progressTrackerCreator,
AlgorithmProcessingTemplateConvenience algorithmProcessingTemplateConvenience
AlgorithmProcessingTemplateConvenience convenience
) {
var algorithms = new MachineLearningAlgorithms(
progressTrackerCreator,
Expand All @@ -50,10 +53,15 @@ public static MachineLearningApplications create(
requestScopedDependencies,
estimation,
algorithms,
algorithmProcessingTemplateConvenience
convenience
);
var streaming = new MachineLearningAlgorithmsStreamModeBusinessFacade(
convenience,
estimation,
algorithms
);

return new MachineLearningApplications(estimation, mutation);
return new MachineLearningApplications(estimation, mutation, streaming);
}

public MachineLearningAlgorithmsEstimationModeBusinessFacade estimate() {
Expand All @@ -63,4 +71,8 @@ public MachineLearningAlgorithmsEstimationModeBusinessFacade estimate() {
public MachineLearningAlgorithmsMutateModeBusinessFacade mutate() {
return mutation;
}

public MachineLearningAlgorithmsStreamModeBusinessFacade stream() {
return streaming;
}
}
Loading

0 comments on commit f6fd679

Please sign in to comment.