Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAVA-3061 Re-introduce an improved CqlVector, add support for accessing vectors directly as float arrays #1666

Merged
merged 13 commits into from
Jul 8, 2023
Merged
65 changes: 64 additions & 1 deletion core/revapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -6887,7 +6887,70 @@
"code": "java.method.removed",
"old": "method <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
"justification": "Refactoring in JAVA-3061"
}
},
{
"code": "java.class.removed",
"old": "class com.datastax.oss.driver.api.core.data.CqlVector.Builder<T>",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.removed",
"old": "method com.datastax.oss.driver.api.core.data.CqlVector.Builder com.datastax.oss.driver.api.core.data.CqlVector<T>::builder()",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.removed",
"old": "method java.lang.Iterable<T> com.datastax.oss.driver.api.core.data.CqlVector<T>::getValues()",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.generics.formalTypeParameterChanged",
"old": "class com.datastax.oss.driver.api.core.data.CqlVector<T>",
"new": "class com.datastax.oss.driver.api.core.data.CqlVector<T extends java.lang.Number>",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.parameterTypeChanged",
"old": "parameter <SubtypeT> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(===com.datastax.oss.driver.api.core.type.CqlVectorType===, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
"new": "parameter <SubtypeT extends java.lang.Number> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(===com.datastax.oss.driver.api.core.type.VectorType===, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.parameterTypeParameterChanged",
"old": "parameter <SubtypeT> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.CqlVectorType, ===com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>===)",
"new": "parameter <SubtypeT extends java.lang.Number> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.VectorType, ===com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>===)",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.returnTypeTypeParametersChanged",
"old": "method <SubtypeT> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.CqlVectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
"new": "method <SubtypeT extends java.lang.Number> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.VectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.generics.formalTypeParameterChanged",
"old": "method <SubtypeT> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.CqlVectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
"new": "method <SubtypeT extends java.lang.Number> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.VectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.parameterTypeParameterChanged",
"old": "parameter <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(===com.datastax.oss.driver.api.core.type.reflect.GenericType<T>===)",
"new": "parameter <T extends java.lang.Number> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(===com.datastax.oss.driver.api.core.type.reflect.GenericType<T>===)",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.method.returnTypeTypeParametersChanged",
"old": "method <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
"new": "method <T extends java.lang.Number> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
"justification": "Refactorings in PR 1666"
},
{
"code": "java.generics.formalTypeParameterChanged",
"old": "method <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
"new": "method <T extends java.lang.Number> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
"justification": "Refactorings in PR 1666"
}
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datastax.oss.driver.api.core.data;

