diff --git a/algo/build.gradle b/algo/build.gradle index 4475b7218d..b367f61933 100644 --- a/algo/build.gradle +++ b/algo/build.gradle @@ -33,6 +33,7 @@ dependencies { implementation project(':core') implementation project(':core-utils') implementation project(':core-write') + implementation project(':gds-values') implementation project(':graph-schema-api') implementation project(':licensing') implementation project(':logging') diff --git a/core/build.gradle b/core/build.gradle index a13f7b178a..91fcc1cf57 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -23,6 +23,7 @@ dependencies { implementation project(':annotations') implementation project(':config-api') implementation project(':core-utils') + implementation project(':gds-values') implementation project(':licensing') implementation project(':logging') implementation project(':graph-schema-api') diff --git a/core/src/main/java/org/neo4j/gds/core/loading/GdsNeo4jValueConverter.java b/core/src/main/java/org/neo4j/gds/core/loading/GdsNeo4jValueConverter.java new file mode 100644 index 0000000000..5ff67da63c --- /dev/null +++ b/core/src/main/java/org/neo4j/gds/core/loading/GdsNeo4jValueConverter.java @@ -0,0 +1,113 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.core.loading; + +import org.jetbrains.annotations.NotNull; +import org.neo4j.gds.values.Array; +import org.neo4j.gds.values.GdsNoValue; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.gds.values.primitive.PrimitiveValues; +import org.neo4j.values.AnyValue; +import org.neo4j.values.SequenceValue; +import org.neo4j.values.storable.ArrayValue; +import org.neo4j.values.storable.IntegralValue; +import org.neo4j.values.storable.NoValue; +import org.neo4j.values.storable.Value; +import org.neo4j.values.storable.ValueGroup; +import org.neo4j.values.virtual.ListValue; + +import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; + +public final class GdsNeo4jValueConverter { + + public static GdsValue toValue(@NotNull AnyValue value) { + if (value == NoValue.NO_VALUE) { + return GdsNoValue.NO_VALUE; + } + if (value.isSequenceValue()) { // ArrayValue or ListValue + return convertSequenceValueOrFail((SequenceValue) value); + } + if (value instanceof Value storableValue && storableValue.valueGroup() == ValueGroup.NUMBER) { + if (storableValue instanceof org.neo4j.values.storable.FloatValue floatValue) { + return PrimitiveValues.floatingPointValue(floatValue.floatValue()); + } else if (storableValue instanceof org.neo4j.values.storable.DoubleValue doubleValue) { + return PrimitiveValues.floatingPointValue(doubleValue.doubleValue()); + } else if (storableValue instanceof IntegralValue integralValue) { + return PrimitiveValues.longValue(integralValue.longValue()); + } + } + throw new IllegalArgumentException(formatWithLocale( + "Unsupported conversion to GDS Value from Neo4j Value with type `%s`.", + value.getTypeName() + )); + } + + private static GdsValue convertSequenceValueOrFail(SequenceValue value) { + if (value instanceof ListValue listValue) { + return convertListValueOrFail(listValue); + } else if (value instanceof ArrayValue arrayValue) { + return convertArrayValueOrFail(arrayValue); + } else { + throw failOnBadInput(value); + } + } + + @NotNull + private static Array convertListValueOrFail(ListValue listValue) { + if (listValue.isEmpty()) { + // encode as long array + return PrimitiveValues.EMPTY_LONG_ARRAY; + } + try { + return convertArrayValueOrFail(listValue.toStorableArray()); + } catch (RuntimeException e) { + throw failOnBadInput(listValue); + } + } + + @NotNull + private static Array convertArrayValueOrFail(ArrayValue array) { + if (array.valueGroup() != ValueGroup.NUMBER_ARRAY) { + throw failOnBadInput(array); + } + if (array.isEmpty()) { + return PrimitiveValues.EMPTY_LONG_ARRAY; + } + var arrayCopy = array.asObjectCopy(); + if (arrayCopy instanceof long[]) { + return PrimitiveValues.longArray((long[]) arrayCopy); + } else if (arrayCopy instanceof double[]) { + return PrimitiveValues.doubleArray((double[]) arrayCopy); + } else { + throw failOnBadInput(array); + } + } + + private static IllegalArgumentException failOnBadInput(SequenceValue badInput) { + return new IllegalArgumentException( + formatWithLocale( + "Unsupported conversion to GDS Value from Neo4j Value `%s`.", + badInput + ) + ); + } + + private GdsNeo4jValueConverter() {} +} diff --git a/core/src/main/java/org/neo4j/gds/core/loading/RelationshipImportResult.java b/core/src/main/java/org/neo4j/gds/core/loading/RelationshipImportResult.java index a4e151a157..d3e74f66e2 100644 --- a/core/src/main/java/org/neo4j/gds/core/loading/RelationshipImportResult.java +++ b/core/src/main/java/org/neo4j/gds/core/loading/RelationshipImportResult.java @@ -30,12 +30,10 @@ import org.neo4j.gds.api.PropertyState; import org.neo4j.gds.api.RelationshipProperty; import org.neo4j.gds.api.RelationshipPropertyStore; -import org.neo4j.gds.api.ValueTypes; import org.neo4j.gds.api.nodeproperties.ValueType; import org.neo4j.gds.api.schema.Direction; import org.neo4j.gds.api.schema.MutableRelationshipSchema; import org.neo4j.gds.api.schema.MutableRelationshipSchemaEntry; -import org.neo4j.values.storable.NumberType; import java.util.Collection; import java.util.HashMap; @@ -157,7 +155,7 @@ private static RelationshipPropertyStore constructRelationshipPropertyStore( ), propertyMapping.defaultValue().isUserDefined() ? propertyMapping.defaultValue() - : ValueTypes.fromNumberType(NumberType.FLOATING_POINT).fallbackValue(), + : ValueType.DOUBLE.fallbackValue(), propertyMapping.aggregation() ) ); diff --git a/core/src/main/java/org/neo4j/gds/core/loading/ValueConverter.java b/core/src/main/java/org/neo4j/gds/core/loading/ValueConverter.java deleted file mode 100644 index 26218bb3c2..0000000000 --- a/core/src/main/java/org/neo4j/gds/core/loading/ValueConverter.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.core.loading; - -import org.jetbrains.annotations.NotNull; -import org.neo4j.gds.api.nodeproperties.ValueType; -import org.neo4j.gds.compat.Neo4jProxy; -import org.neo4j.values.AnyValue; -import org.neo4j.values.storable.ArrayValue; -import org.neo4j.values.storable.DoubleArray; -import org.neo4j.values.storable.DoubleValue; -import org.neo4j.values.storable.FloatArray; -import org.neo4j.values.storable.FloatingPointValue; -import org.neo4j.values.storable.IntegralValue; -import org.neo4j.values.storable.LongArray; -import org.neo4j.values.storable.LongValue; -import org.neo4j.values.storable.NoValue; -import org.neo4j.values.storable.Value; -import org.neo4j.values.storable.ValueGroup; -import org.neo4j.values.storable.Values; -import org.neo4j.values.virtual.ListValue; - -import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; - -public final class ValueConverter { - - public static ValueType valueType(Value value) { - if (value instanceof IntegralValue) { - return ValueType.LONG; - } else if (value instanceof FloatingPointValue) { - return ValueType.DOUBLE; - } else if (value instanceof LongArray) { - return ValueType.LONG_ARRAY; - } else if (value instanceof DoubleArray) { - return ValueType.DOUBLE_ARRAY; - } else if (value instanceof FloatArray) { - return ValueType.FLOAT_ARRAY; - } else { - throw new UnsupportedOperationException(formatWithLocale( - "Loading of values of type %s is currently not supported", - value.getTypeName() - )); - } - } - - public static Value toValue(@NotNull AnyValue value) { - if (value == NoValue.NO_VALUE) { - return NoValue.NO_VALUE; - } else if (value.isSequenceValue()) { - return castToNumericArrayOrFail(value); - } else if (value instanceof Value) { - var storableValue = (Value) value; - if (storableValue.valueGroup() != ValueGroup.NUMBER) { - throw new IllegalArgumentException(formatWithLocale( - "Unsupported GDS node property of type `%s`.", - storableValue.getTypeName() - )); - } - return storableValue; - } else { - throw new IllegalArgumentException(formatWithLocale( - "Unsupported GDS node property of type `%s`.", - value.getTypeName() - )); - } - } - private static ArrayValue castToNumericArrayOrFail(AnyValue value) { - if (value instanceof ListValue) { - return castToNumericArrayOrFail((ListValue) value); - } else if (value instanceof ArrayValue){ - return assertNumberArray((ArrayValue) value); - } else { - throw failOnBadList(value); - } - } - - private static ArrayValue assertNumberArray(ArrayValue array) { - if (array.valueGroup() != ValueGroup.NUMBER_ARRAY) { - throw failOnBadList(array); - } - return array; - } - @NotNull - private static ArrayValue castToNumericArrayOrFail(ListValue listValue) { - if (listValue.isEmpty()) { - // encode as long array - return Values.EMPTY_LONG_ARRAY; - } - - var firstValue = listValue.head(); - try { - int size = Neo4jProxy.sequenceSizeAsInt(listValue); - if (firstValue instanceof LongValue) { - var longArray = new long[size]; - var iterator = listValue.iterator(); - for (int i = 0; i < size && iterator.hasNext(); i++) { - longArray[i] = ((LongValue) iterator.next()).longValue(); - } - return Values.longArray(longArray); - } else if (firstValue instanceof DoubleValue) { - var doubleArray = new double[size]; - var iterator = listValue.iterator(); - for (int i = 0; i < size && iterator.hasNext(); i++) { - doubleArray[i] = ((DoubleValue) iterator.next()).doubleValue(); - } - return Values.doubleArray(doubleArray); - } else { - throw failOnBadList(listValue); - } - } catch (ClassCastException c) { - throw failOnBadList(listValue); - } - } - - private static IllegalArgumentException failOnBadList(AnyValue badList) { - return new IllegalArgumentException(formatWithLocale( - "Only lists of uniformly typed numbers are supported as GDS node properties, but found an unsupported list `%s`.", - badList - )); - } - - private ValueConverter() {} -} diff --git a/core/src/main/java/org/neo4j/gds/core/loading/construction/NodesBuilder.java b/core/src/main/java/org/neo4j/gds/core/loading/construction/NodesBuilder.java index 690ef52ac4..2d78034aba 100644 --- a/core/src/main/java/org/neo4j/gds/core/loading/construction/NodesBuilder.java +++ b/core/src/main/java/org/neo4j/gds/core/loading/construction/NodesBuilder.java @@ -37,10 +37,8 @@ import org.neo4j.gds.core.loading.nodeproperties.NodePropertiesFromStoreBuilder; import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet; import org.neo4j.gds.core.utils.paged.HugeAtomicGrowingBitSet; -import org.neo4j.values.storable.Value; -import org.neo4j.values.storable.Values; +import org.neo4j.gds.values.GdsValue; -import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.concurrent.atomic.LongAdder; @@ -148,33 +146,19 @@ public void addNode(long originalId, NodeLabel nodeLabel) { this.addNode(originalId, NodeLabelTokens.ofNodeLabel(nodeLabel)); } - public void addNodeWithPropertiesAsObjects(long originalId, Map propertiesAsObjects) { - this.addNodeWithPropertiesAsObjects(originalId, propertiesAsObjects, NodeLabelTokens.empty()); - } - - public void addNodeWithPropertiesAsObjects(long originalId, Map propertiesAsObjects, NodeLabel... nodeLabels) { - this.addNodeWithPropertiesAsObjects(originalId, propertiesAsObjects, NodeLabelTokens.ofNodeLabels(nodeLabels)); - } - - public void addNodeWithPropertiesAsObjects(long originalId, Map propertiesAsObjects, NodeLabelToken nodeLabels) { - var properties = new HashMap(propertiesAsObjects.size()); - propertiesAsObjects.forEach((key, value) -> properties.put(key, Values.of(value))); - this.addNode(originalId, properties, nodeLabels); - } - - public void addNode(long originalId, Map properties) { + public void addNode(long originalId, Map properties) { this.addNode(originalId, properties, NodeLabelTokens.empty()); } - public void addNode(long originalId, Map properties, NodeLabelToken nodeLabels) { + public void addNode(long originalId, Map properties, NodeLabelToken nodeLabels) { this.addNode(originalId, nodeLabels, PropertyValues.of(properties)); } - public void addNode(long originalId, Map properties, NodeLabel... nodeLabels) { + public void addNode(long originalId, Map properties, NodeLabel... nodeLabels) { this.addNode(originalId, properties, NodeLabelTokens.ofNodeLabels(nodeLabels)); } - public void addNode(long originalId, Map properties, NodeLabel nodeLabel) { + public void addNode(long originalId, Map properties, NodeLabel nodeLabel) { this.addNode(originalId, properties, NodeLabelTokens.ofNodeLabel(nodeLabel)); } diff --git a/core/src/main/java/org/neo4j/gds/core/loading/construction/PropertyValues.java b/core/src/main/java/org/neo4j/gds/core/loading/construction/PropertyValues.java index 26bb98e65b..18046f397e 100644 --- a/core/src/main/java/org/neo4j/gds/core/loading/construction/PropertyValues.java +++ b/core/src/main/java/org/neo4j/gds/core/loading/construction/PropertyValues.java @@ -19,8 +19,8 @@ */ package org.neo4j.gds.core.loading.construction; -import org.neo4j.gds.core.loading.ValueConverter; -import org.neo4j.values.storable.Value; +import org.neo4j.gds.core.loading.GdsNeo4jValueConverter; +import org.neo4j.gds.values.GdsValue; import org.neo4j.values.virtual.MapValue; import java.util.Map; @@ -29,7 +29,7 @@ public abstract class PropertyValues { - public abstract void forEach(BiConsumer consumer); + public abstract void forEach(BiConsumer consumer); public abstract boolean isEmpty(); @@ -37,25 +37,25 @@ public abstract class PropertyValues { public abstract Iterable propertyKeys(); - public abstract Value get(String key); + public abstract GdsValue get(String key); public static PropertyValues of(MapValue mapValue) { return new CypherPropertyValues(mapValue); } - public static PropertyValues of(Map map) { + public static PropertyValues of(Map map) { return new NativePropertyValues(map); } private static final class NativePropertyValues extends PropertyValues { - private final Map properties; + private final Map properties; - private NativePropertyValues(Map properties) { + private NativePropertyValues(Map properties) { this.properties = properties; } @Override - public void forEach(BiConsumer consumer) { + public void forEach(BiConsumer consumer) { this.properties.forEach(consumer); } @@ -75,7 +75,7 @@ public Set propertyKeys() { } @Override - public Value get(String key) { + public GdsValue get(String key) { return properties.get(key); } } @@ -88,9 +88,9 @@ private CypherPropertyValues(MapValue properties) { } @Override - public void forEach(BiConsumer consumer) { + public void forEach(BiConsumer consumer) { this.properties.foreach((k, v) -> { - consumer.accept(k, ValueConverter.toValue(v)); + consumer.accept(k, GdsNeo4jValueConverter.toValue(v)); }); } @@ -110,8 +110,8 @@ public Iterable propertyKeys() { } @Override - public Value get(String key) { - return ValueConverter.toValue(properties.get(key)); + public GdsValue get(String key) { + return GdsNeo4jValueConverter.toValue(properties.get(key)); } } } diff --git a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/DoubleArrayNodePropertiesBuilder.java b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/DoubleArrayNodePropertiesBuilder.java index 6c051d7592..0bf8d489ab 100644 --- a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/DoubleArrayNodePropertiesBuilder.java +++ b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/DoubleArrayNodePropertiesBuilder.java @@ -27,7 +27,9 @@ import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.concurrency.DefaultPool; import org.neo4j.gds.core.concurrency.ParallelUtil; +import org.neo4j.gds.utils.GdsNeo4jValueConversion; import org.neo4j.gds.utils.Neo4jValueConversion; +import org.neo4j.gds.values.GdsValue; import org.neo4j.values.storable.Value; import java.util.Arrays; @@ -60,6 +62,11 @@ public void setValue(long neoNodeId, Value value) { set(neoNodeId, Neo4jValueConversion.getDoubleArray(value)); } + @Override + public void setValue(long neoNodeId, GdsValue value) { + set(neoNodeId, GdsNeo4jValueConversion.getDoubleArray(value)); + } + @Override public DoubleArrayNodePropertyValues build(long size, PartialIdMap idMap, long highestOriginalId) { var propertiesByNeoIds = builder.build(); diff --git a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/DoubleNodePropertiesBuilder.java b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/DoubleNodePropertiesBuilder.java index eae9bccce9..3817c14416 100644 --- a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/DoubleNodePropertiesBuilder.java +++ b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/DoubleNodePropertiesBuilder.java @@ -27,7 +27,9 @@ import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.concurrency.DefaultPool; import org.neo4j.gds.core.concurrency.ParallelUtil; +import org.neo4j.gds.utils.GdsNeo4jValueConversion; import org.neo4j.gds.utils.Neo4jValueConversion; +import org.neo4j.gds.values.GdsValue; import org.neo4j.values.storable.Value; import java.lang.invoke.MethodHandles; @@ -84,6 +86,12 @@ public void setValue(long neoNodeId, Value value) { set(neoNodeId, doubleValue); } + @Override + public void setValue(long neoNodeId, GdsValue value) { + double doubleValue = GdsNeo4jValueConversion.getDoubleValue(value); + set(neoNodeId, doubleValue); + } + @Override public DoubleNodePropertyValues build(long size, PartialIdMap idMap, long highestOriginalId) { var propertiesByNeoIds = builder.build(); diff --git a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/FloatArrayNodePropertiesBuilder.java b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/FloatArrayNodePropertiesBuilder.java index dde4f886a7..4ed9f7be4c 100644 --- a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/FloatArrayNodePropertiesBuilder.java +++ b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/FloatArrayNodePropertiesBuilder.java @@ -27,7 +27,9 @@ import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.concurrency.DefaultPool; import org.neo4j.gds.core.concurrency.ParallelUtil; +import org.neo4j.gds.utils.GdsNeo4jValueConversion; import org.neo4j.gds.utils.Neo4jValueConversion; +import org.neo4j.gds.values.GdsValue; import org.neo4j.values.storable.Value; import java.util.Arrays; @@ -58,6 +60,11 @@ public void setValue(long neoNodeId, Value value) { set(neoNodeId, Neo4jValueConversion.getFloatArray(value)); } + @Override + public void setValue(long neoNodeId, GdsValue value) { + set(neoNodeId, GdsNeo4jValueConversion.getFloatArray(value)); + } + @Override public FloatArrayNodePropertyValues build(long size, PartialIdMap idMap, long highestOriginalId) { var propertiesByNeoIds = builder.build(); diff --git a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/InnerNodePropertiesBuilder.java b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/InnerNodePropertiesBuilder.java index 4e61d9bd7f..cda18ba6d9 100644 --- a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/InnerNodePropertiesBuilder.java +++ b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/InnerNodePropertiesBuilder.java @@ -21,12 +21,15 @@ import org.neo4j.gds.api.PartialIdMap; import org.neo4j.gds.api.properties.nodes.NodePropertyValues; +import org.neo4j.gds.values.GdsValue; import org.neo4j.values.storable.Value; public interface InnerNodePropertiesBuilder { void setValue(long neoNodeId, Value value); + void setValue(long neoNodeId, GdsValue value); + /** * Builds the underlying node properties and performs a remapping * to the internal id space using the given id map. diff --git a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/LongArrayNodePropertiesBuilder.java b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/LongArrayNodePropertiesBuilder.java index e3595dd365..4c386fe873 100644 --- a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/LongArrayNodePropertiesBuilder.java +++ b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/LongArrayNodePropertiesBuilder.java @@ -27,7 +27,9 @@ import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.concurrency.DefaultPool; import org.neo4j.gds.core.concurrency.ParallelUtil; +import org.neo4j.gds.utils.GdsNeo4jValueConversion; import org.neo4j.gds.utils.Neo4jValueConversion; +import org.neo4j.gds.values.GdsValue; import org.neo4j.values.storable.Value; import java.util.Arrays; @@ -58,6 +60,11 @@ public void setValue(long neoNodeId, Value value) { set(neoNodeId, Neo4jValueConversion.getLongArray(value)); } + @Override + public void setValue(long neoNodeId, GdsValue value) { + set(neoNodeId, GdsNeo4jValueConversion.getLongArray(value)); + } + @Override public LongArrayNodePropertyValues build(long size, PartialIdMap idMap, long highestOriginalId) { var propertiesByNeoIds = builder.build(); diff --git a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/LongNodePropertiesBuilder.java b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/LongNodePropertiesBuilder.java index c85c8ad0a4..d631b44120 100644 --- a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/LongNodePropertiesBuilder.java +++ b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/LongNodePropertiesBuilder.java @@ -28,7 +28,9 @@ import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.concurrency.DefaultPool; import org.neo4j.gds.core.concurrency.ParallelUtil; +import org.neo4j.gds.utils.GdsNeo4jValueConversion; import org.neo4j.gds.utils.Neo4jValueConversion; +import org.neo4j.gds.values.GdsValue; import org.neo4j.values.storable.Value; import java.lang.invoke.MethodHandles; @@ -92,6 +94,12 @@ public void setValue(long neoNodeId, Value value) { set(neoNodeId, longValue); } + @Override + public void setValue(long neoNodeId, GdsValue value) { + var longValue = GdsNeo4jValueConversion.getLongValue(value); + set(neoNodeId, longValue); + } + @Override public NodePropertyValues build(long size, PartialIdMap idMap, long highestOriginalId) { var propertiesByNeoIds = builder.build(); diff --git a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/NodePropertiesFromStoreBuilder.java b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/NodePropertiesFromStoreBuilder.java index c15fa3e799..d8a3128bf6 100644 --- a/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/NodePropertiesFromStoreBuilder.java +++ b/core/src/main/java/org/neo4j/gds/core/loading/nodeproperties/NodePropertiesFromStoreBuilder.java @@ -26,9 +26,15 @@ import org.neo4j.gds.collections.hsa.HugeSparseCollections; import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.loading.HighLimitIdMap; -import org.neo4j.gds.core.loading.ValueConverter; import org.neo4j.gds.mem.MemoryEstimation; import org.neo4j.gds.mem.MemoryEstimations; +import org.neo4j.gds.values.GdsNoValue; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.values.storable.DoubleArray; +import org.neo4j.values.storable.FloatArray; +import org.neo4j.values.storable.FloatingPointValue; +import org.neo4j.values.storable.IntegralValue; +import org.neo4j.values.storable.LongArray; import org.neo4j.values.storable.Value; import org.neo4j.values.storable.Values; @@ -83,6 +89,15 @@ public void set(long neoNodeId, Value value) { } } + public void set(long neoNodeId, GdsValue value) { + if (value != null && value != GdsNoValue.NO_VALUE) { + if (innerBuilder.get() == null) { + initializeWithType(value); + } + innerBuilder.get().setValue(neoNodeId, value); + } + } + public NodePropertyValues build(IdMap idMap) { if (innerBuilder.get() == null) { if (defaultValue.getObject() != null) { @@ -104,12 +119,20 @@ public NodePropertyValues build(IdMap idMap) { // This is synchronized as we want to prevent the creation of multiple InnerNodePropertiesBuilders of which only once survives. private synchronized void initializeWithType(Value value) { if (innerBuilder.get() == null) { - var valueType = ValueConverter.valueType(value); + var valueType = valueType(value); var newBuilder = newInnerBuilder(valueType); innerBuilder.compareAndSet(null, newBuilder); } } + // This is synchronized as we want to prevent the creation of multiple InnerNodePropertiesBuilders of which only once survives. + private synchronized void initializeWithType(GdsValue value) { + if (innerBuilder.get() == null) { + var newBuilder = newInnerBuilder(value.type()); + innerBuilder.compareAndSet(null, newBuilder); + } + } + private InnerNodePropertiesBuilder newInnerBuilder(ValueType valueType) { switch (valueType) { case LONG: @@ -129,4 +152,23 @@ private InnerNodePropertiesBuilder newInnerBuilder(ValueType valueType) { )); } } + + private ValueType valueType(Value value) { + if (value instanceof IntegralValue) { + return ValueType.LONG; + } else if (value instanceof FloatingPointValue) { + return ValueType.DOUBLE; + } else if (value instanceof LongArray) { + return ValueType.LONG_ARRAY; + } else if (value instanceof DoubleArray) { + return ValueType.DOUBLE_ARRAY; + } else if (value instanceof FloatArray) { + return ValueType.FLOAT_ARRAY; + } else { + throw new UnsupportedOperationException(formatWithLocale( + "Loading of values of type %s is currently not supported", + value.getTypeName() + )); + } + } } diff --git a/core/src/main/java/org/neo4j/gds/utils/GdsNeo4jValueConversion.java b/core/src/main/java/org/neo4j/gds/utils/GdsNeo4jValueConversion.java new file mode 100644 index 0000000000..80ff26bbaf --- /dev/null +++ b/core/src/main/java/org/neo4j/gds/utils/GdsNeo4jValueConversion.java @@ -0,0 +1,192 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.utils; + +import org.neo4j.gds.values.DoubleArray; +import org.neo4j.gds.values.FloatArray; +import org.neo4j.gds.values.FloatingPointArray; +import org.neo4j.gds.values.FloatingPointValue; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.gds.values.IntegralArray; +import org.neo4j.gds.values.IntegralValue; +import org.neo4j.gds.values.LongArray; + +import java.util.Locale; +import java.util.function.IntToLongFunction; + +import static org.neo4j.gds.api.ValueConversion.exactDoubleToLong; +import static org.neo4j.gds.api.ValueConversion.exactLongToDouble; +import static org.neo4j.gds.api.ValueConversion.exactLongToFloat; +import static org.neo4j.gds.api.ValueConversion.notOverflowingDoubleToFloat; +import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; + +public final class GdsNeo4jValueConversion { + + public static long getLongValue(GdsValue value) { + if (value instanceof IntegralValue) { + return ((IntegralValue) value).longValue(); + } else if (value instanceof FloatingPointValue) { + return exactDoubleToLong(((FloatingPointValue) value).doubleValue()); + } else { + throw conversionError(value, "Long"); + } + } + + public static double getDoubleValue(GdsValue value) { + if (value instanceof FloatingPointValue) { + return ((FloatingPointValue) value).doubleValue(); + } else if (value instanceof IntegralValue) { + return exactLongToDouble(((IntegralValue) value).longValue()); + } else { + throw conversionError(value, "Double"); + } + } + + public static long[] getLongArray(GdsValue value) { + if (value instanceof LongArray) { + return ((LongArray) value).longArrayValue(); + } else if (value instanceof FloatingPointArray) { + return floatToLongArray((FloatingPointArray) value); + } + else { + throw conversionError(value, "Long Array"); + } + } + + public static double[] getDoubleArray(GdsValue value) { + if (value instanceof DoubleArray) { + return ((DoubleArray) value).doubleArrayValue(); + } else if (value instanceof FloatArray) { + return floatToDoubleArray((FloatArray) value); + } else if (value instanceof IntegralArray) { + return integralToDoubleArray((IntegralArray) value); + } else { + throw conversionError(value, "Double Array"); + } + } + + public static float[] getFloatArray(GdsValue value) { + if (value instanceof FloatArray) { + return ((FloatArray) value).floatArrayValue(); + } else if (value instanceof DoubleArray) { + return doubleToFloatArray((DoubleArray) value); + } else if (value instanceof IntegralArray) { + return longToFloatArray((IntegralArray) value); + }else { + throw conversionError(value, "Float Array"); + } + } + + private static double[] integralToDoubleArray(IntegralArray intArray) { + var result = new double[intArray.length()]; + + IntToLongFunction longValueProvider = resolvelongValueProvider(intArray); + + try { + for (int idx = 0; idx < intArray.length(); idx++) { + result[idx] = exactLongToDouble(longValueProvider.applyAsLong(idx)); + } + } catch (UnsupportedOperationException e) { + throw conversionError(intArray, "Double Array", e.getMessage()); + } + + return result; + } + + private static double[] floatToDoubleArray(FloatArray floatArray) { + var result = new double[floatArray.length()]; + + for (int idx = 0; idx < floatArray.length(); idx++) { + result[idx] = floatArray.doubleValue(idx); + } + + return result; + } + + private static float[] doubleToFloatArray(DoubleArray doubleArray) { + var result = new float[doubleArray.length()]; + + try { + for (int idx = 0; idx < doubleArray.length(); idx++) { + result[idx] = notOverflowingDoubleToFloat(doubleArray.doubleValue(idx)); + } + } catch (UnsupportedOperationException e) { + throw conversionError(doubleArray, "Float Array", e.getMessage()); + } + + return result; + } + + private static float[] longToFloatArray(IntegralArray integralArray) { + var result = new float[integralArray.length()]; + + IntToLongFunction longValueProvider = resolvelongValueProvider(integralArray); + + try { + for (int idx = 0; idx < integralArray.length(); idx++) { + result[idx] = exactLongToFloat(longValueProvider.applyAsLong(idx)); + } + } catch (UnsupportedOperationException e) { + throw conversionError(integralArray, "Float Array", e.getMessage()); + } + + return result; + } + + private static IntToLongFunction resolvelongValueProvider(IntegralArray integralArray) { + if (integralArray instanceof LongArray) { + return ((LongArray) integralArray)::longValue; + } + + throw new IllegalStateException(String.format( + Locale.US, + "Did not expect array of type %s.", integralArray.getClass().getSimpleName() + )); + } + + private static long[] floatToLongArray(FloatingPointArray floatArray) { + var result = new long[floatArray.length()]; + + try { + for (int idx = 0; idx < floatArray.length(); idx++) { + result[idx] = exactDoubleToLong(floatArray.doubleValue(idx)); + } + } catch (UnsupportedOperationException e) { + throw conversionError(floatArray, "Long Array", e.getMessage()); + } + + return result; + } + + private static UnsupportedOperationException conversionError(GdsValue value, String expected) { + return conversionError(value, expected, ""); + } + + private static UnsupportedOperationException conversionError(GdsValue value, String expected, String context) { + return new UnsupportedOperationException(formatWithLocale( + "Cannot safely convert %s into a %s. %s", + value, + expected, + context + )); + } + + private GdsNeo4jValueConversion() {} +} diff --git a/core/src/test/java/org/neo4j/gds/core/loading/NodePropertiesFromStoreBuilderTest.java b/core/src/test/java/org/neo4j/gds/core/loading/NodePropertiesFromStoreBuilderTest.java index ea4359ce8d..82413f6982 100644 --- a/core/src/test/java/org/neo4j/gds/core/loading/NodePropertiesFromStoreBuilderTest.java +++ b/core/src/test/java/org/neo4j/gds/core/loading/NodePropertiesFromStoreBuilderTest.java @@ -27,11 +27,12 @@ import org.neo4j.gds.TestSupport; import org.neo4j.gds.api.DefaultValue; import org.neo4j.gds.api.nodeproperties.ValueType; +import org.neo4j.gds.api.properties.nodes.LongArrayNodePropertyValues; import org.neo4j.gds.api.properties.nodes.NodePropertyValues; import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.loading.nodeproperties.NodePropertiesFromStoreBuilder; -import org.neo4j.values.storable.Value; -import org.neo4j.values.storable.Values; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.gds.values.primitive.PrimitiveValues; import java.util.OptionalDouble; import java.util.OptionalLong; @@ -43,11 +44,9 @@ import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.neo4j.gds.TestSupport.idMap; import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; @@ -81,7 +80,7 @@ void testEmptyLongProperties() { @Test void returnsValuesThatHaveBeenSet() { - var properties = createNodeProperties(2L, 42.0, b -> b.set(1, Values.of(1.0))); + var properties = createNodeProperties(2L, 42.0, b -> b.set(1, PrimitiveValues.create(1.0))); assertEquals(1.0, properties.doubleValue(1)); assertEquals(42.0, properties.doubleValue(0)); @@ -94,7 +93,7 @@ void shouldReturnLongArrays() { NodePropertyValues properties = createNodeProperties( 2L, defaultValue, - b -> b.set(1, Values.of(data)) + b -> b.set(1, PrimitiveValues.create(data)) ); assertArrayEquals(data, properties.longArrayValue(1)); @@ -108,7 +107,7 @@ void shouldReturnDoubleArrays() { NodePropertyValues properties = createNodeProperties( 2L, defaultValue, - b -> b.set(1, Values.of(data)) + b -> b.set(1, PrimitiveValues.create(data)) ); assertArrayEquals(data, properties.doubleArrayValue(1)); @@ -122,7 +121,7 @@ void shouldReturnFloatArrays() { NodePropertyValues properties = createNodeProperties( 2L, defaultValue, - b -> b.set(1, Values.of(data)) + b -> b.set(1, PrimitiveValues.create(data)) ); assertArrayEquals(data, properties.floatArrayValue(1)); @@ -137,7 +136,7 @@ void shouldCastFromFloatArrayToDoubleArray() { NodePropertyValues properties = createNodeProperties( 2L, defaultValue, - b -> b.set(1, Values.of(floatData)) + b -> b.set(1, PrimitiveValues.create(floatData)) ); assertArrayEquals(floatData, properties.floatArrayValue(1)); @@ -156,7 +155,7 @@ void shouldCastFromDoubleArrayToFloatArray() { NodePropertyValues properties = createNodeProperties( 2L, defaultValue, - b -> b.set(1, Values.of(doubleData)) + b -> b.set(1, PrimitiveValues.create(doubleData)) ); assertArrayEquals(doubleData, properties.doubleArrayValue(1)); @@ -166,11 +165,11 @@ void shouldCastFromDoubleArrayToFloatArray() { @Test void dimensions() { - var longs = createNodeProperties(2, -6L, b -> b.set(1, Values.of(69L))); - var doubles = createNodeProperties(2, 420D, b -> b.set(1, Values.of(13.37D))); - var floatArray = createNodeProperties(2, new float[2], b -> b.set(1, Values.of(new float[]{42.2F, 1337.1F}))); - var doubleArray = createNodeProperties(2, new double[3], b -> b.set(1, Values.of(new double[]{1D, 1D, 0D}))); - var longArray = createNodeProperties(2, new long[0], b -> b.set(1, Values.of(new long[0]))); + var longs = createNodeProperties(2, -6L, b -> b.set(1, PrimitiveValues.create(69L))); + var doubles = createNodeProperties(2, 420D, b -> b.set(1, PrimitiveValues.create(13.37D))); + var floatArray = createNodeProperties(2, new float[2], b -> b.set(1, PrimitiveValues.create(new float[]{42.2F, 1337.1F}))); + var doubleArray = createNodeProperties(2, new double[3], b -> b.set(1, PrimitiveValues.create(new double[]{1D, 1D, 0D}))); + var longArray = createNodeProperties(2, new long[0], b -> b.set(1, PrimitiveValues.create(new long[0]))); assertThat(longs.dimension()).contains(1); assertThat(doubles.dimension()).contains(1); @@ -181,36 +180,43 @@ void dimensions() { @Test void dimensionsWithNulls() { - var floatArray = createNodeProperties(3, null, b -> b.set(1, Values.of(new float[]{42.2F, 1337.1F}))); - var doubleArray = createNodeProperties(3, new double[3], b -> b.set(1, Values.of(null))); - var longArray = createNodeProperties(3, new long[0], b -> b.set(0, Values.of(null))); + var floatArray = createNodeProperties(3, null, b -> b.set(1, PrimitiveValues.create(new float[]{42.2F, 1337.1F}))); + var doubleArray = createNodeProperties(3, new double[3], b -> b.set(1, PrimitiveValues.create(null))); + var longArray = createNodeProperties(3, new long[0], b -> b.set(0, PrimitiveValues.create(null))); assertThat(floatArray.dimension()).isEmpty(); assertThat(doubleArray.dimension()).contains(3); assertThat(longArray.dimension()).contains(0); } - static Stream unsupportedValues() { - return Stream.of( - arguments(Values.stringValue("42L")), - arguments(Values.shortArray(new short[]{(short) 42})), - arguments(Values.byteArray(new byte[]{(byte) 42})), - arguments(Values.booleanValue(true)), - arguments(Values.charValue('c')) - ); + @Test + void shouldSupportByteArray() { + var data = PrimitiveValues.byteArray(new byte[]{(byte) 42}); + var nodeProperties = createNodeProperties(2L, DefaultValue.forLongArray(), b -> b.set(1, data)); + assertThat(nodeProperties).isInstanceOf(LongArrayNodePropertyValues.class); + assertThat(nodeProperties.longArrayValue(0)).isEqualTo(DefaultValue.forLongArray().longArrayValue()); + assertThat(nodeProperties.longArrayValue(1)).isEqualTo(data.longArrayValue()); } - @ParameterizedTest - @MethodSource("org.neo4j.gds.core.loading.NodePropertiesFromStoreBuilderTest#unsupportedValues") - void shouldFailOnUnSupportedTypes(Value data) { - assertThatThrownBy(() -> createNodeProperties( - 2L, - null, - b -> b.set(1, data) - )).isInstanceOf(UnsupportedOperationException.class) - .hasMessageContaining("Loading of values of type"); + @Test + void shouldSupportShortArray() { + var data = PrimitiveValues.shortArray(new short[]{(short) 42}); + var nodeProperties = createNodeProperties(2L, DefaultValue.forLongArray(), b -> b.set(1, data)); + assertThat(nodeProperties).isInstanceOf(LongArrayNodePropertyValues.class); + assertThat(nodeProperties.longArrayValue(0)).isEqualTo(DefaultValue.forLongArray().longArrayValue()); + assertThat(nodeProperties.longArrayValue(1)).isEqualTo(data.longArrayValue()); } + @Test + void shouldSupportIntArray() { + var data = PrimitiveValues.intArray(new int[]{(int) 42}); + var nodeProperties = createNodeProperties(2L, DefaultValue.forLongArray(), b -> b.set(1, data)); + assertThat(nodeProperties).isInstanceOf(LongArrayNodePropertyValues.class); + assertThat(nodeProperties.longArrayValue(0)).isEqualTo(DefaultValue.forLongArray().longArrayValue()); + assertThat(nodeProperties.longArrayValue(1)).isEqualTo(data.longArrayValue()); + } + + private static Stream invalidValueTypeCombinations() { Supplier> scalarValues = () -> Stream.of(2L, 2D).map(Arguments::of); Supplier> arrayValues = () -> Stream.of(new double[]{1D}, new long[]{1L}).map(Arguments::of); @@ -225,7 +231,7 @@ private static Stream invalidValueTypeCombinations() { @MethodSource("invalidValueTypeCombinations") void failOnInvalidDefaultType(Object defaultValue, Object propertyValue) { Assertions.assertThatThrownBy(() -> createNodeProperties(1L, defaultValue, b -> { - b.set(0, Values.of(propertyValue)); + b.set(0, PrimitiveValues.create(propertyValue)); })).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining(formatWithLocale("Expected type of default value to be `%s`.", propertyValue.getClass().getSimpleName())); } @@ -240,7 +246,7 @@ void returnsDefaultOnMissingEntries() { @Test void returnNaNIfItWasSet() { - var properties = createNodeProperties(2L, 42.0, b -> b.set(1, Values.of(Double.NaN))); + var properties = createNodeProperties(2L, 42.0, b -> b.set(1, PrimitiveValues.create(Double.NaN))); assertEquals(42.0, properties.doubleValue(0)); assertEquals(Double.NaN, properties.doubleValue(1)); @@ -249,8 +255,8 @@ void returnNaNIfItWasSet() { @Test void trackMaxValue() { var properties = createNodeProperties(2L, 0.0, b -> { - b.set(0, Values.of(42)); - b.set(1, Values.of(21)); + b.set(0, PrimitiveValues.create(42)); + b.set(1, PrimitiveValues.create(21)); }); var maxPropertyValue = properties.getMaxLongPropertyValue(); assertTrue(maxPropertyValue.isPresent()); @@ -260,8 +266,8 @@ void trackMaxValue() { @Test void hasSize() { var properties = createNodeProperties(2L, 0.0, b -> { - b.set(0, Values.of(42.0)); - b.set(1, Values.of(21.0)); + b.set(0, PrimitiveValues.create(42.0)); + b.set(1, PrimitiveValues.create(21.0)); }); assertEquals(2, properties.nodeCount()); } @@ -275,8 +281,9 @@ void shouldHandleNullValues() { new Concurrency(1) ); - builder.set(0, null); - builder.set(1, Values.longValue(42L)); + GdsValue value = null; + builder.set(0, value); + builder.set(1, PrimitiveValues.longValue(42L)); var properties = builder.build(idMap(nodeCount)); @@ -303,7 +310,7 @@ void threadSafety() throws InterruptedException { // that value, while the other thread will write 2^42 in the meantime. If that happens, // this thread would overwrite a new maxValue. for (int i = 0; i < nodeSize; i += 2) { - builder.set(i, Values.of(i == 1338 ? 0x1p41 : 2.0)); + builder.set(i, PrimitiveValues.create(i == 1338 ? 0x1p41 : 2.0)); } }); pool.execute(() -> { @@ -312,7 +319,7 @@ void threadSafety() throws InterruptedException { // second task, sets the value 1 on every other node, except for 1337 which is set to 2^42 // Depending on thread scheduling, the write for 2^42 might be overwritten by the first thread for (int i = 1; i < nodeSize; i += 2) { - builder.set(i, Values.of(i == 1337 ? 0x1p42 : 1.0)); + builder.set(i, PrimitiveValues.create(i == 1337 ? 0x1p42 : 1.0)); } }); diff --git a/core/src/test/java/org/neo4j/gds/utils/GdsNeo4jValueConversionTest.java b/core/src/test/java/org/neo4j/gds/utils/GdsNeo4jValueConversionTest.java new file mode 100644 index 0000000000..9163446101 --- /dev/null +++ b/core/src/test/java/org/neo4j/gds/utils/GdsNeo4jValueConversionTest.java @@ -0,0 +1,214 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.utils; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.neo4j.gds.api.ValueConversion; +import org.neo4j.gds.values.FloatingPointValue; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.gds.values.IntegralValue; +import org.neo4j.gds.values.primitive.PrimitiveValues; + +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +class GdsNeo4jValueConversionTest { + + @ParameterizedTest + @MethodSource("org.neo4j.gds.utils.GdsNeo4jValueConversionTest#longConversion") + void testGettingALong(GdsValue value, Long expected) { + if (expected != null) { + Assertions.assertEquals(expected, GdsNeo4jValueConversion.getLongValue(value)); + } else { + assertThrows(UnsupportedOperationException.class, () -> GdsNeo4jValueConversion.getLongValue(value)); + } + } + + static Stream longConversion() { + return Stream.of( + arguments(PrimitiveValues.longValue(42L), 42L), + arguments(PrimitiveValues.longValue(42), 42L), + arguments(PrimitiveValues.longValue((short) 42), 42L), + arguments(PrimitiveValues.longValue((byte) 42), 42L), + arguments(PrimitiveValues.floatingPointValue(42.0F), 42L), + arguments(PrimitiveValues.floatingPointValue(42.0), 42L), + + arguments(PrimitiveValues.longArray(new long[]{42L}), null), + arguments(PrimitiveValues.floatingPointValue(42.12F), null), + arguments(PrimitiveValues.floatingPointValue(42.12), null) + ); + } + + @Test + void shouldConvertEmptyLongArrayToDoubleArray() { + assertThat(GdsNeo4jValueConversion.getDoubleArray(PrimitiveValues.longArray(new long[0]))).isEqualTo(new double[0]); + } + + @ParameterizedTest + @MethodSource("org.neo4j.gds.utils.GdsNeo4jValueConversionTest#doubleConversion") + void testGettingADouble(GdsValue value, Double expected) { + if (expected != null) { + assertEquals(expected, GdsNeo4jValueConversion.getDoubleValue(value), 0.1); + } else { + assertThrows(UnsupportedOperationException.class, () -> GdsNeo4jValueConversion.getDoubleValue(value)); + } + } + + static Stream doubleConversion() { + return Stream.of( + arguments(PrimitiveValues.floatingPointValue(42.1), 42.1D), + arguments(PrimitiveValues.floatingPointValue(42.1F), 42.1D), + arguments(PrimitiveValues.longValue(42L), 42.0D), + arguments(PrimitiveValues.longValue(42), 42.0D), + arguments(PrimitiveValues.longValue((short) 42), 42.0D), + arguments(PrimitiveValues.longValue((byte) 42), 42.0D), + + arguments(PrimitiveValues.doubleArray(new double[]{42.0}), null), + arguments(PrimitiveValues.longValue(1L << 54 + 1), null) + ); + } + + static float getFloatValue(GdsValue value) { + if (value instanceof FloatingPointValue) { + var doubleValue = ((FloatingPointValue) value).doubleValue(); + return ValueConversion.notOverflowingDoubleToFloat(doubleValue); + } else if (value instanceof IntegralValue) { + return ValueConversion.exactLongToFloat(((IntegralValue) value).longValue()); + } else { + throw new UnsupportedOperationException("Failed to convert to float"); + } + } + + @ParameterizedTest + @MethodSource("org.neo4j.gds.utils.GdsNeo4jValueConversionTest#floatConversion") + void testGettingAFloat(GdsValue value, Float expected) { + if (expected != null) { + assertEquals(expected, getFloatValue(value), 0.1); + } else { + assertThrows(UnsupportedOperationException.class, () -> getFloatValue(value)); + } + } + + static Stream floatConversion() { + return Stream.of( + arguments(PrimitiveValues.floatingPointValue(42.1F), 42.1F), + arguments(PrimitiveValues.floatingPointValue(42.1D), 42.1F), + arguments(PrimitiveValues.longValue(42L), 42.0F), + arguments(PrimitiveValues.longValue(42), 42.0F), + arguments(PrimitiveValues.longValue((short) 42), 42.0F), + arguments(PrimitiveValues.longValue((byte) 42), 42.0F), + + arguments(PrimitiveValues.doubleArray(new double[]{42.0}), null), + arguments(PrimitiveValues.longArray(new long[]{42}), null), + arguments(PrimitiveValues.floatArray(new float[]{42.0F}), null), + arguments(PrimitiveValues.longValue(Long.MAX_VALUE), null), + arguments(PrimitiveValues.longValue(Long.MIN_VALUE), null), + arguments(PrimitiveValues.floatingPointValue(Float.MAX_VALUE * 2.0D), null), + arguments(PrimitiveValues.floatingPointValue(-Float.MAX_VALUE * 2.0D), null) + ); + } + + @ParameterizedTest + @MethodSource("org.neo4j.gds.utils.GdsNeo4jValueConversionTest#longArrayConversion") + void testGettingALongArray(GdsValue value, long[] expected) { + if (expected != null) { + assertArrayEquals(expected, GdsNeo4jValueConversion.getLongArray(value)); + } else { + assertThrows(UnsupportedOperationException.class, () -> GdsNeo4jValueConversion.getLongArray(value)); + } + } + + static Stream longArrayConversion() { + return Stream.of( + arguments(PrimitiveValues.longArray(new long[]{42L}), new long[]{42L}), + arguments(PrimitiveValues.intArray(new int[]{42}), new long[]{42L}), + arguments(PrimitiveValues.floatArray(new float[]{42.0F}), new long[]{42L}), + arguments(PrimitiveValues.floatArray(new float[]{42.42F}), null), + arguments(PrimitiveValues.doubleArray(new double[]{42.0}), new long[]{42L}), + arguments(PrimitiveValues.doubleArray(new double[]{42.42d}), null) + ); + } + + @Test + void testLongArrayConversionErrorShowsCause() { + var input = PrimitiveValues.floatArray(new float[] {42.0F, 13.37F, 256.0F}); + assertThatThrownBy(() -> GdsNeo4jValueConversion.getLongArray(input)) + .hasMessage("Cannot safely convert FloatArray[42.0, 13.37, 256.0] into a Long Array." + + " Cannot safely convert 13.37 into an long value"); + } + + @ParameterizedTest + @MethodSource("org.neo4j.gds.utils.GdsNeo4jValueConversionTest#doubleArrayConversion") + void testGettingADoubleArray(GdsValue value, double[] expected) { + if (expected != null) { + assertArrayEquals(expected, GdsNeo4jValueConversion.getDoubleArray(value), 0.1); + } else { + assertThrows(UnsupportedOperationException.class, () -> GdsNeo4jValueConversion.getDoubleArray(value)); + } + } + + static Stream doubleArrayConversion() { + return Stream.of( + arguments(PrimitiveValues.doubleArray(new double[]{42.0}), new double[]{42.0}), + arguments(PrimitiveValues.floatArray(new float[]{42.0F}), new double[]{42.0}), + arguments(PrimitiveValues.longArray(new long[]{42}), new double[]{42}), + arguments(PrimitiveValues.intArray(new int[]{42}), new double[]{42}), + arguments(PrimitiveValues.longArray(new long[]{9007199254740993L}), null) + ); + } + + @Test + void testDoubleArrayConversionErrorShowsCause() { + var input = PrimitiveValues.longArray(new long[] {42, 9007199254740993L, -100}); + assertThatThrownBy(() -> GdsNeo4jValueConversion.getDoubleArray(input)) + .hasMessage("Cannot safely convert LongArray[42, 9007199254740993, -100] into a Double Array." + + " Cannot safely convert 9007199254740993 into an double value"); + } + + @ParameterizedTest + @MethodSource("org.neo4j.gds.utils.GdsNeo4jValueConversionTest#floatArrayConversion") + void testGettingAFloatArray(GdsValue value, float[] expected) { + if (expected != null) { + assertArrayEquals(expected, GdsNeo4jValueConversion.getFloatArray(value), 0.1f); + } else { + assertThrows(UnsupportedOperationException.class, () -> GdsNeo4jValueConversion.getFloatArray(value)); + } + } + + static Stream floatArrayConversion() { + return Stream.of( + arguments(PrimitiveValues.floatArray(new float[]{42.0f}), new float[]{42.0f}), + arguments(PrimitiveValues.doubleArray(new double[]{42.0}), new float[]{42.0f}), + arguments(PrimitiveValues.doubleArray(new double[]{Double.MAX_VALUE}), null), + arguments(PrimitiveValues.longArray(new long[]{42}), new float[]{42.0f}), + arguments(PrimitiveValues.longArray(new long[]{9007199254740993L}), null) + ); + } +} diff --git a/gds-values/build.gradle b/gds-values/build.gradle new file mode 100644 index 0000000000..061ee89a77 --- /dev/null +++ b/gds-values/build.gradle @@ -0,0 +1,13 @@ +apply plugin: 'java-library' + +description = 'Neo4j Graph Data Science :: GDS Values' +group = 'org.neo4j.gds' + +dependencies { + api project(':graph-projection-api') + compileOnly openGds.jetbrains.annotations + testImplementation platform(openGds.junit5bom) + testImplementation openGds.junit.pioneer + testImplementation openGds.junit5.jupiter.engine + testImplementation openGds.assertj.core +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/Array.java b/gds-values/src/main/java/org/neo4j/gds/values/Array.java new file mode 100644 index 0000000000..d303c39b34 --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/Array.java @@ -0,0 +1,31 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values; + +public interface Array extends GdsValue { + int length(); + boolean equals(byte[] other); + boolean equals(short[] other); + boolean equals(int[] other); + boolean equals(long[] other); + boolean equals(float[] other); + boolean equals(double[] other); + String toString(); +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/ArrayEquals.java b/gds-values/src/main/java/org/neo4j/gds/values/ArrayEquals.java new file mode 100644 index 0000000000..b024cc92ae --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/ArrayEquals.java @@ -0,0 +1,341 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values; + +import java.util.Arrays; + +/** + * Static methods for checking the equality of arrays of primitives. + * + * This class handles only evaluation of a[] == b[] where type( a ) != type( b ), ei. byte[] == int[] and such. + * byte[] == byte[] evaluation can be done using Arrays.equals(). + */ +public final class ArrayEquals { + private ArrayEquals() {} + + // TYPED COMPARISON + + public static boolean byteAndShort(byte[] a, short[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if (a[i] != b[i]) { + return false; + } + } + return true; + } + } + + public static boolean byteAndInt(byte[] a, int[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if (a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean byteAndLong(byte[] a, long[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if ((long)a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean byteAndFloat(byte[] a, float[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if ((float)a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean byteAndDouble(byte[] a, double[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if ((double)a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean shortAndInt(short[] a, int[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if (a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean shortAndLong(short[] a, long[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if ((long)a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean shortAndFloat(short[] a, float[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if ((float)a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean shortAndDouble(short[] a, double[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if ((double)a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean intAndLong(int[] a, long[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if ((long)a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean intAndFloat(int[] a, float[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if ((float)a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean intAndDouble(int[] a, double[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if ((double)a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean longAndFloat(long[] a, float[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if ((float)a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean longAndDouble(long[] a, double[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if ((double)a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + public static boolean floatAndDouble(float[] a, double[] b) { + if (a.length != b.length) { + return false; + } else { + for(int i = 0; i < a.length; ++i) { + if ((double)a[i] != b[i]) { + return false; + } + } + + return true; + } + } + + // NON-TYPED COMPARISON + + public static boolean byteAndObject(byte[] a, Object b) { + if (b instanceof byte[]) { + return Arrays.equals(a, (byte[])b); + } else if (b instanceof short[]) { + return byteAndShort(a, (short[])b); + } else if (b instanceof int[]) { + return byteAndInt(a, (int[])b); + } else if (b instanceof long[]) { + return byteAndLong(a, (long[])b); + } else if (b instanceof float[]) { + return byteAndFloat(a, (float[])b); + } else { + return b instanceof double[] ? byteAndDouble(a, (double[])b) : false; + } + } + + public static boolean shortAndObject(short[] a, Object b) { + if (b instanceof byte[]) { + return byteAndShort((byte[])b, a); + } else if (b instanceof short[]) { + return Arrays.equals(a, (short[])b); + } else if (b instanceof int[]) { + return shortAndInt(a, (int[])b); + } else if (b instanceof long[]) { + return shortAndLong(a, (long[])b); + } else if (b instanceof float[]) { + return shortAndFloat(a, (float[])b); + } else { + return b instanceof double[] ? shortAndDouble(a, (double[])b) : false; + } + } + + public static boolean intAndObject(int[] a, Object b) { + if (b instanceof byte[]) { + return byteAndInt((byte[])b, a); + } else if (b instanceof short[]) { + return shortAndInt((short[])b, a); + } else if (b instanceof int[]) { + return Arrays.equals(a, (int[])b); + } else if (b instanceof long[]) { + return intAndLong(a, (long[])b); + } else if (b instanceof float[]) { + return intAndFloat(a, (float[])b); + } else { + return b instanceof double[] ? intAndDouble(a, (double[])b) : false; + } + } + + public static boolean longAndObject(long[] a, Object b) { + if (b instanceof byte[]) { + return byteAndLong((byte[])b, a); + } else if (b instanceof short[]) { + return shortAndLong((short[])b, a); + } else if (b instanceof int[]) { + return intAndLong((int[])b, a); + } else if (b instanceof long[]) { + return Arrays.equals(a, (long[])b); + } else if (b instanceof float[]) { + return longAndFloat(a, (float[])b); + } else { + return b instanceof double[] ? longAndDouble(a, (double[])b) : false; + } + } + + public static boolean floatAndObject(float[] a, Object b) { + if (b instanceof byte[]) { + return byteAndFloat((byte[])b, a); + } else if (b instanceof short[]) { + return shortAndFloat((short[])b, a); + } else if (b instanceof int[]) { + return intAndFloat((int[])b, a); + } else if (b instanceof long[]) { + return longAndFloat((long[])b, a); + } else if (b instanceof float[]) { + return Arrays.equals(a, (float[])b); + } else { + return b instanceof double[] ? floatAndDouble(a, (double[])b) : false; + } + } + + public static boolean doubleAndObject(double[] a, Object b) { + if (b instanceof byte[]) { + return byteAndDouble((byte[])b, a); + } else if (b instanceof short[]) { + return shortAndDouble((short[])b, a); + } else if (b instanceof int[]) { + return intAndDouble((int[])b, a); + } else if (b instanceof long[]) { + return longAndDouble((long[])b, a); + } else if (b instanceof float[]) { + return floatAndDouble((float[])b, a); + } else { + return b instanceof double[] ? Arrays.equals(a, (double[])b) : false; + } + } +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/DoubleArray.java b/gds-values/src/main/java/org/neo4j/gds/values/DoubleArray.java new file mode 100644 index 0000000000..f065b0f424 --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/DoubleArray.java @@ -0,0 +1,23 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values; + +public interface DoubleArray extends FloatingPointArray { +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/FloatArray.java b/gds-values/src/main/java/org/neo4j/gds/values/FloatArray.java new file mode 100644 index 0000000000..3e7ab8c0c7 --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/FloatArray.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values; + +public interface FloatArray extends FloatingPointArray { + float[] floatArrayValue(); +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/FloatingPointArray.java b/gds-values/src/main/java/org/neo4j/gds/values/FloatingPointArray.java new file mode 100644 index 0000000000..daa2bc3611 --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/FloatingPointArray.java @@ -0,0 +1,25 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values; + +public interface FloatingPointArray extends Array { + double doubleValue(int idx); + double[] doubleArrayValue(); +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/FloatingPointValue.java b/gds-values/src/main/java/org/neo4j/gds/values/FloatingPointValue.java new file mode 100644 index 0000000000..ef972e75ec --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/FloatingPointValue.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values; + +public interface FloatingPointValue extends GdsValue { + double doubleValue(); +} diff --git a/core/src/main/java/org/neo4j/gds/api/ValueTypes.java b/gds-values/src/main/java/org/neo4j/gds/values/GdsNoValue.java similarity index 59% rename from core/src/main/java/org/neo4j/gds/api/ValueTypes.java rename to gds-values/src/main/java/org/neo4j/gds/values/GdsNoValue.java index 21ffc8912b..0810cb7755 100644 --- a/core/src/main/java/org/neo4j/gds/api/ValueTypes.java +++ b/gds-values/src/main/java/org/neo4j/gds/values/GdsNoValue.java @@ -17,25 +17,22 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ -package org.neo4j.gds.api; +package org.neo4j.gds.values; import org.neo4j.gds.api.nodeproperties.ValueType; -import org.neo4j.values.storable.NumberType; -public final class ValueTypes { +import static org.neo4j.gds.api.nodeproperties.ValueType.UNKNOWN; - public static ValueType fromNumberType(NumberType nt) { - switch (nt) { - case FLOATING_POINT: - return ValueType.DOUBLE; - case INTEGRAL: - return ValueType.LONG; - case NO_NUMBER: - return ValueType.UNKNOWN; - default: - throw new IllegalArgumentException("Unexpected value: " + nt + " (sad java 😞)"); - } +public class GdsNoValue implements GdsValue { + public static final GdsNoValue NO_VALUE = new GdsNoValue(); + + @Override + public ValueType type() { + return UNKNOWN; } - private ValueTypes() {} + @Override + public Object asObject() { + return null; + } } diff --git a/gds-values/src/main/java/org/neo4j/gds/values/GdsValue.java b/gds-values/src/main/java/org/neo4j/gds/values/GdsValue.java new file mode 100644 index 0000000000..d3b040a561 --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/GdsValue.java @@ -0,0 +1,27 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values; + +import org.neo4j.gds.api.nodeproperties.ValueType; + +public interface GdsValue { + ValueType type(); + Object asObject(); +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/IntegralArray.java b/gds-values/src/main/java/org/neo4j/gds/values/IntegralArray.java new file mode 100644 index 0000000000..9279a86d2b --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/IntegralArray.java @@ -0,0 +1,23 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values; + +public interface IntegralArray extends Array { +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/IntegralValue.java b/gds-values/src/main/java/org/neo4j/gds/values/IntegralValue.java new file mode 100644 index 0000000000..23a9890a42 --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/IntegralValue.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values; + +public interface IntegralValue extends GdsValue { + long longValue(); +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/LongArray.java b/gds-values/src/main/java/org/neo4j/gds/values/LongArray.java new file mode 100644 index 0000000000..f43d7645d1 --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/LongArray.java @@ -0,0 +1,25 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values; + +public interface LongArray extends IntegralArray { + long[] longArrayValue(); + long longValue(int idx); +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/primitive/ByteLongArrayImpl.java b/gds-values/src/main/java/org/neo4j/gds/values/primitive/ByteLongArrayImpl.java new file mode 100644 index 0000000000..6e5da36c37 --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/primitive/ByteLongArrayImpl.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.neo4j.gds.api.nodeproperties.ValueType; +import org.neo4j.gds.values.ArrayEquals; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.gds.values.LongArray; + +import java.util.Arrays; + +public class ByteLongArrayImpl implements LongArray { + + private final byte[] value; + + public ByteLongArrayImpl(byte[] value) { + this.value = value; + } + + @Override + public long[] longArrayValue() { + var copy = new long[value.length]; + for (int i = 0; i < value.length; i++) { + copy[i] = value[i]; + } + return copy; + } + + @Override + public long longValue(int idx) { + return value[idx]; + } + + @Override + public int length() { + return value.length; + } + + @Override + public ValueType type() { + return ValueType.LONG_ARRAY; + } + + @Override + public byte[] asObject() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o instanceof ByteLongArrayImpl) { + return equals(((ByteLongArrayImpl) o).value); + } + if (o instanceof GdsValue) { + return ArrayEquals.byteAndObject(value, ((GdsValue) o).asObject()); + } + return false; + } + + @Override + public boolean equals(byte[] o) { + return Arrays.equals(value, o); + } + + @Override + public boolean equals(short[] o) { + return ArrayEquals.byteAndShort(value, o); + } + + @Override + public boolean equals(int[] o) { + return ArrayEquals.byteAndInt(value, o); + } + + @Override + public boolean equals(long[] other) { + return ArrayEquals.byteAndLong(value, other); + } + + @Override + public boolean equals(float[] o) { + return ArrayEquals.byteAndFloat(value, o); + } + + @Override + public boolean equals(double[] o) { + return ArrayEquals.byteAndDouble(value, o); + } + + @Override + public int hashCode() { + return Arrays.hashCode(value); + } + + @Override + public String toString() { + return "LongArray" + Arrays.toString(value); + } +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/primitive/DoubleArrayImpl.java b/gds-values/src/main/java/org/neo4j/gds/values/primitive/DoubleArrayImpl.java new file mode 100644 index 0000000000..a6e859edee --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/primitive/DoubleArrayImpl.java @@ -0,0 +1,112 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.neo4j.gds.api.nodeproperties.ValueType; +import org.neo4j.gds.values.ArrayEquals; +import org.neo4j.gds.values.DoubleArray; +import org.neo4j.gds.values.GdsValue; + +import java.util.Arrays; + +public class DoubleArrayImpl implements DoubleArray { + + private final double[] value; + + public DoubleArrayImpl(double[] value) { + this.value = value; + } + + @Override + public double[] doubleArrayValue() { + return Arrays.copyOf(value, value.length); + } + + @Override + public double doubleValue(int idx) { + return value[idx]; + } + + @Override + public int length() { + return value.length; + } + + @Override + public ValueType type() { + return ValueType.DOUBLE_ARRAY; + } + + @Override + public double[] asObject() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o instanceof DoubleArray) { + return equals(((DoubleArray) o).doubleArrayValue()); + } else if (o instanceof GdsValue) { + return ArrayEquals.doubleAndObject(value, ((GdsValue) o).asObject()); + } + return false; + } + + @Override + public boolean equals(byte[] o) { + return ArrayEquals.byteAndDouble(o, value); + } + + @Override + public boolean equals(short[] o) { + return ArrayEquals.shortAndDouble(o, value); + } + + @Override + public boolean equals(int[] o) { + return ArrayEquals.intAndDouble(o, value); + } + + @Override + public boolean equals(long[] other) { + return ArrayEquals.longAndDouble(other, value); + } + + @Override + public boolean equals(float[] o) { + return ArrayEquals.floatAndDouble(o, value); + } + + @Override + public boolean equals(double[] o) { + return Arrays.equals(value, o); + } + + @Override + public int hashCode() { + return Arrays.hashCode(value); + } + + @Override + public String toString() { + return "DoubleArray" + Arrays.toString(value); + } +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/primitive/FloatArrayImpl.java b/gds-values/src/main/java/org/neo4j/gds/values/primitive/FloatArrayImpl.java new file mode 100644 index 0000000000..7739fe646a --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/primitive/FloatArrayImpl.java @@ -0,0 +1,121 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.neo4j.gds.api.nodeproperties.ValueType; +import org.neo4j.gds.values.ArrayEquals; +import org.neo4j.gds.values.FloatArray; +import org.neo4j.gds.values.GdsValue; + +import java.util.Arrays; + +public class FloatArrayImpl implements FloatArray { + + private final float[] value; + + public FloatArrayImpl(float[] value) { + this.value = value; + } + + @Override + public ValueType type() { + return ValueType.FLOAT_ARRAY; + } + + @Override + public double[] asObject() { + return doubleArrayValue(); + } + + @Override + public double[] doubleArrayValue() { + var copy = new double[value.length]; + for (int i = 0; i < value.length; i++) { + copy[i] = value[i]; + } + return copy; + } + + @Override + public double doubleValue(int idx) { + return value[idx]; + } + + @Override + public int length() { + return value.length; + } + + @Override + public float[] floatArrayValue() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o instanceof FloatArray) { + return equals(((FloatArray) o).floatArrayValue()); + } else if (o instanceof GdsValue) { + return ArrayEquals.floatAndObject(value, ((GdsValue) o).asObject()); + } + return false; + } + + @Override + public boolean equals(byte[] o) { + return ArrayEquals.byteAndFloat(o, value); + } + + @Override + public boolean equals(short[] o) { + return ArrayEquals.shortAndFloat(o, value); + } + + @Override + public boolean equals(int[] o) { + return ArrayEquals.intAndFloat(o, value); + } + + @Override + public boolean equals(long[] other) { + return ArrayEquals.longAndFloat(other, value); + } + + @Override + public boolean equals(float[] o) { + return Arrays.equals(value, o); + } + + @Override + public boolean equals(double[] o) { + return ArrayEquals.floatAndDouble(value, o); + } + + @Override + public int hashCode() { + return Arrays.hashCode(value); + } + + @Override + public String toString() { + return "FloatArray" + Arrays.toString(value); + } +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/primitive/FloatingPointValueImpl.java b/gds-values/src/main/java/org/neo4j/gds/values/primitive/FloatingPointValueImpl.java new file mode 100644 index 0000000000..8f4296f34c --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/primitive/FloatingPointValueImpl.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.neo4j.gds.api.nodeproperties.ValueType; +import org.neo4j.gds.values.FloatingPointValue; + +import java.util.Objects; + +public class FloatingPointValueImpl implements FloatingPointValue { + private final double value; + + public FloatingPointValueImpl(double value) { + this.value = value; + } + + @Override + public double doubleValue() { + return value; + } + + @Override + public ValueType type() { + return ValueType.DOUBLE; + } + + @Override + public Double asObject() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o instanceof FloatingPointValue) { + FloatingPointValue that = (FloatingPointValue) o; + return Double.compare(value, that.doubleValue()) == 0; + } + return false; + } + + @Override + public int hashCode() { + return Objects.hashCode(value); + } +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/primitive/IntLongArrayImpl.java b/gds-values/src/main/java/org/neo4j/gds/values/primitive/IntLongArrayImpl.java new file mode 100644 index 0000000000..673bcab3ce --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/primitive/IntLongArrayImpl.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.neo4j.gds.api.nodeproperties.ValueType; +import org.neo4j.gds.values.ArrayEquals; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.gds.values.LongArray; + +import java.util.Arrays; + +public class IntLongArrayImpl implements LongArray { + + private final int[] value; + + public IntLongArrayImpl(int[] value) { + this.value = value; + } + + @Override + public long[] longArrayValue() { + var copy = new long[value.length]; + for (int i = 0; i < value.length; i++) { + copy[i] = value[i]; + } + return copy; + } + + @Override + public long longValue(int idx) { + return value[idx]; + } + + @Override + public int length() { + return value.length; + } + + @Override + public ValueType type() { + return ValueType.LONG_ARRAY; + } + + @Override + public long[] asObject() { + return longArrayValue(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o instanceof IntLongArrayImpl) { + return equals(((IntLongArrayImpl) o).value); + } + if (o instanceof GdsValue) { + return ArrayEquals.intAndObject(value, ((GdsValue) o).asObject()); + } + return false; + } + + @Override + public boolean equals(byte[] o) { + return ArrayEquals.byteAndInt(o, value); + } + + @Override + public boolean equals(short[] o) { + return ArrayEquals.shortAndInt(o, value); + } + + @Override + public boolean equals(int[] o) { + return Arrays.equals(value, o); + } + + @Override + public boolean equals(long[] other) { + return ArrayEquals.intAndLong(value, other); + } + + @Override + public boolean equals(float[] o) { + return ArrayEquals.intAndFloat(value, o); + } + + @Override + public boolean equals(double[] o) { + return ArrayEquals.intAndDouble(value, o); + } + + @Override + public int hashCode() { + return Arrays.hashCode(value); + } + + @Override + public String toString() { + return "LongArray" + Arrays.toString(value); + } +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/primitive/LongArrayImpl.java b/gds-values/src/main/java/org/neo4j/gds/values/primitive/LongArrayImpl.java new file mode 100644 index 0000000000..8077b9da4a --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/primitive/LongArrayImpl.java @@ -0,0 +1,112 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.neo4j.gds.api.nodeproperties.ValueType; +import org.neo4j.gds.values.ArrayEquals; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.gds.values.LongArray; + +import java.util.Arrays; + +public class LongArrayImpl implements LongArray { + + private final long[] value; + + public LongArrayImpl(long[] value) { + this.value = value; + } + + @Override + public long[] longArrayValue() { + return Arrays.copyOf(value, value.length); + } + + @Override + public long longValue(int idx) { + return value[idx]; + } + + @Override + public int length() { + return value.length; + } + + @Override + public ValueType type() { + return ValueType.LONG_ARRAY; + } + + @Override + public long[] asObject() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o instanceof LongArray) { + return equals(((LongArray) o).longArrayValue()); + } else if (o instanceof GdsValue) { + return ArrayEquals.longAndObject(value, ((GdsValue) o).asObject()); + } + return false; + } + + @Override + public boolean equals(byte[] o) { + return ArrayEquals.byteAndLong(o, value); + } + + @Override + public boolean equals(short[] o) { + return ArrayEquals.shortAndLong(o, value); + } + + @Override + public boolean equals(int[] o) { + return ArrayEquals.intAndLong(o, value); + } + + @Override + public boolean equals(long[] other) { + return Arrays.equals(value, other); + } + + @Override + public boolean equals(float[] o) { + return ArrayEquals.longAndFloat(value, o); + } + + @Override + public boolean equals(double[] o) { + return ArrayEquals.longAndDouble(value, o); + } + + @Override + public int hashCode() { + return Arrays.hashCode(value); + } + + @Override + public String toString() { + return "LongArray" + Arrays.toString(value); + } +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/primitive/LongValueImpl.java b/gds-values/src/main/java/org/neo4j/gds/values/primitive/LongValueImpl.java new file mode 100644 index 0000000000..57717f4ee7 --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/primitive/LongValueImpl.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.neo4j.gds.api.nodeproperties.ValueType; +import org.neo4j.gds.values.IntegralValue; + +import java.util.Objects; + +public class LongValueImpl implements IntegralValue { + private final long value; + + public LongValueImpl(long value) { + this.value = value; + } + + @Override + public long longValue() { + return value; + } + + @Override + public ValueType type() { + return ValueType.LONG; + } + + @Override + public Long asObject() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o instanceof IntegralValue) { + IntegralValue that = (IntegralValue) o; + return value == that.longValue(); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hashCode(value); + } +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/primitive/PrimitiveValues.java b/gds-values/src/main/java/org/neo4j/gds/values/primitive/PrimitiveValues.java new file mode 100644 index 0000000000..12edd969ce --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/primitive/PrimitiveValues.java @@ -0,0 +1,162 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.jetbrains.annotations.Nullable; +import org.neo4j.gds.values.Array; +import org.neo4j.gds.values.DoubleArray; +import org.neo4j.gds.values.FloatArray; +import org.neo4j.gds.values.FloatingPointValue; +import org.neo4j.gds.values.GdsNoValue; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.gds.values.IntegralValue; +import org.neo4j.gds.values.LongArray; + +import java.util.Locale; +import java.util.Objects; + +public final class PrimitiveValues { + private static final long[] EMPTY_LONGS = new long[0]; + public static final GdsNoValue NO_VALUE = GdsNoValue.NO_VALUE; + public static final LongArray EMPTY_LONG_ARRAY = longArray(EMPTY_LONGS); + + public static GdsValue create(@Nullable Object value) { + GdsValue of = of(value); + if (of != null) { + return of; + } + Objects.requireNonNull(value); + throw new IllegalArgumentException( + String.format( + Locale.ENGLISH, + "[%s:%s] is not a supported property value", value, value.getClass().getName() + ) + ); + } + + private static @Nullable GdsValue of(Object value) { + if (value == null) return NO_VALUE; + if (value instanceof Number) { + return numberValue((Number) value); + } + if (value instanceof Object[]) { + return arrayValue((Object[]) value); + } + if (value instanceof byte[]) { + return byteArray((byte[]) value); + } + if (value instanceof short[]) { + return shortArray((short[]) value); + } + if (value instanceof int[]) { + return intArray((int[]) value); + } + if (value instanceof long[]) { + return longArray((long[]) value); + } + if (value instanceof float[]) { + return floatArray((float[]) value); + } + if (value instanceof double[]) { + return doubleArray((double[]) value); + } + return null; + } + + private static GdsValue numberValue(Number number) { + if (number instanceof Long longNumber) { + return longValue(longNumber); + } else if (number instanceof Integer intNumber) { + return longValue(intNumber); + } else if (number instanceof Double doubleNumber) { + return floatingPointValue(doubleNumber); + } else if (number instanceof Byte byteNumber) { + return longValue(byteNumber); + } else if (number instanceof Float floatNumber) { + return floatingPointValue(floatNumber); + } else if (number instanceof Short shortNumber) { + return longValue(shortNumber); + } else { + throw new UnsupportedOperationException("Unsupported type of Number " + number); + } + } + + private static @Nullable Array arrayValue(Object[] value) { + if (value instanceof Float[]) { + return floatArray(copy(value, new float[value.length])); + } + if (value instanceof Double[]) { + return doubleArray(copy(value, new double[value.length])); + } + if (value instanceof Long[]) { + return longArray(copy(value, new long[value.length])); + } + if (value instanceof Integer[]) { + return intArray(copy(value, new int[value.length])); + } + if (value instanceof Short[]) { + return shortArray(copy(value, new short[value.length])); + } + if (value instanceof Byte[]) { + return byteArray(copy(value, new byte[value.length])); + } + return null; + } + + + public static IntegralValue longValue(long value) { + return new LongValueImpl(value); + } + public static FloatingPointValue floatingPointValue(double value) { + return new FloatingPointValueImpl(value); + } + + public static DoubleArray doubleArray(double[] data) { + return new DoubleArrayImpl(data); + } + public static FloatArray floatArray(float[] data) { + return new FloatArrayImpl(data); + } + public static LongArray longArray(long[] data) { + return new LongArrayImpl(data); + } + public static LongArray intArray(int[] data) { + return new IntLongArrayImpl(data); + } + public static LongArray shortArray(short[] data) { + return new ShortLongArrayImpl(data); + } + public static LongArray byteArray(byte[] data) { + return new ByteLongArrayImpl(data); + } + + private static T copy(Object[] value, T target) { + for(int i = 0; i < value.length; ++i) { + if (value[i] == null) { + throw new IllegalArgumentException("Property array value elements may not be null."); + } + java.lang.reflect.Array.set(target, i, value[i]); + } + return target; + } + + private PrimitiveValues() {} + +} diff --git a/gds-values/src/main/java/org/neo4j/gds/values/primitive/ShortLongArrayImpl.java b/gds-values/src/main/java/org/neo4j/gds/values/primitive/ShortLongArrayImpl.java new file mode 100644 index 0000000000..f9ad180498 --- /dev/null +++ b/gds-values/src/main/java/org/neo4j/gds/values/primitive/ShortLongArrayImpl.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.neo4j.gds.api.nodeproperties.ValueType; +import org.neo4j.gds.values.ArrayEquals; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.gds.values.LongArray; + +import java.util.Arrays; + +public class ShortLongArrayImpl implements LongArray { + + private final short[] value; + + public ShortLongArrayImpl(short[] value) { + this.value = value; + } + + @Override + public long[] longArrayValue() { + var copy = new long[value.length]; + for (int i = 0; i < value.length; i++) { + copy[i] = value[i]; + } + return copy; + } + + @Override + public long longValue(int idx) { + return value[idx]; + } + + @Override + public int length() { + return value.length; + } + + @Override + public ValueType type() { + return ValueType.LONG_ARRAY; + } + + @Override + public short[] asObject() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o instanceof ShortLongArrayImpl) { + return equals(((ShortLongArrayImpl) o).value); + } + if (o instanceof GdsValue) { + return ArrayEquals.shortAndObject(value, ((GdsValue) o).asObject()); + } + return false; + } + + @Override + public boolean equals(byte[] o) { + return ArrayEquals.byteAndShort(o, value); + } + + @Override + public boolean equals(short[] o) { + return Arrays.equals(value, o); + } + + @Override + public boolean equals(int[] o) { + return ArrayEquals.shortAndInt(value, o); + } + + @Override + public boolean equals(long[] other) { + return ArrayEquals.shortAndLong(value, other); + } + + @Override + public boolean equals(float[] o) { + return ArrayEquals.shortAndFloat(value, o); + } + + @Override + public boolean equals(double[] o) { + return ArrayEquals.shortAndDouble(value, o); + } + + @Override + public int hashCode() { + return Arrays.hashCode(value); + } + + @Override + public String toString() { + return "LongArray" + Arrays.toString(value); + } +} diff --git a/gds-values/src/test/java/org/neo4j/gds/values/primitive/ByteLongArrayImplTest.java b/gds-values/src/test/java/org/neo4j/gds/values/primitive/ByteLongArrayImplTest.java new file mode 100644 index 0000000000..115e1942f6 --- /dev/null +++ b/gds-values/src/test/java/org/neo4j/gds/values/primitive/ByteLongArrayImplTest.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.junit.jupiter.api.Test; +import org.neo4j.gds.api.nodeproperties.ValueType; + +import static org.assertj.core.api.Assertions.assertThat; + +class ByteLongArrayImplTest { + + @Test + void longArrayValue() { + var value = new ByteLongArrayImpl(new byte[]{1, 2, 3, 4, 5, 6, 7, 8, 9}); + assertThat(value.longArrayValue()).isEqualTo(new long[]{1, 2, 3, 4, 5, 6, 7, 8, 9}); + } + + @Test + void longValue() { + var value = new ByteLongArrayImpl(new byte[]{1, 2, 3}); + assertThat(value.longValue(0)).isEqualTo(1); + assertThat(value.longValue(1)).isEqualTo(2); + assertThat(value.longValue(2)).isEqualTo(3); + } + + @Test + void length() { + var value = new ByteLongArrayImpl(new byte[]{1, 2, 3}); + assertThat(value.length()).isEqualTo(3); + } + + @Test + void type() { + var value = new ByteLongArrayImpl(new byte[]{1, 2, 3}); + assertThat(value.type()).isEqualTo(ValueType.LONG_ARRAY); + } + + @Test + void testEquals() { + // empty is equal to the empty constant + assertThat(new ByteLongArrayImpl(new byte[]{})).isEqualTo(PrimitiveValues.EMPTY_LONG_ARRAY); + + var value = new ByteLongArrayImpl(new byte[]{1, 2, 3}); + // equal to another instance if the held values are equal + assertThat(value).isEqualTo(new ByteLongArrayImpl(new byte[]{1, 2, 3})); + // equal to an array if the held array is equal + assertThat(value.equals(new byte[]{1, 2, 3})).isTrue(); + } +} diff --git a/gds-values/src/test/java/org/neo4j/gds/values/primitive/DoubleArrayImplTest.java b/gds-values/src/test/java/org/neo4j/gds/values/primitive/DoubleArrayImplTest.java new file mode 100644 index 0000000000..462cf1b128 --- /dev/null +++ b/gds-values/src/test/java/org/neo4j/gds/values/primitive/DoubleArrayImplTest.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.junit.jupiter.api.Test; +import org.neo4j.gds.api.nodeproperties.ValueType; + +import static org.assertj.core.api.Assertions.assertThat; + +class DoubleArrayImplTest { + + @Test + void doubleArrayValue() { + var value = new DoubleArrayImpl(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9}); + assertThat(value.doubleArrayValue()).isEqualTo(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9}); + } + + @Test + void longValue() { + var value = new DoubleArrayImpl(new double[]{1, 2, 3}); + assertThat(value.doubleValue(0)).isEqualTo(1); + assertThat(value.doubleValue(1)).isEqualTo(2); + assertThat(value.doubleValue(2)).isEqualTo(3); + } + + @Test + void length() { + var value = new DoubleArrayImpl(new double[]{1, 2, 3}); + assertThat(value.length()).isEqualTo(3); + } + + @Test + void type() { + var value = new DoubleArrayImpl(new double[]{1, 2, 3}); + assertThat(value.type()).isEqualTo(ValueType.DOUBLE_ARRAY); + } + + @Test + void testEquals() { + // empty is equal to the empty constant + assertThat(new DoubleArrayImpl(new double[]{})).isEqualTo(PrimitiveValues.EMPTY_LONG_ARRAY); + + var value = new DoubleArrayImpl(new double[]{1, 2, 3}); + // equal to another instance if the held values are equal + assertThat(value).isEqualTo(new DoubleArrayImpl(new double[]{1, 2, 3})); + // equal to an array if the held array is equal + assertThat(value.equals(new double[]{1, 2, 3})).isTrue(); + } +} diff --git a/gds-values/src/test/java/org/neo4j/gds/values/primitive/FloaArrayImplTest.java b/gds-values/src/test/java/org/neo4j/gds/values/primitive/FloaArrayImplTest.java new file mode 100644 index 0000000000..10b4c8f48e --- /dev/null +++ b/gds-values/src/test/java/org/neo4j/gds/values/primitive/FloaArrayImplTest.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.junit.jupiter.api.Test; +import org.neo4j.gds.api.nodeproperties.ValueType; + +import static org.assertj.core.api.Assertions.assertThat; + +class FloaArrayImplTest { + + @Test + void floatArrayValue() { + var value = new FloatArrayImpl(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}); + assertThat(value.floatArrayValue()).isEqualTo(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9}); + } + + @Test + void longValue() { + var value = new FloatArrayImpl(new float[]{1, 2, 3}); + assertThat(value.doubleValue(0)).isEqualTo(1); + assertThat(value.doubleValue(1)).isEqualTo(2); + assertThat(value.doubleValue(2)).isEqualTo(3); + } + + @Test + void length() { + var value = new FloatArrayImpl(new float[]{1, 2, 3}); + assertThat(value.length()).isEqualTo(3); + } + + @Test + void type() { + var value = new FloatArrayImpl(new float[]{1, 2, 3}); + assertThat(value.type()).isEqualTo(ValueType.FLOAT_ARRAY); + } + + @Test + void testEquals() { + // empty is equal to the empty constant + assertThat(new FloatArrayImpl(new float[]{})).isEqualTo(PrimitiveValues.EMPTY_LONG_ARRAY); + + var value = new FloatArrayImpl(new float[]{1, 2, 3}); + // equal to another instance if the held values are equal + assertThat(value).isEqualTo(new FloatArrayImpl(new float[]{1, 2, 3})); + // equal to an array if the held array is equal + assertThat(value.equals(new float[]{1, 2, 3})).isTrue(); + } +} diff --git a/gds-values/src/test/java/org/neo4j/gds/values/primitive/IntLongArrayImplTest.java b/gds-values/src/test/java/org/neo4j/gds/values/primitive/IntLongArrayImplTest.java new file mode 100644 index 0000000000..f7911fad49 --- /dev/null +++ b/gds-values/src/test/java/org/neo4j/gds/values/primitive/IntLongArrayImplTest.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.junit.jupiter.api.Test; +import org.neo4j.gds.api.nodeproperties.ValueType; + +import static org.assertj.core.api.Assertions.assertThat; + +class IntLongArrayImplTest { + + @Test + void longArrayValue() { + var value = new IntLongArrayImpl(new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9}); + assertThat(value.longArrayValue()).isEqualTo(new long[]{1, 2, 3, 4, 5, 6, 7, 8, 9}); + } + + @Test + void longValue() { + var value = new IntLongArrayImpl(new int[]{1, 2, 3}); + assertThat(value.longValue(0)).isEqualTo(1); + assertThat(value.longValue(1)).isEqualTo(2); + assertThat(value.longValue(2)).isEqualTo(3); + } + + @Test + void length() { + var value = new IntLongArrayImpl(new int[]{1, 2, 3}); + assertThat(value.length()).isEqualTo(3); + } + + @Test + void type() { + var value = new IntLongArrayImpl(new int[]{1, 2, 3}); + assertThat(value.type()).isEqualTo(ValueType.LONG_ARRAY); + } + + @Test + void testEquals() { + // empty is equal to the empty constant + assertThat(new IntLongArrayImpl(new int[]{})).isEqualTo(PrimitiveValues.EMPTY_LONG_ARRAY); + + var value = new IntLongArrayImpl(new int[]{1, 2, 3}); + // equal to another instance if the held values are equal + assertThat(value).isEqualTo(new IntLongArrayImpl(new int[]{1, 2, 3})); + // equal to an array if the held array is equal + assertThat(value.equals(new int[]{1, 2, 3})).isTrue(); + } +} diff --git a/gds-values/src/test/java/org/neo4j/gds/values/primitive/LongArrayImplTest.java b/gds-values/src/test/java/org/neo4j/gds/values/primitive/LongArrayImplTest.java new file mode 100644 index 0000000000..186c6165fe --- /dev/null +++ b/gds-values/src/test/java/org/neo4j/gds/values/primitive/LongArrayImplTest.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.junit.jupiter.api.Test; +import org.neo4j.gds.api.nodeproperties.ValueType; + +import static org.assertj.core.api.Assertions.assertThat; + +class LongArrayImplTest { + + @Test + void longArrayValue() { + var value = new LongArrayImpl(new long[]{1, 2, 3, 4, 5, 6, 7, 8, 9}); + assertThat(value.longArrayValue()).isEqualTo(new long[]{1, 2, 3, 4, 5, 6, 7, 8, 9}); + } + + @Test + void longValue() { + var value = new LongArrayImpl(new long[]{1, 2, 3}); + assertThat(value.longValue(0)).isEqualTo(1); + assertThat(value.longValue(1)).isEqualTo(2); + assertThat(value.longValue(2)).isEqualTo(3); + } + + @Test + void length() { + var value = new LongArrayImpl(new long[]{1, 2, 3}); + assertThat(value.length()).isEqualTo(3); + } + + @Test + void type() { + var value = new LongArrayImpl(new long[]{1, 2, 3}); + assertThat(value.type()).isEqualTo(ValueType.LONG_ARRAY); + } + + @Test + void testEquals() { + // empty is equal to the empty constant + assertThat(new LongArrayImpl(new long[]{})).isEqualTo(PrimitiveValues.EMPTY_LONG_ARRAY); + + var value = new LongArrayImpl(new long[]{1, 2, 3}); + // equal to another instance if the held values are equal + assertThat(value).isEqualTo(new LongArrayImpl(new long[]{1, 2, 3})); + // equal to an array if the held array is equal + assertThat(value.equals(new long[]{1, 2, 3})).isTrue(); + } +} diff --git a/gds-values/src/test/java/org/neo4j/gds/values/primitive/PrimitiveValuesTest.java b/gds-values/src/test/java/org/neo4j/gds/values/primitive/PrimitiveValuesTest.java new file mode 100644 index 0000000000..a690eb6d11 --- /dev/null +++ b/gds-values/src/test/java/org/neo4j/gds/values/primitive/PrimitiveValuesTest.java @@ -0,0 +1,134 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.junit.jupiter.api.Test; +import org.neo4j.gds.api.nodeproperties.ValueType; +import org.neo4j.gds.values.FloatingPointValue; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class PrimitiveValuesTest { + + @Test + void shouldCreateLongValuesFromScalarIntegers() { + assertThat(PrimitiveValues.longValue(42)) + .satisfies(value -> { + assertThat(value.type()).isEqualTo(ValueType.LONG); + assertThat(value).isEqualTo(PrimitiveValues.create((byte) 42)); + assertThat(value).isEqualTo(PrimitiveValues.create((short) 42)); + assertThat(value).isEqualTo(PrimitiveValues.create((int) 42)); + assertThat(value).isEqualTo(PrimitiveValues.create((long) 42)); + }); + } + + @Test + void shouldCreateDoubleValuesFromScalarFloatingPoints() { + FloatingPointValue floatValue = PrimitiveValues.floatingPointValue(1F); + assertThat(floatValue).satisfies(value -> { + assertThat(value.type()).isEqualTo(ValueType.DOUBLE); + assertThat(value.doubleValue()).isEqualTo(1F); + assertThat(value).isEqualTo(PrimitiveValues.create(1F)); + }); + FloatingPointValue doubleValue = PrimitiveValues.floatingPointValue(1D); + assertThat(doubleValue).satisfies(value -> { + assertThat(value.type()).isEqualTo(ValueType.DOUBLE); + assertThat(value.doubleValue()).isEqualTo(1D); + assertThat(value).isEqualTo(PrimitiveValues.create(1D)); + }); + assertThat(floatValue).isEqualTo(doubleValue); + } + + @Test + void shouldFailCreatingValuesFromSomeOtherThings() { + assertThatThrownBy(() -> PrimitiveValues.create('c')) + .hasMessageContaining("java.lang.Character") + .hasMessageContaining("is not a supported property value"); + assertThatThrownBy(() -> PrimitiveValues.create("string")) + .hasMessageContaining("java.lang.String") + .hasMessageContaining("is not a supported property value"); + assertThatThrownBy(() -> PrimitiveValues.create(List.of(1, 2))) + .hasMessageContaining("[1, 2]") + .hasMessageContaining("List") + .hasMessageContaining("is not a supported property value"); + assertThatThrownBy(() -> PrimitiveValues.create(new Object[]{'c', "string"})) + .hasMessageContaining("[Ljava.lang.Object;") + .hasMessageContaining("is not a supported property value"); + } + + @Test + void shouldCreateValuesFromIntegerArrays() { + var byteArray = PrimitiveValues.byteArray(new byte[]{1, 2, 3}); + assertThat(byteArray) + .satisfies(arr -> { + assertThat(arr.type()).isEqualTo(ValueType.LONG_ARRAY); + assertThat(arr).isEqualTo(PrimitiveValues.create(new byte[]{1, 2, 3})); + assertThat(arr).isEqualTo(PrimitiveValues.create(new Byte[]{(byte) 1, (byte) 2, (byte) 3})); + }); + var shortArray = PrimitiveValues.shortArray(new short[]{1, 2, 3}); + assertThat(shortArray) + .satisfies(arr -> { + assertThat(arr.type()).isEqualTo(ValueType.LONG_ARRAY); + assertThat(arr).isEqualTo(PrimitiveValues.create(new short[]{1, 2, 3})); + assertThat(arr).isEqualTo(PrimitiveValues.create(new Short[]{(short) 1, (short) 2, (short) 3})); + }); + var intArray = PrimitiveValues.intArray(new int[]{1, 2, 3}); + assertThat(intArray) + .satisfies(arr -> { + assertThat(arr.type()).isEqualTo(ValueType.LONG_ARRAY); + assertThat(arr).isEqualTo(PrimitiveValues.create(new int[]{1, 2, 3})); + assertThat(arr).isEqualTo(PrimitiveValues.create(new Integer[]{(int) 1, (int) 2, (int) 3})); + }); + var longArray = PrimitiveValues.longArray(new long[]{1, 2, 3}); + assertThat(longArray) + .satisfies(arr -> { + assertThat(arr.type()).isEqualTo(ValueType.LONG_ARRAY); + assertThat(arr).isEqualTo(PrimitiveValues.create(new long[]{1, 2, 3})); + assertThat(arr).isEqualTo(PrimitiveValues.create(new Long[]{(long) 1, (long) 2, (long) 3})); + }); + assertThat(longArray) + .isEqualTo(byteArray) + .isEqualTo(shortArray) + .isEqualTo(intArray); + } + + @Test + void shouldCreateValuesFromFloatingPointArrays() { + var floatArray = PrimitiveValues.floatArray(new float[]{1, 2, 3}); + assertThat(floatArray) + .satisfies(arr -> { + assertThat(arr.type()).isEqualTo(ValueType.FLOAT_ARRAY); + assertThat(arr).isEqualTo(PrimitiveValues.create(new float[]{1, 2, 3})); + assertThat(arr).isEqualTo(PrimitiveValues.create(new Float[]{(float) 1, (float) 2, (float) 3})); + }); + var doubleArray = PrimitiveValues.doubleArray(new double[]{1, 2, 3}); + assertThat(doubleArray) + .satisfies(arr -> { + assertThat(arr.type()).isEqualTo(ValueType.DOUBLE_ARRAY); + assertThat(arr).isEqualTo(PrimitiveValues.create(new double[]{1, 2, 3})); + assertThat(arr).isEqualTo(PrimitiveValues.create(new Double[]{(double) 1, (double) 2, (double) 3})); + }); + assertThat(doubleArray) + .isEqualTo(floatArray); + } +} diff --git a/gds-values/src/test/java/org/neo4j/gds/values/primitive/ShortLongArrayImplTest.java b/gds-values/src/test/java/org/neo4j/gds/values/primitive/ShortLongArrayImplTest.java new file mode 100644 index 0000000000..707c78c6ce --- /dev/null +++ b/gds-values/src/test/java/org/neo4j/gds/values/primitive/ShortLongArrayImplTest.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.values.primitive; + +import org.junit.jupiter.api.Test; +import org.neo4j.gds.api.nodeproperties.ValueType; + +import static org.assertj.core.api.Assertions.assertThat; + +class ShortLongArrayImplTest { + + @Test + void longArrayValue() { + var value = new ShortLongArrayImpl(new short[]{1, 2, 3, 4, 5, 6, 7, 8, 9}); + assertThat(value.longArrayValue()).isEqualTo(new long[]{1, 2, 3, 4, 5, 6, 7, 8, 9}); + } + + @Test + void longValue() { + var value = new ShortLongArrayImpl(new short[]{1, 2, 3}); + assertThat(value.longValue(0)).isEqualTo(1); + assertThat(value.longValue(1)).isEqualTo(2); + assertThat(value.longValue(2)).isEqualTo(3); + } + + @Test + void length() { + var value = new ShortLongArrayImpl(new short[]{1, 2, 3}); + assertThat(value.length()).isEqualTo(3); + } + + @Test + void type() { + var value = new ShortLongArrayImpl(new short[]{1, 2, 3}); + assertThat(value.type()).isEqualTo(ValueType.LONG_ARRAY); + } + + @Test + void testEquals() { + // empty is equal to the empty constant + assertThat(new ShortLongArrayImpl(new short[]{})).isEqualTo(PrimitiveValues.EMPTY_LONG_ARRAY); + + var value = new ShortLongArrayImpl(new short[]{1, 2, 3}); + // equal to another instance if the held values are equal + assertThat(value).isEqualTo(new ShortLongArrayImpl(new short[]{1, 2, 3})); + // equal to an array if the held array is equal + assertThat(value.equals(new short[]{1, 2, 3})).isTrue(); + } +} diff --git a/io/core/build.gradle b/io/core/build.gradle index 3dedd92182..dbc5d7a1b5 100644 --- a/io/core/build.gradle +++ b/io/core/build.gradle @@ -22,6 +22,7 @@ dependencies { implementation project(':core') implementation project(':core-utils') implementation project(':graph-schema-api') + implementation project(':gds-values') implementation project(':logging') implementation project(':memory-usage') implementation project(':neo4j-kernel-adapter') diff --git a/io/core/src/main/java/org/neo4j/gds/core/io/file/GraphStoreNodeVisitor.java b/io/core/src/main/java/org/neo4j/gds/core/io/file/GraphStoreNodeVisitor.java index 6e32393e00..e037ee4633 100644 --- a/io/core/src/main/java/org/neo4j/gds/core/io/file/GraphStoreNodeVisitor.java +++ b/io/core/src/main/java/org/neo4j/gds/core/io/file/GraphStoreNodeVisitor.java @@ -22,6 +22,8 @@ import org.neo4j.gds.api.schema.NodeSchema; import org.neo4j.gds.core.loading.construction.NodeLabelTokens; import org.neo4j.gds.core.loading.construction.NodesBuilder; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.gds.values.primitive.PrimitiveValues; import java.util.HashMap; import java.util.Map; @@ -37,12 +39,12 @@ public GraphStoreNodeVisitor(NodeSchema nodeSchema, NodesBuilder nodesBuilder) { @Override protected void exportElement() { - Map props = new HashMap<>(); + Map props = new HashMap<>(); forEachProperty((key, value) -> { - props.put(key, value); + props.put(key, PrimitiveValues.create(value)); }); var nodeLabels = NodeLabelTokens.of(labels()); - nodesBuilder.addNodeWithPropertiesAsObjects(id(), props, nodeLabels); + nodesBuilder.addNode(id(), props, nodeLabels); } public static final class Builder extends NodeVisitor.Builder { diff --git a/legacy-cypher-projection/build.gradle b/legacy-cypher-projection/build.gradle index 497be92090..eddde6f013 100644 --- a/legacy-cypher-projection/build.gradle +++ b/legacy-cypher-projection/build.gradle @@ -25,6 +25,7 @@ dependencies { implementation project(':config-api') implementation project(':core') implementation project(':graph-schema-api') + implementation project(':gds-values') implementation project(':memory-usage') implementation project(':neo4j-values') implementation project(':progress-tracking') diff --git a/legacy-cypher-projection/src/integrationTest/java/org/neo4j/gds/legacycypherprojection/CypherFactoryTest.java b/legacy-cypher-projection/src/integrationTest/java/org/neo4j/gds/legacycypherprojection/CypherFactoryTest.java index cf577d0dfb..669f500ff5 100644 --- a/legacy-cypher-projection/src/integrationTest/java/org/neo4j/gds/legacycypherprojection/CypherFactoryTest.java +++ b/legacy-cypher-projection/src/integrationTest/java/org/neo4j/gds/legacycypherprojection/CypherFactoryTest.java @@ -235,12 +235,12 @@ void failsOnBadMixedList() { assertThatThrownBy(() -> applyInFullAccessTransaction(db, tx -> builder.build().graph())) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Only lists of uniformly typed numbers are supported as GDS node properties") + .hasMessageContaining("Unsupported conversion to GDS Value from Neo4j Value") .hasMessageContaining("List{Long(1), Long(2), Boolean('true')}"); } @Test - void failsOnMixedNumbersList() { + void mixedNumbersListNodeProperty() { var nodeQuery = "RETURN 0 AS id, [1, 2, 1.23] AS list"; var builder = new CypherLoaderBuilder() @@ -248,10 +248,9 @@ void failsOnMixedNumbersList() { .nodeQuery(nodeQuery) .relationshipQuery("RETURN 0 AS source, 0 AS target LIMIT 0"); - assertThatThrownBy(() -> applyInFullAccessTransaction(db, tx -> builder.build().graph())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Only lists of uniformly typed numbers are supported as GDS node properties") - .hasMessageContaining("List{Long(1), Long(2), Double(1"); // omitting the rest of the double for locale reasons + var graph = applyInFullAccessTransaction(db, tx -> builder.build().graph()); + var value = graph.nodeProperties("list").doubleArrayValue(0); + assertThat(value).isEqualTo(new double[]{1.0, 2.0, 1.23}); } @Test @@ -265,7 +264,7 @@ void failsOnBadUniformList() { assertThatThrownBy(() -> applyInFullAccessTransaction(db, tx -> builder.build().graph())) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Only lists of uniformly typed numbers are supported as GDS node properties") + .hasMessageContaining("Unsupported conversion to GDS Value from Neo4j Value") .hasMessageContaining("List{String(\"forty\"), String(\"two\")}"); } diff --git a/legacy-cypher-projection/src/main/java/org/neo4j/gds/legacycypherprojection/NodeSubscriber.java b/legacy-cypher-projection/src/main/java/org/neo4j/gds/legacycypherprojection/NodeSubscriber.java index d86f592b56..f4936bc8c4 100644 --- a/legacy-cypher-projection/src/main/java/org/neo4j/gds/legacycypherprojection/NodeSubscriber.java +++ b/legacy-cypher-projection/src/main/java/org/neo4j/gds/legacycypherprojection/NodeSubscriber.java @@ -19,15 +19,15 @@ */ package org.neo4j.gds.legacycypherprojection; -import org.neo4j.gds.core.loading.ValueConverter; +import org.neo4j.gds.core.loading.GdsNeo4jValueConverter; import org.neo4j.gds.core.loading.construction.NodesBuilder; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.values.CypherNodeLabelTokens; +import org.neo4j.gds.values.GdsValue; import org.neo4j.graphdb.QueryStatistics; import org.neo4j.values.AnyValue; import org.neo4j.values.SequenceValue; import org.neo4j.values.storable.NumberValue; -import org.neo4j.values.storable.Value; import org.neo4j.values.storable.Values; import org.neo4j.values.virtual.VirtualValues; @@ -54,7 +54,7 @@ class NodeSubscriber extends ErrorCachingQuerySubscriber { private long neoId = -1L; private SequenceValue labels = VirtualValues.EMPTY_LIST; - private Map properties; + private Map properties; private int idOffset = UNINITIALIZED; private int labelOffset = UNINITIALIZED; @@ -118,7 +118,7 @@ public void onField(int offset, AnyValue value) { labels = (SequenceValue) value; } else {//properties if ( value != Values.NO_VALUE) { - properties.put(fieldNames[offset], ValueConverter.toValue(value)); + properties.put(fieldNames[offset], GdsNeo4jValueConverter.toValue(value)); } } } diff --git a/native-projection/build.gradle b/native-projection/build.gradle index 1b006a8086..60bb0241a5 100644 --- a/native-projection/build.gradle +++ b/native-projection/build.gradle @@ -25,6 +25,7 @@ dependencies { implementation project(':core') implementation project(':core-utils') implementation project(':graph-schema-api') + implementation project(':gds-values') implementation project(':logging') implementation project(':memory-usage') implementation project(':neo4j-kernel-adapter') diff --git a/settings.gradle b/settings.gradle index 6e12ce145d..fb13cd8d51 100644 --- a/settings.gradle +++ b/settings.gradle @@ -320,6 +320,9 @@ project(':defaults-and-limits-configuration').projectDir = file('defaults-and-li include('neo4j-values') project(':neo4j-values').projectDir = file('neo4j-values') +include('gds-values') +project(':gds-values').projectDir = file('gds-values') + diff --git a/test-utils/build.gradle b/test-utils/build.gradle index def20a93f3..cbe7bc24ce 100644 --- a/test-utils/build.gradle +++ b/test-utils/build.gradle @@ -21,6 +21,7 @@ dependencies { implementation project(':config-api') implementation project(':core') implementation project(':core-utils') + implementation project(':gds-values') implementation project(':legacy-cypher-projection') implementation project(':native-projection') implementation project(':graph-schema-api') diff --git a/test-utils/src/main/java/org/neo4j/gds/gdl/GdlFactory.java b/test-utils/src/main/java/org/neo4j/gds/gdl/GdlFactory.java index 51c3407136..e25d6dd21e 100644 --- a/test-utils/src/main/java/org/neo4j/gds/gdl/GdlFactory.java +++ b/test-utils/src/main/java/org/neo4j/gds/gdl/GdlFactory.java @@ -57,8 +57,9 @@ import org.neo4j.gds.extension.GdlSupportPerMethodExtension; import org.neo4j.gds.mem.MemoryEstimation; import org.neo4j.gds.mem.MemoryEstimations; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.gds.values.primitive.PrimitiveValues; import org.neo4j.values.storable.ArrayValue; -import org.neo4j.values.storable.Value; import org.neo4j.values.storable.Values; import org.s1ck.gdl.GDLHandler; import org.s1ck.gdl.model.Element; @@ -214,12 +215,12 @@ private Nodes loadNodes() { .collect(Collectors.toList()); } - Map propertyValues = new HashMap<>(); + Map propertyValues = new HashMap<>(); vertex.getProperties().forEach((propertyKey, propertyValue) -> { if (propertyValue instanceof List) { propertyValue = convertListProperty((List) propertyValue); } - propertyValues.put(propertyKey, Values.of(propertyValue)); + propertyValues.put(propertyKey, PrimitiveValues.create(propertyValue)); }); nodesBuilder.addNode( diff --git a/triplet-graph-builder/build.gradle b/triplet-graph-builder/build.gradle index d8cad34706..612381acdb 100644 --- a/triplet-graph-builder/build.gradle +++ b/triplet-graph-builder/build.gradle @@ -25,6 +25,7 @@ dependencies { implementation project(':annotations') implementation project(':config-api') implementation project(':core') + implementation project(':gds-values') implementation project(':graph-schema-api') implementation project(':logging') implementation project(':memory-usage') diff --git a/triplet-graph-builder/src/main/java/org/neo4j/gds/projection/RelationshipPropertyExtractor.java b/triplet-graph-builder/src/main/java/org/neo4j/gds/projection/RelationshipPropertyExtractor.java index 61b1dac5d0..33bfb230c2 100644 --- a/triplet-graph-builder/src/main/java/org/neo4j/gds/projection/RelationshipPropertyExtractor.java +++ b/triplet-graph-builder/src/main/java/org/neo4j/gds/projection/RelationshipPropertyExtractor.java @@ -19,10 +19,12 @@ */ package org.neo4j.gds.projection; +import org.neo4j.gds.api.ValueConversion; import org.neo4j.gds.core.Aggregation; -import org.neo4j.values.storable.NumberValue; -import org.neo4j.values.storable.Value; -import org.neo4j.values.storable.Values; +import org.neo4j.gds.values.FloatingPointValue; +import org.neo4j.gds.values.GdsNoValue; +import org.neo4j.gds.values.GdsValue; +import org.neo4j.gds.values.IntegralValue; import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; @@ -32,18 +34,22 @@ private RelationshipPropertyExtractor() { throw new UnsupportedOperationException("No instances"); } - public static double extractValue(Value value, double defaultValue) { + public static double extractValue(GdsValue value, double defaultValue) { return extractValue(Aggregation.NONE, value, defaultValue); } - public static double extractValue(Aggregation aggregation, Value value, double defaultValue) { + public static double extractValue(Aggregation aggregation, GdsValue value, double defaultValue) { // slightly different logic than org.neo4j.values.storable.Values#coerceToDouble // b/c we want to fall back to the default value if the value is empty - if (value instanceof NumberValue) { - double propertyValue = ((NumberValue) value).doubleValue(); + if (value instanceof FloatingPointValue) { + double propertyValue = ((FloatingPointValue) value).doubleValue(); return aggregation.normalizePropertyValue(propertyValue); } - if (Values.NO_VALUE.equals(value)) { + if (value instanceof IntegralValue) { + double propertyValue = ValueConversion.exactLongToDouble(((IntegralValue) value).longValue()); + return aggregation.normalizePropertyValue(propertyValue); + } + if (GdsNoValue.NO_VALUE.equals(value)) { return aggregation.emptyValue(defaultValue); } @@ -51,7 +57,7 @@ public static double extractValue(Aggregation aggregation, Value value, double d // Do we want to do so or is failing on non numeric properties ok? throw new IllegalArgumentException(formatWithLocale( "Unsupported type [%s] of value %s. Please use a numeric property.", - value.valueRepresentation().valueGroup(), + value.type(), value )); } diff --git a/triplet-graph-builder/src/test/java/org/neo4j/gds/projection/GraphImporterTest.java b/triplet-graph-builder/src/test/java/org/neo4j/gds/projection/GraphImporterTest.java index cb3b0d9931..748e124534 100644 --- a/triplet-graph-builder/src/test/java/org/neo4j/gds/projection/GraphImporterTest.java +++ b/triplet-graph-builder/src/test/java/org/neo4j/gds/projection/GraphImporterTest.java @@ -42,7 +42,7 @@ import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker; import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory; import org.neo4j.gds.logging.LogAdapter; -import org.neo4j.values.storable.Values; +import org.neo4j.gds.values.primitive.PrimitiveValues; import java.util.List; import java.util.Map; @@ -168,8 +168,8 @@ void shouldImportNodesWithProperties() { importer.update( i, i + 1, - PropertyValues.of(Map.of("prop", Values.longValue(i))), - PropertyValues.of(Map.of("prop", Values.longValue(i + 1))), + PropertyValues.of(Map.of("prop", PrimitiveValues.longValue(i))), + PropertyValues.of(Map.of("prop", PrimitiveValues.longValue(i + 1))), NodeLabelTokens.empty(), NodeLabelTokens.empty(), RelationshipType.ALL_RELATIONSHIPS, @@ -214,8 +214,8 @@ void shouldImportNodesWithPropertiesWithDifferentSchemas() { importer.update( i, i + 1, - PropertyValues.of(Map.of("prop" + i, Values.longValue(i))), - PropertyValues.of(Map.of("prop" + j, Values.longValue(j))), + PropertyValues.of(Map.of("prop" + i, PrimitiveValues.longValue(i))), + PropertyValues.of(Map.of("prop" + j, PrimitiveValues.longValue(j))), NodeLabelTokens.ofStrings("Label" + i), NodeLabelTokens.ofStrings("Label" + (j)), RelationshipType.ALL_RELATIONSHIPS, @@ -307,7 +307,7 @@ void shouldImportRelationshipsWithProperties() { NodeLabelTokens.empty(), NodeLabelTokens.empty(), RelationshipType.of("REL" + i), - PropertyValues.of(Map.of("prop" + i, Values.longValue(i))) + PropertyValues.of(Map.of("prop" + i, PrimitiveValues.longValue(i))) ); } diff --git a/triplet-graph-builder/src/test/java/org/neo4j/gds/projection/RelationshipPropertyExtractorTest.java b/triplet-graph-builder/src/test/java/org/neo4j/gds/projection/RelationshipPropertyExtractorTest.java index 29ba9f0c88..bfa86067de 100644 --- a/triplet-graph-builder/src/test/java/org/neo4j/gds/projection/RelationshipPropertyExtractorTest.java +++ b/triplet-graph-builder/src/test/java/org/neo4j/gds/projection/RelationshipPropertyExtractorTest.java @@ -22,118 +22,67 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.EnumSource; -import org.junit.jupiter.params.provider.MethodSource; import org.neo4j.gds.core.Aggregation; -import org.neo4j.values.storable.Value; -import org.neo4j.values.storable.Values; - -import java.util.Arrays; -import java.util.stream.Stream; +import org.neo4j.gds.values.primitive.PrimitiveValues; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.params.provider.Arguments.arguments; -import static org.neo4j.gds.TestSupport.crossArguments; -import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; class RelationshipPropertyExtractorTest { @Test void extractValueReadsAnyNumericType() { - Assertions.assertEquals(42.0, RelationshipPropertyExtractor.extractValue(Values.byteValue((byte) 42), 0.0)); - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(Values.shortValue((short) 42), 0.0)); - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(Values.intValue(42), 0.0)); - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(Values.longValue(42), 0.0)); - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(Values.floatValue(42.0F), 0.0)); - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(Values.doubleValue(42.0D), 0.0)); - assertTrue(Double.isNaN(RelationshipPropertyExtractor.extractValue(Values.floatValue(Float.NaN), 0.0))); - assertTrue(Double.isNaN(RelationshipPropertyExtractor.extractValue(Values.doubleValue(Double.NaN), 0.0))); + Assertions.assertEquals(42.0, RelationshipPropertyExtractor.extractValue(PrimitiveValues.longValue((byte) 42), 0.0)); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(PrimitiveValues.longValue((short) 42), 0.0)); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(PrimitiveValues.longValue(42), 0.0)); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(PrimitiveValues.longValue(42), 0.0)); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(PrimitiveValues.floatingPointValue(42.0F), 0.0)); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(PrimitiveValues.floatingPointValue(42.0D), 0.0)); + assertTrue(Double.isNaN(RelationshipPropertyExtractor.extractValue(PrimitiveValues.floatingPointValue(Float.NaN), 0.0))); + assertTrue(Double.isNaN(RelationshipPropertyExtractor.extractValue(PrimitiveValues.floatingPointValue(Double.NaN), 0.0))); } @ParameterizedTest @EnumSource(value = Aggregation.class, names = "COUNT", mode = EnumSource.Mode.EXCLUDE) void extractValueReadsAnyNumericTypeWithAggregationExceptCount(Aggregation aggregation) { - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.byteValue((byte) 42), 0.0)); - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.shortValue((short) 42), 0.0)); - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.intValue(42), 0.0)); - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.longValue(42), 0.0)); - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.floatValue(42.0F), 0.0)); - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.doubleValue(42.0D), 0.0)); - assertTrue(Double.isNaN(RelationshipPropertyExtractor.extractValue(aggregation, Values.floatValue(Float.NaN), 0.0))); - assertTrue(Double.isNaN(RelationshipPropertyExtractor.extractValue(aggregation, Values.doubleValue(Double.NaN), 0.0))); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.longValue((byte) 42), 0.0)); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.longValue((short) 42), 0.0)); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.longValue(42), 0.0)); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.longValue(42), 0.0)); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.floatingPointValue(42.0F), 0.0)); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.floatingPointValue(42.0D), 0.0)); + assertTrue(Double.isNaN(RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.floatingPointValue(Float.NaN), 0.0))); + assertTrue(Double.isNaN(RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.floatingPointValue(Double.NaN), 0.0))); } @ParameterizedTest @EnumSource(value = Aggregation.class, names = "COUNT", mode = EnumSource.Mode.INCLUDE) void extractValueReadsAnyNumericTypeWithCountAggregation(Aggregation aggregation) { - assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.byteValue((byte) 42), 0.0)); - assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.shortValue((short) 42), 0.0)); - assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.intValue(42), 0.0)); - assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.longValue(42), 0.0)); - assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.floatValue(42.0F), 0.0)); - assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.doubleValue(42.0D), 0.0)); - assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.floatValue(Float.NaN), 0.0)); - assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.doubleValue(Double.NaN), 0.0)); + assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.longValue((byte) 42), 0.0)); + assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.longValue((short) 42), 0.0)); + assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.longValue(42), 0.0)); + assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.longValue(42), 0.0)); + assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.floatingPointValue(42.0F), 0.0)); + assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.floatingPointValue(42.0D), 0.0)); + assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.floatingPointValue(Float.NaN), 0.0)); + assertEquals(1.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.floatingPointValue(Double.NaN), 0.0)); } @Test void extractValueReturnsDefaultWhenValueDoesNotExist() { - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(Values.NO_VALUE, 42.0)); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(PrimitiveValues.NO_VALUE, 42.0)); } @ParameterizedTest @EnumSource(value = Aggregation.class, names = "COUNT", mode = EnumSource.Mode.EXCLUDE) void extractValueReturnsDefaultWhenValueDoesNotExistForAggregationsExceptCount(Aggregation aggregation) { - assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.NO_VALUE, 42.0)); + assertEquals(42.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.NO_VALUE, 42.0)); } @ParameterizedTest @EnumSource(value = Aggregation.class, names = "COUNT", mode = EnumSource.Mode.INCLUDE) void extractValueReturnsZeroWhenValueDoesNotExistForCountAggregation(Aggregation aggregation) { - assertEquals(0.0, RelationshipPropertyExtractor.extractValue(aggregation, Values.NO_VALUE, 42.0)); - } - - @ParameterizedTest - @MethodSource("invalidProperties") - void extractValueFailsForNonNumericTypes(Value value, String typePart, String valuePart) { - IllegalArgumentException exception = assertThrows( - IllegalArgumentException.class, - () -> RelationshipPropertyExtractor.extractValue(value, 42.0) - ); - String expectedErrorMessage = formatWithLocale( - "Unsupported type [%s] of value %s. Please use a numeric property.", - typePart, - valuePart - ); - assertEquals(expectedErrorMessage, exception.getMessage()); - } - - @ParameterizedTest - @MethodSource("invalidPropertyAndAnyAggregation") - void extractValueFailsForNonNumericTypesAndAggregation(Value value, String typePart, String valuePart, Aggregation aggregation) { - IllegalArgumentException exception = assertThrows( - IllegalArgumentException.class, - () -> RelationshipPropertyExtractor.extractValue(aggregation, value, 42.0) - ); - String expectedErrorMessage = formatWithLocale( - "Unsupported type [%s] of value %s. Please use a numeric property.", - typePart, - valuePart - ); - assertEquals(expectedErrorMessage, exception.getMessage()); - } - - static Stream invalidProperties() { - return Stream.of( - arguments(Values.booleanValue(true), "BOOLEAN", "Boolean('true')"), - arguments(Values.stringValue("42"), "TEXT", "String(\"42\")") - ); - } - - static Stream invalidPropertyAndAnyAggregation() { - return crossArguments(RelationshipPropertyExtractorTest::invalidProperties, () -> Arrays.stream(Aggregation.values()).map(Arguments::of)); + assertEquals(0.0, RelationshipPropertyExtractor.extractValue(aggregation, PrimitiveValues.NO_VALUE, 42.0)); } }