Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
Signed-off-by: Yury-Fridlyand <[email protected]>
  • Loading branch information
Yury-Fridlyand committed Sep 26, 2024
1 parent 9e3ce53 commit 7fe92e7
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ public interface VectorSearchBaseCommands {
* @return <code>OK</code>.
* @example
* <pre>{@code
* // TODO
* // Create an index for vectors of size 2:
* client.ftcreate("hash_idx1", IndexType.HASH, new String[] {"hash:"}, new FieldInfo[] {
* new FieldInfo("vec", "VEC", VectorFieldFlat.builder(DistanceMetric.L2, 2).build())
* }).get();
* // Create a 6-dimensional JSON index using the HNSW algorithm:
* client.ftcreate("json_idx1", IndexType.JSON, new String[] {"json:"}, new FieldInfo[] {
* new FieldInfo("$.vec", "VEC", VectorFieldHnsw.builder(DistanceMetric.L2, 6).numberOfEdges(32).build())
* }).get();
* }</pre>
*/
CompletableFuture<String> ftcreate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import java.util.Optional;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor;

// TODO examples
public class FTCreateOptions {
/** Type of the index dataset. */
public enum IndexType {
/** Data stored in hashes, so field identifiers are field names within the hashes. */
HASH,
/** Data stored in JSONs, so field identifiers are JSON Path expressions. */
JSON
}

Expand Down Expand Up @@ -60,6 +62,22 @@ public static class TagField implements Field {
private Optional<Character> separator;
private final boolean caseSensitive;

/** Create a <code>TAG</code> field. */
public TagField() {
this.separator = Optional.empty();
this.caseSensitive = false;
}

/**
* Create a <code>TAG</code> field.
*
* @param separator The tag separator.
*/
public TagField(char separator) {
this.separator = Optional.of(separator);
this.caseSensitive = false;
}

/**
* Create a <code>TAG</code> field.
*
Expand Down Expand Up @@ -95,24 +113,35 @@ public String[] toArgs() {
}
}

/** Vector index algorithm. */
public enum Algorithm {
/** Hierarchical Navigable Small World */
/**
* Hierarchical Navigable Small World provides an approximation of nearest neighbors algorithm
* that uses a multi-layered graph.
*/
HNSW,
/**
* The Flat algorithm is a brute force linear processing of each vector in the index, yielding
* exact answers within the bounds of the precision of the distance computations. Because of the
* linear processing of the index, run times for this algorithm can be very high for large
* indexes.
* exact answers within the bounds of the precision of the distance computations.
*/
FLAT
}

/**
* Distance metrics to measure the degree of similarity between two vectors.<br>
* The above metrics calculate distance between two vectors, where the smaller the value is, the
* closer the two vectors are in the vector space.
*/
public enum DistanceMetric {
/** Euclidean distance between two vectors. */
L2,
/** Inner product of two vectors. */
IP,
/** Cosine distance of two vectors. */
COSINE
}

/** Superclass for vector field implementations, contains common logic. */
@AllArgsConstructor(access = AccessLevel.PROTECTED)
abstract static class VectorField implements Field {
private final Map<String, String> params;
Expand All @@ -123,7 +152,7 @@ public String[] toArgs() {
var args = new ArrayList<String>();
args.add("VECTOR");
args.add(Algorithm);
args.add(Integer.toString(params.size()));
args.add(Integer.toString(params.size() * 2));
params.forEach(
(name, value) -> {
args.add(name);
Expand All @@ -147,7 +176,8 @@ protected VectorFieldHnsw(Map<String, String> params) {
/**
* Init a {@link VectorFieldHnsw}'s builder.
*
* @param distanceMetric {@link DistanceMetric}
* @param distanceMetric {@link DistanceMetric} to measure the degree of similarity between two
* vectors.
* @param dimensions Vector dimension, specified as a positive integer. Maximum: 32768
*/
public static VectorFieldHnswBuilder builder(
Expand All @@ -157,7 +187,7 @@ public static VectorFieldHnswBuilder builder(
}

public static class VectorFieldHnswBuilder extends VectorFieldBuilder<VectorFieldHnswBuilder> {
public VectorFieldHnswBuilder(DistanceMetric distanceMetric, int dimensions) {
VectorFieldHnswBuilder(DistanceMetric distanceMetric, int dimensions) {
super(distanceMetric, dimensions);
}

Expand Down Expand Up @@ -210,7 +240,8 @@ protected VectorFieldFlat(Map<String, String> params) {
/**
* Init a {@link VectorFieldFlat}'s builder.
*
* @param distanceMetric {@link DistanceMetric}
* @param distanceMetric {@link DistanceMetric} to measure the degree of similarity between two
* vectors.
* @param dimensions Vector dimension, specified as a positive integer. Maximum: 32768
*/
public static VectorFieldFlatBuilder builder(
Expand All @@ -220,7 +251,7 @@ public static VectorFieldFlatBuilder builder(
}

public static class VectorFieldFlatBuilder extends VectorFieldBuilder<VectorFieldFlatBuilder> {
public VectorFieldFlatBuilder(DistanceMetric distanceMetric, int dimensions) {
VectorFieldFlatBuilder(DistanceMetric distanceMetric, int dimensions) {
super(distanceMetric, dimensions);
}

Expand All @@ -233,7 +264,7 @@ public VectorFieldFlat build() {
abstract static class VectorFieldBuilder<T extends VectorFieldBuilder<T>> {
protected final Map<String, String> params = new HashMap<>();

public VectorFieldBuilder(DistanceMetric distanceMetric, int dimensions) {
VectorFieldBuilder(DistanceMetric distanceMetric, int dimensions) {
params.put("TYPE", "FLOAT32");
params.put("DIM", Integer.toString(dimensions));
params.put("DISTANCE_METRIC", distanceMetric.toString());
Expand All @@ -252,18 +283,37 @@ public T initialCapacity(int initialCapacity) {
public abstract VectorField build();
}

@RequiredArgsConstructor
/** Field definition to be added into index schema. */
public static class FieldInfo {
private final String identifier;
private final String alias;
private final Field field;

/**
* Field definition to be added into index schema.
*
* @param identifier Field identifier (name).
* @param field The {@link Field} itself.
*/
public FieldInfo(String identifier, Field field) {
this.identifier = identifier;
this.field = field;
this.alias = null;
}

/**
* Field definition to be added into index schema.
*
* @param identifier Field identifier (name).
* @param alias Field alias.
* @param field The {@link Field} itself.
*/
public FieldInfo(String identifier, String alias, Field field) {
this.identifier = identifier;
this.alias = alias;
this.field = field;
}

/** Convert to module API. */
public String[] toArgs() {
var args = new ArrayList<String>();
Expand Down
19 changes: 18 additions & 1 deletion java/integTest/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ tasks.register('startStandalone') {
test.dependsOn 'stopAllBeforeTests'
stopAllBeforeTests.finalizedBy 'clearDirs'
clearDirs.finalizedBy 'startStandalone'
clearDirs.finalizedBy 'startCluster'
// clearDirs.finalizedBy 'startCluster'
test.finalizedBy 'stopAllAfterTests'
test.dependsOn ':client:buildRustRelease'

Expand All @@ -122,3 +122,20 @@ tasks.withType(Test) {
logger.quiet "${desc.className}.${desc.name}: ${result.resultType} ${(result.getEndTime() - result.getStartTime())/1000}s"
}
}

test {
filter {
excludeTestsMatching 'glide.modules.*'
}
}

tasks.register('modulesTest', Test) {
doFirst {
systemProperty 'test.server.standalone.ports', 6379
systemProperty 'test.server.cluster.ports', 7000
}

filter {
includeTestsMatching 'glide.modules.*'
}
}
170 changes: 170 additions & 0 deletions java/integTest/src/test/java/glide/modules/VectorSearchTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/** Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 */
package glide.modules;

import static glide.TestUtilities.commonClientConfig;
import static glide.TestUtilities.commonClusterClientConfig;
import static glide.api.BaseClient.OK;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import glide.api.BaseClient;
import glide.api.GlideClient;
import glide.api.GlideClusterClient;
import glide.api.models.commands.vss.FTCreateOptions.DistanceMetric;
import glide.api.models.commands.vss.FTCreateOptions.FieldInfo;
import glide.api.models.commands.vss.FTCreateOptions.IndexType;
import glide.api.models.commands.vss.FTCreateOptions.NumericField;
import glide.api.models.commands.vss.FTCreateOptions.TagField;
import glide.api.models.commands.vss.FTCreateOptions.TextField;
import glide.api.models.commands.vss.FTCreateOptions.VectorFieldFlat;
import glide.api.models.commands.vss.FTCreateOptions.VectorFieldHnsw;
import glide.api.models.exceptions.RequestException;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import lombok.Getter;
import lombok.SneakyThrows;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

public class VectorSearchTests {

@Getter private static List<Arguments> clients;

@BeforeAll
@SneakyThrows
public static void init() {
var standaloneClient =
GlideClient.createClient(commonClientConfig().requestTimeout(5000).build()).get();

var clusterClient =
GlideClusterClient.createClient(commonClusterClientConfig().requestTimeout(5000).build())
.get();

clients = List.of(Arguments.of(standaloneClient), Arguments.of(clusterClient));
}

@AfterAll
@SneakyThrows
public static void teardown() {
for (var client : clients) {
((BaseClient) client.get()[0]).close();
}
}

@SneakyThrows
@ParameterizedTest(autoCloseArguments = false)
@MethodSource("getClients")
public void ft_create(BaseClient client) {
// create few simple indices
assertEquals(
OK,
client
.ftcreate(
UUID.randomUUID().toString(),
IndexType.HASH,
new String[0],
new FieldInfo[] {
new FieldInfo("vec", "vec", VectorFieldHnsw.builder(DistanceMetric.L2, 2).build())
})
.get());
assertEquals(
OK,
client
.ftcreate(
UUID.randomUUID().toString(),
IndexType.JSON,
new String[] {"json:"},
new FieldInfo[] {
new FieldInfo(
"$.vec", "VEC", VectorFieldFlat.builder(DistanceMetric.L2, 6).build())
})
.get());

// create an index with NSFW vector with additional parameters
assertEquals(
OK,
client
.ftcreate(
UUID.randomUUID().toString(),
IndexType.HASH,
new String[] {"docs:"},
new FieldInfo[] {
new FieldInfo(
"doc_embedding",
VectorFieldHnsw.builder(DistanceMetric.COSINE, 1536)
.numberOfEdges(40)
.vectorsExaminedOnConstruction(250)
.vectorsExaminedOnRuntime(40)
.build())
})
.get());

// create an index with multiple fields
assertEquals(
OK,
client
.ftcreate(
UUID.randomUUID().toString(),
IndexType.HASH,
new String[] {"blog:post:"},
new FieldInfo[] {
new FieldInfo("title", new TextField()),
new FieldInfo("published_at", new NumericField()),
new FieldInfo("category", new TagField())
})
.get());

// create an index with multiple prefixes
var name = UUID.randomUUID().toString();
assertEquals(
OK,
client
.ftcreate(
name,
IndexType.HASH,
new String[] {"author:details:", "book:details:"},
new FieldInfo[] {
new FieldInfo("author_id", new TagField()),
new FieldInfo("author_ids", new TagField()),
new FieldInfo("title", new TextField()),
new FieldInfo("name", new TextField())
})
.get());

// create a duplicating index
var exception =
assertThrows(
ExecutionException.class,
() ->
client
.ftcreate(
name,
IndexType.HASH,
new String[0],
new FieldInfo[] {new FieldInfo("title", new TextField())})
.get());
assertInstanceOf(RequestException.class, exception.getCause());
assertTrue(exception.getMessage().contains("already exists"));

// create an index without fields
exception =
assertThrows(
ExecutionException.class,
() ->
client
.ftcreate(
UUID.randomUUID().toString(),
IndexType.HASH,
new String[0],
new FieldInfo[0])
.get());
assertInstanceOf(RequestException.class, exception.getCause());
assertTrue(exception.getMessage().contains("arguments are missing"));
}
}
2 changes: 0 additions & 2 deletions utils/cluster_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,6 @@ def get_server_command() -> str:
"yes",
"--logfile",
f"{node_folder}/redis.log",
"--protected-mode",
"no"
]
if load_module:
if len(load_module) == 0:
Expand Down

0 comments on commit 7fe92e7

Please sign in to comment.