Skip to content

Commit

Permalink
Implement BellmanFordMemoryEstimateDefinition
Browse files Browse the repository at this point in the history
  • Loading branch information
IoannisPanagiotas committed Dec 14, 2023
1 parent 1a77637 commit cd9030b
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@

import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
Expand Down Expand Up @@ -63,16 +60,6 @@ public Task progressTask(Graph graphOrGraphStore, BellmanFordBaseConfig config)

@Override
public MemoryEstimation memoryEstimation(CONFIG configuration) {
var builder = MemoryEstimations.builder(BellmanFord.class)
.perNode("frontier", HugeLongArray::memoryEstimation)
.perNode("validBitset", HugeAtomicBitSet::memoryEstimation)
.add(DistanceTracker.memoryEstimation())
.perThread("BellmanFordTask", BellmanFordTask.memoryEstimation());

if(configuration.trackNegativeCycles()) {
builder.perNode("negativeCyclesVertices", HugeLongArray::memoryEstimation);
}

return builder.build();
return new BellmanFordMemoryEstimateDefinition().memoryEstimation(configuration);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.paths.bellmanford;

import org.neo4j.gds.AlgorithmMemoryEstimateDefinition;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;

public class BellmanFordMemoryEstimateDefinition implements AlgorithmMemoryEstimateDefinition<BellmanFordBaseConfig> {

@Override
public MemoryEstimation memoryEstimation(BellmanFordBaseConfig configuration) {
var builder = MemoryEstimations.builder(BellmanFord.class)
.perNode("frontier", HugeLongArray::memoryEstimation)
.perNode("validBitset", HugeAtomicBitSet::memoryEstimation)
.add(DistanceTracker.memoryEstimation())
.perThread("BellmanFordTask", BellmanFordTask.memoryEstimation());

if(configuration.trackNegativeCycles()) {
builder.perNode("negativeCyclesVertices", HugeLongArray::memoryEstimation);
}

return builder.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,28 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.assertions.MemoryEstimationAssert;

import java.util.Map;
import java.util.stream.Stream;

import static org.neo4j.gds.TestSupport.assertMemoryEstimation;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class BellmanFordAlgorithmFactoryTest {
class BellmanFordMemoryEstimateDefinitionTest {

@ParameterizedTest(name = "{0}")
@MethodSource("memoryEstimationSetup")
void memoryEstimation(String description, boolean mutateNegativeCycles, long expectedBytes) {
var config = BellmanFordMutateConfig.of(CypherMapWrapper.create(
Map.of(
"sourceNode", 0L,
"mutateNegativeCycles", mutateNegativeCycles,
"mutateRelationshipType", "foo"
)
));
var algorithmFactory = new BellmanFordAlgorithmFactory<>();
void memoryEstimation(String description, boolean trackNegativeCycles, long expectedBytes) {

assertMemoryEstimation(
() -> algorithmFactory.memoryEstimation(config),
10,
23,
4,
MemoryRange.of(expectedBytes)
);

var config = mock(BellmanFordBaseConfig.class);
when(config.trackNegativeCycles()).thenReturn(trackNegativeCycles);

var memoryEstimation = new BellmanFordMemoryEstimateDefinition();

MemoryEstimationAssert.assertThat(memoryEstimation.memoryEstimation(config))
.memoryRange(10, 23, 4)
.hasSameMinAndMaxEqualTo(expectedBytes);
}

private static Stream<Arguments> memoryEstimationSetup() {
Expand Down

0 comments on commit cd9030b

Please sign in to comment.