diff --git a/algo/src/test/java/org/neo4j/gds/similarity/JaccardSimilarityTest.java b/algo/src/test/java/org/neo4j/gds/similarity/JaccardSimilarityTest.java deleted file mode 100644 index d472356984..0000000000 --- a/algo/src/test/java/org/neo4j/gds/similarity/JaccardSimilarityTest.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.similarity; - -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; - -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collector; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import static org.assertj.core.api.Assertions.assertThat; - -class JaccardSimilarityTest { - - @ParameterizedTest(name = "{2}") - @MethodSource("listCollectors") - void shouldPassAtAllCasesOfListInput( - Collector> firstListCollector, - Collector> secondListCollector, - String label - ) { - var arr1 = new int[]{1,2,3}; - var arr2 = new int[]{1,2,3}; - List l1 = Arrays.stream(arr1).boxed().collect(firstListCollector); - List l2 = Arrays.stream(arr2).boxed().collect(secondListCollector); - - var similarities = new SimilaritiesFunc(); - var jaccarded = similarities.jaccardSimilarity(l1, l2); - assertThat(jaccarded).isEqualTo(1); - } - - private static Stream listCollectors() { - return Stream.of( - Arguments.of( - Collectors.toUnmodifiableList(), Collectors.toUnmodifiableList(), "Unmodifiable, Unmodifiable" - ), - Arguments.of( - Collectors.toUnmodifiableList(), Collectors.toList(), "Unmodifiable, Modifiable" - ), - Arguments.of( - Collectors.toList(), Collectors.toList(), "Modifiable, Modifiable" - ), - Arguments.of( - Collectors.toList(), Collectors.toUnmodifiableList(), "Modifiable, Unmodifiable" - ) - ); - } - -} diff --git a/algo/src/test/java/org/neo4j/gds/similarity/JaccardWithCypherTest.java b/algo/src/test/java/org/neo4j/gds/similarity/JaccardWithCypherTest.java deleted file mode 100644 index 19be8870de..0000000000 --- a/algo/src/test/java/org/neo4j/gds/similarity/JaccardWithCypherTest.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.similarity; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.neo4j.gds.BaseTest; -import org.neo4j.gds.compat.GraphDatabaseApiProxy; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatNoException; -import static org.assertj.core.api.InstanceOfAssertFactories.DOUBLE; -import static org.junit.jupiter.api.Assertions.assertEquals; - -class JaccardWithCypherTest extends BaseTest { - - @BeforeEach - void setUp() throws Exception { - GraphDatabaseApiProxy.registerFunctions(db, SimilaritiesFunc.class); - } - - @Test - void testJaccardFunctionWithInputFromDatabase() { - assertThatNoException().isThrownBy( - () -> runQueryWithResultConsumer( - "CREATE (t:Test {listone: [1, 5], listtwo: [5, 5]}) RETURN gds.similarity.jaccard(t.listone, t.listtwo) AS score", - result -> { - assertThat(result.hasNext()).isTrue(); - var score = result.next().get("score"); - assertThat(score) - .asInstanceOf(DOUBLE) - .isEqualTo(1.0 / 3.0); - } - ) - ); - } - - @Test - void testJaccardFunction() { - assertThatNoException().isThrownBy( - () -> - runQueryWithResultConsumer( - "RETURN gds.similarity.jaccard([1, 5], [5, 5]) AS score", - result -> assertEquals(1.0 / 3.0, result.next().get("score")) - ) - ); - } -} diff --git a/algo/src/test/java/org/neo4j/gds/similarity/SimilaritiesFuncTest.java b/algo/src/test/java/org/neo4j/gds/similarity/SimilaritiesFuncTest.java index bec83d5284..5c217d23a5 100644 --- a/algo/src/test/java/org/neo4j/gds/similarity/SimilaritiesFuncTest.java +++ b/algo/src/test/java/org/neo4j/gds/similarity/SimilaritiesFuncTest.java @@ -28,9 +28,12 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.stream.Collector; +import java.util.stream.Collectors; import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.junit.jupiter.api.Assertions.assertEquals; class SimilaritiesFuncTest { @@ -157,4 +160,39 @@ void testJaccardWithNulls() { assertEquals(1 / 3.0, new SimilaritiesFunc().jaccardSimilarity(left, right)); } + @ParameterizedTest(name = "{2}") + @MethodSource("listCollectors") + void shouldComputeJaccardAtAllCasesOfListInput( + Collector> firstListCollector, + Collector> secondListCollector, + String label + ) { + var arr1 = new int[]{1,2,3}; + var arr2 = new int[]{1,2,3}; + var l1 = Arrays.stream(arr1).boxed().collect(firstListCollector); + var l2 = Arrays.stream(arr2).boxed().collect(secondListCollector); + + var similarities = new SimilaritiesFunc(); + assertThatNoException().isThrownBy( + () -> assertThat(similarities.jaccardSimilarity(l1, l2)).isEqualTo(1) + ); + } + + private static Stream listCollectors() { + return Stream.of( + Arguments.of( + Collectors.toUnmodifiableList(), Collectors.toUnmodifiableList(), "Unmodifiable, Unmodifiable" + ), + Arguments.of( + Collectors.toUnmodifiableList(), Collectors.toList(), "Unmodifiable, Modifiable" + ), + Arguments.of( + Collectors.toList(), Collectors.toList(), "Modifiable, Modifiable" + ), + Arguments.of( + Collectors.toList(), Collectors.toUnmodifiableList(), "Modifiable, Unmodifiable" + ) + ); + } + } diff --git a/algo/src/test/java/org/neo4j/gds/similarity/SimilaritiesFuncWithCypherTest.java b/algo/src/test/java/org/neo4j/gds/similarity/SimilaritiesFuncWithCypherTest.java index fb59d42a48..d1d2011171 100644 --- a/algo/src/test/java/org/neo4j/gds/similarity/SimilaritiesFuncWithCypherTest.java +++ b/algo/src/test/java/org/neo4j/gds/similarity/SimilaritiesFuncWithCypherTest.java @@ -24,6 +24,9 @@ import org.neo4j.gds.BaseTest; import org.neo4j.gds.compat.GraphDatabaseApiProxy; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.assertj.core.api.InstanceOfAssertFactories.DOUBLE; import static org.junit.jupiter.api.Assertions.assertEquals; class SimilaritiesFuncWithCypherTest extends BaseTest { @@ -80,4 +83,20 @@ void testJaccardFunction() { result -> assertEquals(1.0 / 3.0, result.next().get("score")) ); } + + @Test + void testJaccardFunctionWithInputFromDatabase() { + assertThatNoException().isThrownBy( + () -> runQueryWithResultConsumer( + "CREATE (t:Test {listone: [1, 5], listtwo: [5, 5]}) RETURN gds.similarity.jaccard(t.listone, t.listtwo) AS score", + result -> { + assertThat(result.hasNext()).isTrue(); + var score = result.next().get("score"); + assertThat(score) + .asInstanceOf(DOUBLE) + .isEqualTo(1.0 / 3.0); + } + ) + ); + } }