Skip to content

Commit

Permalink
[Internal] ContainerBuilder: Fixes Builder to Set Full Text Policy (#…
Browse files Browse the repository at this point in the history
…4852)

# Pull Request Template

## Description

This PR fixes an issue in the `ContainerBuilder` that was preventing
setting the Full Text Policy in Container Properties.

## Type of change

Please delete options that are not relevant.

- [x] Bug fix (non-breaking change which fixes an issue)

## Closing issues

To automatically close an issue: closes #4820
  • Loading branch information
kundadebdatta authored Oct 25, 2024
1 parent 4a70bc3 commit 018dd20
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ public async Task<ContainerResponse> CreateIfNotExistsAsync(
containerProperties.VectorEmbeddingPolicy = this.vectorEmbeddingPolicy;
}

if (this.fullTextPolicy != null)
{
containerProperties.FullTextPolicy = this.fullTextPolicy;
}

return containerProperties;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ await databaseForVectorEmbedding.DefineContainer(containerName, partitionKeyPath
Assert.AreEqual(fullTextPaths.Count, containerSettings.IndexingPolicy.FullTextIndexes.Count());
Assert.AreEqual(fullTextPath1, containerSettings.IndexingPolicy.FullTextIndexes[0].Path);
Assert.AreEqual(fullTextPath2, containerSettings.IndexingPolicy.FullTextIndexes[1].Path);
Assert.AreEqual(fullTextPath1, containerSettings.IndexingPolicy.FullTextIndexes[2].Path);
Assert.AreEqual(fullTextPath3, containerSettings.IndexingPolicy.FullTextIndexes[2].Path);
}
finally
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace Microsoft.Azure.Cosmos.Tests.Fluent
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Fluent;
Expand Down Expand Up @@ -337,6 +338,184 @@ await containerFluentDefinitionForCreate
It.IsAny<CancellationToken>()), Times.Once);
}

