Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAVA-3061 Re-introduce an improved CqlVector, add support for accessing vectors directly as float arrays #1666

Merged
merged 13 commits into from
Jul 8, 2023
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datastax.oss.driver.api.core.data;

import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
import com.datastax.oss.driver.shaded.guava.common.base.Predicates;
import com.datastax.oss.driver.shaded.guava.common.base.Splitter;
import com.datastax.oss.driver.shaded.guava.common.collect.Iterables;
import com.datastax.oss.driver.shaded.guava.common.collect.Streams;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Representation of a vector as defined in CQL.
*
* <p>A CQL vector is a fixed-length array of non-null numeric values. These properties don't map
* cleanly to an existing class in the standard JDK Collections hierarchy so we provide this value
* object instead. Like other value object collections returned by the driver instances of this
* class are not immutable; think of these value objects as a representation of a vector stored in
* the database as an initial step in some additional computation.
*
* <p>While we don't implement any Collection APIs we do implement Iterable. We also attempt to play
* nice with the Streams API in order to better facilitate integration with data pipelines. Finally,
* where possible we've tried to make the API of this class similar to the equivalent methods on
* {@link List}.
*/
public class CqlVector<T extends Number> implements Iterable<T> {

/**
* Create a new CqlVector containing the specified values.
*
* @param vals the collection of values to wrap.
* @return a CqlVector wrapping those values
*/
public static <V extends Number> CqlVector<V> newInstance(V... vals) {

// Note that Array.asList() guarantees the return of an array which implements RandomAccess
return new CqlVector(Arrays.asList(vals));
}

/**
* Create a new CqlVector that "wraps" an existing ArrayList. Modifications to the passed
* ArrayList will also be reflected in the returned CqlVector.
hhughes marked this conversation as resolved.
Show resolved Hide resolved
*
* @param list the collection of values to wrap.
* @return a CqlVector wrapping those values
*/
public static <V extends Number> CqlVector<V> newInstance(List<V> list) {
Preconditions.checkArgument(list != null, "Input list should not be null");
return new CqlVector(list);
}

private final List<T> list;

private CqlVector(@NonNull List<T> list) {

Preconditions.checkArgument(
Iterables.all(list, Predicates.notNull()), "CqlVectors cannot contain null values");
this.list = list;
}

/**
* Retrieve the value at the specified index. Modelled after {@link List#get(int)}
*
* @param idx the index to retrieve
* @return the value at the specified index
*/
public T get(int idx) {
return list.get(idx);
}

/**
* Update the value at the specified index. Modelled after {@link List#set(int, Object)}
*
* @param idx the index to set
* @param val the new value for the specified index
* @return the old value for the specified index
*/
public T set(int idx, T val) {
return list.set(idx, val);
}

/**
* Return the size of this vector. Modelled after {@link List#size()}
*
* @return the vector size
*/
public int size() {
return this.list.size();
}

/**
* Return a CqlVector consisting of the contents of a portion of this vector. Modelled after
* {@link List#subList(int, int)}
*
* @param from the index to start from (inclusive)
* @param to the index to end on (exclusive)
* @return a new CqlVector wrapping the sublist
*/
public CqlVector<T> subVector(int from, int to) {
return new CqlVector<T>(this.list.subList(from, to));
}

/**
* Create an {@link Iterator} for this vector
*
* @return the generated iterator
*/
@Override
public Iterator<T> iterator() {
return this.list.iterator();
}

/**
* Create a {@link Stream} of the values in this vector
*
* @return the Stream instance
*/
public Stream<T> stream() {
return this.list.stream();
}

@Override
public boolean equals(Object o) {
if (o == this) {
return true;
} else if (o instanceof CqlVector) {
CqlVector that = (CqlVector) o;
return this.list.equals(that.list);
} else {
return false;
}
}

@Override
public int hashCode() {
return Objects.hash(list);
}

@Override
public String toString() {
return Iterables.toString(this.list);
}

/**
* Create a new CqlVector instance from the specified string representation. Note that this method
* is intended to mirror {@link #toString()}; calling this method from a <code>toString</code>
* call on some CqlVector should return a CqlVector that is equal to the origin instance.
*
* @param str a String representation of a CqlVector
* @param subtypeCodec
* @return
*/
public CqlVector<T> from(@NonNull String str, @NonNull TypeCodec<T> subtypeCodec) {
ArrayList<T> vals =
Streams.stream(Splitter.on(", ").split(str.substring(1, str.length() - 1)))
.map(subtypeCodec::parse)
.collect(Collectors.toCollection(ArrayList::new));
return CqlVector.newInstance(vals);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ default CqlDuration getCqlDuration(@NonNull CqlIdentifier id) {
* @throws IllegalArgumentException if the id is invalid.
*/
@Nullable
default <ElementT> List<ElementT> getVector(
default <ElementT extends Number> CqlVector<ElementT> getVector(
@NonNull CqlIdentifier id, @NonNull Class<ElementT> elementsClass) {
return getVector(firstIndexOf(id), elementsClass);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,9 @@ default CqlDuration getCqlDuration(int i) {
* @throws IndexOutOfBoundsException if the index is invalid.
*/
@Nullable
default <ElementT> List<ElementT> getVector(int i, @NonNull Class<ElementT> elementsClass) {
return get(i, GenericType.listOf(elementsClass));
default <ElementT extends Number> CqlVector<ElementT> getVector(
int i, @NonNull Class<ElementT> elementsClass) {
return get(i, GenericType.vectorOf(elementsClass));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ default CqlDuration getCqlDuration(@NonNull String name) {
* @throws IllegalArgumentException if the name is invalid.
*/
@Nullable
default <ElementT> List<ElementT> getVector(
default <ElementT extends Number> CqlVector<ElementT> getVector(
@NonNull String name, @NonNull Class<ElementT> elementsClass) {
return getList(firstIndexOf(name), elementsClass);
return getVector(firstIndexOf(name), elementsClass);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ default SelfT setCqlDuration(@NonNull CqlIdentifier id, @Nullable CqlDuration v)
*/
@NonNull
@CheckReturnValue
default <ElementT> SelfT setVector(
default <ElementT extends Number> SelfT setVector(
@NonNull CqlIdentifier id,
@Nullable List<ElementT> v,
@Nullable CqlVector<ElementT> v,
@NonNull Class<ElementT> elementsClass) {
SelfT result = null;
for (Integer i : allIndicesOf(id)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,9 @@ default SelfT setCqlDuration(int i, @Nullable CqlDuration v) {
*/
@NonNull
@CheckReturnValue
default <ElementT> SelfT setVector(
int i, @Nullable List<ElementT> v, @NonNull Class<ElementT> elementsClass) {
return set(i, v, GenericType.listOf(elementsClass));
default <ElementT extends Number> SelfT setVector(
int i, @Nullable CqlVector<ElementT> v, @NonNull Class<ElementT> elementsClass) {
return set(i, v, GenericType.vectorOf(elementsClass));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,10 @@ default SelfT setCqlDuration(@NonNull String name, @Nullable CqlDuration v) {
*/
@NonNull
@CheckReturnValue
default <ElementT> SelfT setVector(
@NonNull String name, @Nullable List<ElementT> v, @NonNull Class<ElementT> elementsClass) {
default <ElementT extends Number> SelfT setVector(
@NonNull String name,
@Nullable CqlVector<ElementT> v,
@NonNull Class<ElementT> elementsClass) {
SelfT result = null;
for (Integer i : allIndicesOf(name)) {
result = (result == null ? this : result).setVector(i, v, elementsClass);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.datastax.oss.driver.api.core.type.codec;

import com.datastax.oss.driver.api.core.data.CqlDuration;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.data.TupleValue;
import com.datastax.oss.driver.api.core.data.UdtValue;
import com.datastax.oss.driver.api.core.type.CustomType;
Expand Down Expand Up @@ -207,7 +208,7 @@ public static TypeCodec<TupleValue> tupleOf(@NonNull TupleType cqlType) {
return new TupleCodec(cqlType);
}

public static <SubtypeT> TypeCodec<List<SubtypeT>> vectorOf(
public static <SubtypeT extends Number> TypeCodec<CqlVector<SubtypeT>> vectorOf(
@NonNull VectorType type, @NonNull TypeCodec<SubtypeT> subtypeCodec) {
return new VectorCodec(
DataTypes.vectorOf(subtypeCodec.getCqlType(), type.getDimensions()), subtypeCodec);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
*/
package com.datastax.oss.driver.api.core.type.reflect;

import com.datastax.oss.driver.api.core.data.CqlDuration;
import com.datastax.oss.driver.api.core.data.GettableByIndex;
import com.datastax.oss.driver.api.core.data.TupleValue;
import com.datastax.oss.driver.api.core.data.UdtValue;
import com.datastax.oss.driver.api.core.data.*;
absurdfarce marked this conversation as resolved.
Show resolved Hide resolved
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.shaded.guava.common.primitives.Primitives;
import com.datastax.oss.driver.shaded.guava.common.reflect.TypeParameter;
Expand Down Expand Up @@ -147,6 +144,23 @@ public static <T> GenericType<Set<T>> setOf(@NonNull GenericType<T> elementType)
return new GenericType<>(token);
}

@NonNull
public static <T extends Number> GenericType<CqlVector<T>> vectorOf(
@NonNull Class<T> elementType) {
TypeToken<CqlVector<T>> token =
new TypeToken<CqlVector<T>>() {}.where(
new TypeParameter<T>() {}, TypeToken.of(elementType));
return new GenericType<>(token);
}

@NonNull
public static <T extends Number> GenericType<CqlVector<T>> vectorOf(
@NonNull GenericType<T> elementType) {
TypeToken<CqlVector<T>> token =
new TypeToken<CqlVector<T>>() {}.where(new TypeParameter<T>() {}, elementType.token);
return new GenericType<>(token);
}

@NonNull
public static <K, V> GenericType<Map<K, V>> mapOf(
@NonNull Class<K> keyType, @NonNull Class<V> valueType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.datastax.oss.driver.internal.core.type.codec;

import com.datastax.oss.driver.api.core.ProtocolVersion;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.core.type.VectorType;
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
Expand All @@ -32,21 +33,21 @@
import java.util.NoSuchElementException;
import java.util.stream.Collectors;

public class VectorCodec<SubtypeT> implements TypeCodec<List<SubtypeT>> {
public class VectorCodec<SubtypeT extends Number> implements TypeCodec<CqlVector<SubtypeT>> {

private final VectorType cqlType;
private final GenericType<List<SubtypeT>> javaType;
private final GenericType<CqlVector<SubtypeT>> javaType;
private final TypeCodec<SubtypeT> subtypeCodec;

public VectorCodec(VectorType cqlType, TypeCodec<SubtypeT> subtypeCodec) {
this.cqlType = cqlType;
this.subtypeCodec = subtypeCodec;
this.javaType = GenericType.listOf(subtypeCodec.getJavaType());
this.javaType = GenericType.vectorOf(subtypeCodec.getJavaType());
}

@NonNull
@Override
public GenericType<List<SubtypeT>> getJavaType() {
public GenericType<CqlVector<SubtypeT>> getJavaType() {
return this.javaType;
}

Expand All @@ -59,7 +60,7 @@ public DataType getCqlType() {
@Nullable
@Override
public ByteBuffer encode(
@Nullable List<SubtypeT> value, @NonNull ProtocolVersion protocolVersion) {
@Nullable CqlVector<SubtypeT> value, @NonNull ProtocolVersion protocolVersion) {
if (value == null || cqlType.getDimensions() <= 0) {
return null;
}
Expand Down Expand Up @@ -103,7 +104,7 @@ public ByteBuffer encode(

@Nullable
@Override
public List<SubtypeT> decode(
public CqlVector<SubtypeT> decode(
@Nullable ByteBuffer bytes, @NonNull ProtocolVersion protocolVersion) {
if (bytes == null || bytes.remaining() == 0) {
return null;
Expand Down Expand Up @@ -133,27 +134,29 @@ Elements should at least precede themselves with their size (along the lines of
/* Restore the input ByteBuffer to its original state */
bytes.rewind();

return rv;
return CqlVector.newInstance(rv);
}

@NonNull
@Override
public String format(@Nullable List<SubtypeT> value) {
public String format(@Nullable CqlVector<SubtypeT> value) {
return value == null ? "NULL" : Iterables.toString(value);
}

@Nullable
@Override
public List<SubtypeT> parse(@Nullable String value) {
public CqlVector<SubtypeT> parse(@Nullable String value) {
return (value == null || value.isEmpty() || value.equalsIgnoreCase("NULL"))
? null
: this.from(value);
}

private List<SubtypeT> from(@Nullable String value) {
private CqlVector<SubtypeT> from(@Nullable String value) {

return Streams.stream(Splitter.on(", ").split(value.substring(1, value.length() - 1)))
.map(subtypeCodec::parse)
.collect(Collectors.toCollection(ArrayList::new));
ArrayList<SubtypeT> vals =
Streams.stream(Splitter.on(", ").split(value.substring(1, value.length() - 1)))
.map(subtypeCodec::parse)
.collect(Collectors.toCollection(ArrayList::new));
return CqlVector.newInstance(vals);
}
}
Loading