import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
import com.datastax.oss.driver.shaded.guava.common.base.Predicates;
import com.datastax.oss.driver.shaded.guava.common.base.Splitter;
import com.datastax.oss.driver.shaded.guava.common.collect.Iterables;
import com.datastax.oss.driver.shaded.guava.common.collect.Streams;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Representation of a vector as defined in CQL.
*
* <p>A CQL vector is a fixed-length array of non-null numeric values. These properties don't map
* cleanly to an existing class in the standard JDK Collections hierarchy so we provide this value
* object instead. Like other value object collections returned by the driver instances of this
* class are not immutable; think of these value objects as a representation of a vector stored in
* the database as an initial step in some additional computation.
*
* <p>While we don't implement any Collection APIs we do implement Iterable. We also attempt to play
* nice with the Streams API in order to better facilitate integration with data pipelines. Finally,
* where possible we've tried to make the API of this class similar to the equivalent methods on
* {@link List}.
*/
public class CqlVector<T extends Number> implements Iterable<T> {

/**
* Create a new CqlVector containing the specified values.
*
* @param vals the collection of values to wrap.
* @return a CqlVector wrapping those values
*/
public static <V extends Number> CqlVector<V> newInstance(V... vals) {

// Note that Array.asList() guarantees the return of an array which implements RandomAccess
return new CqlVector(Arrays.asList(vals));
}

/**
* Create a new CqlVector that "wraps" an existing ArrayList. Modifications to the passed
* ArrayList will also be reflected in the returned CqlVector.
hhughes marked this conversation as resolved.
Show resolved Hide resolved
*
* @param list the collection of values to wrap.
* @return a CqlVector wrapping those values
*/
public static <V extends Number> CqlVector<V> newInstance(List<V> list) {
Preconditions.checkArgument(list != null, "Input list should not be null");
return new CqlVector(list);
}

/**
* Create a new CqlVector instance from the specified string representation. Note that this method
* is intended to mirror {@link #toString()}; passing this method the output from a <code>toString
* </code> call on some CqlVector should return a CqlVector that is equal to the origin instance.
*
* @param str a String representation of a CqlVector
* @param subtypeCodec
* @return a new CqlVector built from the String representation
*/
public static <V extends Number> CqlVector<V> from(
@NonNull String str, @NonNull TypeCodec<V> subtypeCodec) {
Preconditions.checkArgument(str != null, "Cannot create CqlVector from null string");
Preconditions.checkArgument(!str.isEmpty(), "Cannot create CqlVector from empty string");
ArrayList<V> vals =
Streams.stream(Splitter.on(", ").split(str.substring(1, str.length() - 1)))
.map(subtypeCodec::parse)
.collect(Collectors.toCollection(ArrayList::new));
return new CqlVector(vals);
}

private final List<T> list;

private CqlVector(@NonNull List<T> list) {

Preconditions.checkArgument(
Iterables.all(list, Predicates.notNull()), "CqlVectors cannot contain null values");
this.list = list;
}

/**
* Retrieve the value at the specified index. Modelled after {@link List#get(int)}
*
* @param idx the index to retrieve
* @return the value at the specified index
*/
public T get(int idx) {
return list.get(idx);
}

/**
* Update the value at the specified index. Modelled after {@link List#set(int, Object)}
*
* @param idx the index to set
* @param val the new value for the specified index
* @return the old value for the specified index
*/
public T set(int idx, T val) {
return list.set(idx, val);
}

/**
* Return the size of this vector. Modelled after {@link List#size()}
*
* @return the vector size
*/
public int size() {
return this.list.size();
}

/**
* Return a CqlVector consisting of the contents of a portion of this vector. Modelled after
* {@link List#subList(int, int)}
*
* @param from the index to start from (inclusive)
* @param to the index to end on (exclusive)
* @return a new CqlVector wrapping the sublist
*/
public CqlVector<T> subVector(int from, int to) {
return new CqlVector<T>(this.list.subList(from, to));
}

/**
* Return a boolean indicating whether the vector is empty. Modelled after {@link List#isEmpty()}
*
* @return true if the list is empty, false otherwise
*/
public boolean isEmpty() {
return this.list.isEmpty();
}

/**
* Create an {@link Iterator} for this vector
*
* @return the generated iterator
*/
@Override
public Iterator<T> iterator() {
return this.list.iterator();
}

/**
* Create a {@link Stream} of the values in this vector
*
* @return the Stream instance
*/
public Stream<T> stream() {
return this.list.stream();
}

@Override
public boolean equals(Object o) {
if (o == this) {
return true;
} else if (o instanceof CqlVector) {
CqlVector that = (CqlVector) o;
return this.list.equals(that.list);
} else {
return false;
}
}

@Override
public int hashCode() {
return Objects.hash(list);
}

@Override
public String toString() {
return Iterables.toString(this.list);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ default CqlDuration getCqlDuration(@NonNull CqlIdentifier id) {
* @throws IllegalArgumentException if the id is invalid.
*/
@Nullable
default <ElementT> List<ElementT> getVector(
default <ElementT extends Number> CqlVector<ElementT> getVector(
@NonNull CqlIdentifier id, @NonNull Class<ElementT> elementsClass) {
return getVector(firstIndexOf(id), elementsClass);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,9 @@ default CqlDuration getCqlDuration(int i) {
* @throws IndexOutOfBoundsException if the index is invalid.
*/
@Nullable
default <ElementT> List<ElementT> getVector(int i, @NonNull Class<ElementT> elementsClass) {
return get(i, GenericType.listOf(elementsClass));
default <ElementT extends Number> CqlVector<ElementT> getVector(
int i, @NonNull Class<ElementT> elementsClass) {
return get(i, GenericType.vectorOf(elementsClass));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ default CqlDuration getCqlDuration(@NonNull String name) {
* @throws IllegalArgumentException if the name is invalid.
*/
@Nullable
default <ElementT> List<ElementT> getVector(
default <ElementT extends Number> CqlVector<ElementT> getVector(
@NonNull String name, @NonNull Class<ElementT> elementsClass) {
return getList(firstIndexOf(name), elementsClass);
return getVector(firstIndexOf(name), elementsClass);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ default SelfT setCqlDuration(@NonNull CqlIdentifier id, @Nullable CqlDuration v)
*/
@NonNull
@CheckReturnValue
default <ElementT> SelfT setVector(
default <ElementT extends Number> SelfT setVector(
@NonNull CqlIdentifier id,
@Nullable List<ElementT> v,
@Nullable CqlVector<ElementT> v,
@NonNull Class<ElementT> elementsClass) {
SelfT result = null;
for (Integer i : allIndicesOf(id)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,9 @@ default SelfT setCqlDuration(int i, @Nullable CqlDuration v) {
*/
@NonNull
@CheckReturnValue
default <ElementT> SelfT setVector(
int i, @Nullable List<ElementT> v, @NonNull Class<ElementT> elementsClass) {
return set(i, v, GenericType.listOf(elementsClass));
default <ElementT extends Number> SelfT setVector(
int i, @Nullable CqlVector<ElementT> v, @NonNull Class<ElementT> elementsClass) {
return set(i, v, GenericType.vectorOf(elementsClass));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,10 @@ default SelfT setCqlDuration(@NonNull String name, @Nullable CqlDuration v) {
*/
@NonNull
@CheckReturnValue
default <ElementT> SelfT setVector(
@NonNull String name, @Nullable List<ElementT> v, @NonNull Class<ElementT> elementsClass) {
default <ElementT extends Number> SelfT setVector(
@NonNull String name,
@Nullable CqlVector<ElementT> v,
@NonNull Class<ElementT> elementsClass) {
SelfT result = null;
for (Integer i : allIndicesOf(name)) {
result = (result == null ? this : result).setVector(i, v, elementsClass);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
package com.datastax.oss.driver.api.core.type.codec;

import com.datastax.oss.driver.api.core.session.SessionBuilder;
import com.datastax.oss.driver.api.core.type.DataTypes;
import com.datastax.oss.driver.api.core.type.codec.registry.MutableCodecRegistry;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.datastax.oss.driver.internal.core.type.DefaultVectorType;
import com.datastax.oss.driver.internal.core.type.codec.SimpleBlobCodec;
import com.datastax.oss.driver.internal.core.type.codec.TimestampCodec;
import com.datastax.oss.driver.internal.core.type.codec.extras.OptionalCodec;
Expand All @@ -36,6 +38,7 @@
import com.datastax.oss.driver.internal.core.type.codec.extras.time.PersistentZonedTimestampCodec;
import com.datastax.oss.driver.internal.core.type.codec.extras.time.TimestampMillisCodec;
import com.datastax.oss.driver.internal.core.type.codec.extras.time.ZonedTimestampCodec;
import com.datastax.oss.driver.internal.core.type.codec.extras.vector.FloatVectorToArrayCodec;
import com.fasterxml.jackson.databind.ObjectMapper;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -479,4 +482,9 @@ public static <T> TypeCodec<T> json(
@NonNull Class<T> javaType, @NonNull ObjectMapper objectMapper) {
return new JsonCodec<>(javaType, objectMapper);
}

/** Builds a new codec that maps CQL float vectors of the specified size to an array of floats. */
public static TypeCodec<float[]> floatVectorToArray(int dimensions) {
return new FloatVectorToArrayCodec(new DefaultVectorType(DataTypes.FLOAT, dimensions));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.datastax.oss.driver.api.core.type.codec;

import com.datastax.oss.driver.api.core.data.CqlDuration;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.data.TupleValue;
import com.datastax.oss.driver.api.core.data.UdtValue;
import com.datastax.oss.driver.api.core.type.CustomType;
Expand Down Expand Up @@ -207,12 +208,17 @@ public static TypeCodec<TupleValue> tupleOf(@NonNull TupleType cqlType) {
return new TupleCodec(cqlType);
}

public static <SubtypeT> TypeCodec<List<SubtypeT>> vectorOf(
public static <SubtypeT extends Number> TypeCodec<CqlVector<SubtypeT>> vectorOf(
@NonNull VectorType type, @NonNull TypeCodec<SubtypeT> subtypeCodec) {
return new VectorCodec(
DataTypes.vectorOf(subtypeCodec.getCqlType(), type.getDimensions()), subtypeCodec);
}

public static <SubtypeT extends Number> TypeCodec<CqlVector<SubtypeT>> vectorOf(
int dimensions, @NonNull TypeCodec<SubtypeT> subtypeCodec) {
return new VectorCodec(DataTypes.vectorOf(subtypeCodec.getCqlType(), dimensions), subtypeCodec);
}

/**
* Builds a new codec that maps a CQL user defined type to the driver's {@link UdtValue}, for the
* given type definition.
Expand Down
Loading