[TestMethod]
public async Task ValidateFullTextPolicyAndIndexUsingContainerBuilder()
{
string defaultLanguage = "en-US", fullTextPath1 = "/fts1", fullTextPath2 = "/fts2", fullTextPath3 = "/fts3";

Collection<FullTextPath> fullTextPaths = new Collection<FullTextPath>()
{
new Cosmos.FullTextPath()
{
Path = fullTextPath1,
Language = "en-US",
},
new Cosmos.FullTextPath()
{
Path = fullTextPath2,
Language = "en-US",
},
new Cosmos.FullTextPath()
{
Path = fullTextPath3,
Language = "en-US",
},
};

Mock<ContainerResponse> mockContainerResponse = new Mock<ContainerResponse>();
mockContainerResponse
.Setup(x => x.StatusCode)
.Returns(HttpStatusCode.Created);

Mock<Database> mockContainers = new Mock<Database>();
Mock<CosmosClient> mockClient = new Mock<CosmosClient>();
mockContainers.Setup(m => m.Client).Returns(mockClient.Object);
mockContainers
.Setup(c => c.CreateContainerAsync(
It.IsAny<ContainerProperties>(),
It.IsAny<int?>(),
It.IsAny<RequestOptions>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(mockContainerResponse.Object);
mockContainers
.Setup(c => c.Id)
.Returns(Guid.NewGuid().ToString());

ContainerBuilder containerFluentDefinitionForCreate = new ContainerBuilder(
mockContainers.Object,
containerName,
partitionKey);

ContainerResponse response = await containerFluentDefinitionForCreate
.WithFullTextPolicy(
defaultLanguage: defaultLanguage,
fullTextPaths: fullTextPaths)
.Attach()
.WithIndexingPolicy()
.WithFullTextIndex()
.Path(fullTextPath1)
.Attach()
.WithFullTextIndex()
.Path(fullTextPath2)
.Attach()
.WithFullTextIndex()
.Path(fullTextPath3)
.Attach()
.Attach()
.CreateAsync();

Assert.AreEqual(HttpStatusCode.Created, response.StatusCode);
mockContainers.Verify(c => c.CreateContainerAsync(
It.Is<ContainerProperties>((settings) => settings.FullTextPolicy.FullTextPaths.Count == 3
&& fullTextPath1.Equals(settings.FullTextPolicy.FullTextPaths[0].Path)
&& "en-US".Equals(settings.FullTextPolicy.FullTextPaths[0].Language)
&& fullTextPath2.Equals(settings.FullTextPolicy.FullTextPaths[1].Path)
&& "en-US".Equals(settings.FullTextPolicy.FullTextPaths[1].Language)
&& fullTextPath3.Equals(settings.FullTextPolicy.FullTextPaths[2].Path)
&& "en-US".Equals(settings.FullTextPolicy.FullTextPaths[2].Language)
&& fullTextPath1.Equals(settings.IndexingPolicy.FullTextIndexes[0].Path)
&& fullTextPath2.Equals(settings.IndexingPolicy.FullTextIndexes[1].Path)
&& fullTextPath3.Equals(settings.IndexingPolicy.FullTextIndexes[2].Path)),
It.IsAny<int?>(),
It.IsAny<RequestOptions>(),
It.IsAny<CancellationToken>()), Times.Once);
}

[TestMethod]
public async Task ValidateVectorEmbeddingsAndIndexingPolicyUsingContainerBuilder()
{
string vector1Path = "/vector1", vector2Path = "/vector2", vector3Path = "/vector3";

Collection<Embedding> embeddings = new Collection<Embedding>()
{
new ()
{
Path = vector1Path,
DataType = VectorDataType.Int8,
DistanceFunction = DistanceFunction.DotProduct,
Dimensions = 1200,
},
new ()
{
Path = vector2Path,
DataType = VectorDataType.Uint8,
DistanceFunction = DistanceFunction.Cosine,
Dimensions = 3,
},
new ()
{
Path = vector3Path,
DataType = VectorDataType.Float32,
DistanceFunction = DistanceFunction.Euclidean,
Dimensions = 400,
},
};

Mock<ContainerResponse> mockContainerResponse = new Mock<ContainerResponse>();
mockContainerResponse
.Setup(x => x.StatusCode)
.Returns(HttpStatusCode.Created);

Mock<Database> mockContainers = new Mock<Database>();
Mock<CosmosClient> mockClient = new Mock<CosmosClient>();
mockContainers.Setup(m => m.Client).Returns(mockClient.Object);
mockContainers
.Setup(c => c.CreateContainerAsync(
It.IsAny<ContainerProperties>(),
It.IsAny<int?>(),
It.IsAny<RequestOptions>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(mockContainerResponse.Object);
mockContainers
.Setup(c => c.Id)
.Returns(Guid.NewGuid().ToString());

ContainerBuilder containerFluentDefinitionForCreate = new ContainerBuilder(
mockContainers.Object,
containerName,
partitionKey);

ContainerResponse response = await containerFluentDefinitionForCreate
.WithVectorEmbeddingPolicy(embeddings)
.Attach()
.WithIndexingPolicy()
.WithVectorIndex()
.Path(vector1Path, VectorIndexType.Flat)
.Attach()
.WithVectorIndex()
.Path(vector2Path, VectorIndexType.QuantizedFlat)
.WithQuantizationByteSize(3)
.WithVectorIndexShardKey(new string[] { "/Country" })
.Attach()
.WithVectorIndex()
.Path(vector3Path, VectorIndexType.DiskANN)
.WithQuantizationByteSize(2)
.WithIndexingSearchListSize(5)
.WithVectorIndexShardKey(new string[] { "/ZipCode" })
.Attach()
.Attach()
.CreateAsync();

Assert.AreEqual(HttpStatusCode.Created, response.StatusCode);
mockContainers.Verify(c => c.CreateContainerAsync(
It.Is<ContainerProperties>((settings) => settings.VectorEmbeddingPolicy.Embeddings.Count == 3
&& vector1Path.Equals(settings.VectorEmbeddingPolicy.Embeddings[0].Path)
&& VectorDataType.Int8.Equals(settings.VectorEmbeddingPolicy.Embeddings[0].DataType)
&& DistanceFunction.DotProduct.Equals(settings.VectorEmbeddingPolicy.Embeddings[0].DistanceFunction)
&& 1200.Equals(settings.VectorEmbeddingPolicy.Embeddings[0].Dimensions)
&& vector2Path.Equals(settings.VectorEmbeddingPolicy.Embeddings[1].Path)
&& VectorDataType.Uint8.Equals(settings.VectorEmbeddingPolicy.Embeddings[1].DataType)
&& DistanceFunction.Cosine.Equals(settings.VectorEmbeddingPolicy.Embeddings[1].DistanceFunction)
&& 3.Equals(settings.VectorEmbeddingPolicy.Embeddings[1].Dimensions)
&& vector3Path.Equals(settings.VectorEmbeddingPolicy.Embeddings[2].Path)
&& VectorDataType.Float32.Equals(settings.VectorEmbeddingPolicy.Embeddings[2].DataType)
&& DistanceFunction.Euclidean.Equals(settings.VectorEmbeddingPolicy.Embeddings[2].DistanceFunction)
&& 400.Equals(settings.VectorEmbeddingPolicy.Embeddings[2].Dimensions)),
It.IsAny<int?>(),
It.IsAny<RequestOptions>(),
It.IsAny<CancellationToken>()), Times.Once);
}

private static CosmosClientContext GetContext()
{
Mock<CosmosClientContext> cosmosClientContext = new Mock<CosmosClientContext>();
Expand Down

0 comments on commit 018dd20

Please sign in to comment.