Skip to content

Commit

Permalink
Check index once per unique index name during save embeddings. (#387)
Browse files Browse the repository at this point in the history
## Motivation and Context (Why the change? What's the scenario?)

I had the same problem as the submitter of this issue:
#289. For every
embedding during the save process, a call is made to list indexes. Using
Azure AI Search, I was getting rate limit errors, i.e.
`Azure.RequestFailedException with the message "You are sending too many
requests. Please try again later."`

+ Fix Qdrant connector to handle missing collections similarly to other DBs

## High level description (Approach, Design)

Made an adjustment to simply cache indexes that were already
checked/created during the save embeddings process. This skips the
unnecessary call to ListIndexes for every single embedding.

---------

Co-authored-by: spenavagd <[email protected]>
Co-authored-by: Devis Lucato <[email protected]>
  • Loading branch information
3 people authored Apr 25, 2024
1 parent c631a64 commit c0b2eef
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 55 deletions.
3 changes: 3 additions & 0 deletions examples/003-dotnet-Serverless/003-dotnet-Serverless.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
<ItemGroup>
<ProjectReference Include="..\..\service\Core\Core.csproj" />
<ProjectReference Include="..\..\extensions\LlamaSharp\LlamaSharp\LlamaSharp.csproj" />

<PackageReference Include="Microsoft.KernelMemory.MemoryDb.Postgres" Version="0.38.240423.1" Condition="'$(SolutionName)' != 'KernelMemoryDev'" />
<ProjectReference Include="..\..\extensions\Postgres\Postgres\Postgres.csproj" Condition="'$(SolutionName)' == 'KernelMemoryDev'" />
</ItemGroup>

<ItemGroup>
Expand Down
4 changes: 4 additions & 0 deletions examples/003-dotnet-Serverless/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.KernelMemory.AI.OpenAI;
using Microsoft.KernelMemory.ContentStorage.DevTools;
using Microsoft.KernelMemory.MemoryStorage.DevTools;
using Microsoft.KernelMemory.Postgres;

/* Use MemoryServerlessClient to run the default import pipeline
* in the same process, without distributed queues.
Expand All @@ -36,6 +37,7 @@ public static async Task Main()
var searchClientConfig = new SearchClientConfig();
var azDocIntelConfig = new AzureAIDocIntelConfig();
var azureAISearchConfig = new AzureAISearchConfig();
var postgresConfig = new PostgresConfig();

new ConfigurationBuilder()
.AddJsonFile("appsettings.json")
Expand All @@ -47,6 +49,7 @@ public static async Task Main()
.BindSection("KernelMemory:Services:LlamaSharp", llamaConfig)
.BindSection("KernelMemory:Services:AzureAIDocIntel", azDocIntelConfig)
.BindSection("KernelMemory:Services:AzureAISearch", azureAISearchConfig)
.BindSection("KernelMemory:Services:Postgres", postgresConfig)
.BindSection("KernelMemory:Retrieval:SearchClient", searchClientConfig);

s_memory = new KernelMemoryBuilder()
Expand All @@ -60,6 +63,7 @@ public static async Task Main()
// .WithAzureBlobsStorage(new AzureBlobsConfig {...}) // Store files in Azure Blobs
// .WithSimpleVectorDb(SimpleVectorDbConfig.Persistent) // Store memories on disk
// .WithAzureAISearchMemoryDb(azureAISearchConfig) // Store memories in Azure AI Search
// .WithPostgresMemoryDb(postgresConfig) // Store memories in Postgres
// .WithQdrantMemoryDb("http://127.0.0.1:6333") // Store memories in Qdrant
.Build<MemoryServerless>();

Expand Down
22 changes: 18 additions & 4 deletions extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,17 @@ public async Task<string> UpsertAsync(string index, MemoryRecord record, Cancell
var client = this.GetSearchClient(index);
AzureAISearchMemoryRecord localRecord = AzureAISearchMemoryRecord.FromMemoryRecord(record);

await client.IndexDocumentsAsync(
IndexDocumentsBatch.Upload(new[] { localRecord }),
new IndexDocumentsOptions { ThrowOnAnyError = true },
cancellationToken: cancellationToken).ConfigureAwait(false);
try
{
await client.IndexDocumentsAsync(
IndexDocumentsBatch.Upload(new[] { localRecord }),
new IndexDocumentsOptions { ThrowOnAnyError = true },
cancellationToken: cancellationToken).ConfigureAwait(false);
}
catch (RequestFailedException e) when (IsIndexNotFoundException(e))
{
throw new IndexNotFound(e.Message, e);
}

return record.Id;
}
Expand Down Expand Up @@ -400,6 +407,13 @@ private SearchClient GetSearchClient(string index)
return client;
}

private static bool IsIndexNotFoundException(RequestFailedException e)
{
return e.Status == 404
&& e.Message.Contains("index", StringComparison.OrdinalIgnoreCase)
&& e.Message.Contains("not found", StringComparison.OrdinalIgnoreCase);
}

private static void ValidateSchema(MemoryDbSchema schema)
{
schema.Validate(vectorSizeRequired: true);
Expand Down
42 changes: 27 additions & 15 deletions extensions/Postgres/Postgres/Db/PostgresDbClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.MemoryStorage;
using Npgsql;
using NpgsqlTypes;
using Pgvector;
Expand Down Expand Up @@ -289,7 +290,7 @@ public async Task DeleteTableAsync(
this._log.LogTrace("Deleting table. SQL: {0}", cmd.CommandText);
await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
catch (Npgsql.PostgresException e) when (IsNotFoundException(e))
catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
{
this._log.LogTrace("Table not found: {0}", tableName);
}
Expand All @@ -315,12 +316,14 @@ public async Task UpsertAsync(

NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);

await using (connection)
try
{
using NpgsqlCommand cmd = connection.CreateCommand();
await using (connection)
{
using NpgsqlCommand cmd = connection.CreateCommand();

#pragma warning disable CA2100 // SQL reviewed
cmd.CommandText = $@"
cmd.CommandText = $@"
INSERT INTO {tableName}
({this._colId}, {this._colEmbedding}, {this._colTags}, {this._colContent}, {this._colPayload})
VALUES
Expand All @@ -333,16 +336,25 @@ DO UPDATE SET
{this._colPayload} = @payload
";

cmd.Parameters.AddWithValue("@id", record.Id);
cmd.Parameters.AddWithValue("@embedding", record.Embedding);
cmd.Parameters.AddWithValue("@tags", NpgsqlDbType.Array | NpgsqlDbType.Text, record.Tags.ToArray() ?? emptyTags);
cmd.Parameters.AddWithValue("@content", NpgsqlDbType.Text, record.Content ?? EmptyContent);
cmd.Parameters.AddWithValue("@payload", NpgsqlDbType.Jsonb, record.Payload ?? EmptyPayload);
cmd.Parameters.AddWithValue("@id", record.Id);
cmd.Parameters.AddWithValue("@embedding", record.Embedding);
cmd.Parameters.AddWithValue("@tags", NpgsqlDbType.Array | NpgsqlDbType.Text, record.Tags.ToArray() ?? emptyTags);
cmd.Parameters.AddWithValue("@content", NpgsqlDbType.Text, record.Content ?? EmptyContent);
cmd.Parameters.AddWithValue("@payload", NpgsqlDbType.Jsonb, record.Payload ?? EmptyPayload);
#pragma warning restore CA2100

this._log.LogTrace("Upserting record '{0}' in table '{1}'", record.Id, tableName);
this._log.LogTrace("Upserting record '{0}' in table '{1}'", record.Id, tableName);

await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
}
catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
{
throw new IndexNotFound(e.Message, e);
}
catch (Exception e)
{
throw new PostgresException(e.Message, e);
}
}

Expand Down Expand Up @@ -436,7 +448,7 @@ OFFSET @offset
result.Add((this.ReadEntry(dataReader, withEmbeddings), similarity));
}
}
catch (Npgsql.PostgresException e) when (IsNotFoundException(e))
catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
{
this._log.LogTrace("Table not found: {0}", tableName);
}
Expand Down Expand Up @@ -530,7 +542,7 @@ OFFSET @offset
result.Add(this.ReadEntry(dataReader, withEmbeddings));
}
}
catch (Npgsql.PostgresException e) when (IsNotFoundException(e))
catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
{
this._log.LogTrace("Table not found: {0}", tableName);
}
Expand Down Expand Up @@ -572,7 +584,7 @@ public async Task DeleteAsync(
{
await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
catch (Npgsql.PostgresException e) when (IsNotFoundException(e))
catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e))
{
this._log.LogTrace("Table not found: {0}", tableName);
}
Expand Down Expand Up @@ -637,7 +649,7 @@ private string WithTableNamePrefix(string tableName)
return $"{this._tableNamePrefix}{tableName}";
}

private static bool IsNotFoundException(Npgsql.PostgresException e)
private static bool IsTableNotFoundException(Npgsql.PostgresException e)
{
return (e.SqlState == PgErrUndefinedTable || e.Message.Contains("does not exist", StringComparison.OrdinalIgnoreCase));
}
Expand Down
13 changes: 12 additions & 1 deletion extensions/Qdrant/Qdrant/Client/QdrantClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http;
using Microsoft.KernelMemory.MemoryStorage;

namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client;

Expand Down Expand Up @@ -177,7 +178,11 @@ public async Task UpsertVectorsAsync(
var (response, content) = await this.ExecuteHttpRequestAsync(request, cancellationToken).ConfigureAwait(false);
this.ValidateResponse(response, content, nameof(this.UpsertVectorsAsync));

if (JsonSerializer.Deserialize<UpsertVectorResponse>(content)?.Status != "ok")
UpsertVectorResponse? qdrantResponse = JsonSerializer.Deserialize<UpsertVectorResponse>(content);
ArgumentNullExceptionEx.ThrowIfNull(qdrantResponse, nameof(qdrantResponse), "Qdrant response is NULL");
ArgumentNullExceptionEx.ThrowIfNull(qdrantResponse.Status, nameof(qdrantResponse.Status), "Qdrant response status is NULL");

if (qdrantResponse.Status != "ok")
{
this._log.LogWarning("Vector upserts failed");
}
Expand Down Expand Up @@ -423,6 +428,12 @@ private void ValidateResponse(HttpResponseMessage response, string content, stri
}
else
{
if (response.StatusCode == HttpStatusCode.NotFound && responseContent.Contains("Not found: Collection", StringComparison.OrdinalIgnoreCase))
{
this._log.LogWarning("Qdrant collection not found: {0}, {1}", response.StatusCode, responseContent);
throw new IndexNotFound(responseContent);
}

if (!responseContent.Contains("already exists", StringComparison.OrdinalIgnoreCase))
{
this._log.LogWarning("Qdrant responded with error: {0}, {1}", response.StatusCode, responseContent);
Expand Down
89 changes: 63 additions & 26 deletions extensions/Qdrant/Qdrant/QdrantMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,17 @@ public Task DeleteIndexAsync(
string index,
CancellationToken cancellationToken = default)
{
index = NormalizeIndexName(index);
return this._qdrantClient.DeleteCollectionAsync(index, cancellationToken);
try
{
index = NormalizeIndexName(index);
return this._qdrantClient.DeleteCollectionAsync(index, cancellationToken);
}
catch (IndexNotFound)
{
this._log.LogInformation("Index not found, nothing to delete");
}

return Task.CompletedTask;
}

/// <inheritdoc />
Expand Down Expand Up @@ -139,14 +148,25 @@ public async Task<string> UpsertAsync(
}

Embedding textEmbedding = await this._embeddingGenerator.GenerateEmbeddingAsync(text, cancellationToken).ConfigureAwait(false);
List<(QdrantPoint<DefaultQdrantPayload>, double)> results = await this._qdrantClient.GetSimilarListAsync(
collectionName: index,
target: textEmbedding,
scoreThreshold: minRelevance,
requiredTags: requiredTags,
limit: limit,
withVectors: withEmbeddings,
cancellationToken: cancellationToken).ConfigureAwait(false);

List<(QdrantPoint<DefaultQdrantPayload>, double)> results;
try
{
results = await this._qdrantClient.GetSimilarListAsync(
collectionName: index,
target: textEmbedding,
scoreThreshold: minRelevance,
requiredTags: requiredTags,
limit: limit,
withVectors: withEmbeddings,
cancellationToken: cancellationToken).ConfigureAwait(false);
}
catch (IndexNotFound e)
{
this._log.LogWarning(e, "Index not found");
// Nothing to return
yield break;
}

foreach (var point in results)
{
Expand Down Expand Up @@ -174,13 +194,23 @@ public async IAsyncEnumerable<MemoryRecord> GetListAsync(
requiredTags.AddRange(filters.Select(filter => filter.GetFilters().Select(x => $"{x.Key}{Constants.ReservedEqualsChar}{x.Value}")));
}

List<QdrantPoint<DefaultQdrantPayload>> results = await this._qdrantClient.GetListAsync(
collectionName: index,
requiredTags: requiredTags,
offset: 0,
limit: limit,
withVectors: withEmbeddings,
cancellationToken: cancellationToken).ConfigureAwait(false);
List<QdrantPoint<DefaultQdrantPayload>> results;
try
{
results = await this._qdrantClient.GetListAsync(
collectionName: index,
requiredTags: requiredTags,
offset: 0,
limit: limit,
withVectors: withEmbeddings,
cancellationToken: cancellationToken).ConfigureAwait(false);
}
catch (IndexNotFound e)
{
this._log.LogWarning(e, "Index not found");
// Nothing to return
yield break;
}

foreach (var point in results)
{
Expand All @@ -196,17 +226,24 @@ public async Task DeleteAsync(
{
index = NormalizeIndexName(index);

QdrantPoint<DefaultQdrantPayload>? existingPoint = await this._qdrantClient
.GetVectorByPayloadIdAsync(index, record.Id, cancellationToken: cancellationToken)
.ConfigureAwait(false);
if (existingPoint == null)
try
{
this._log.LogTrace("No record with ID {0} found, nothing to delete", record.Id);
return;
}
QdrantPoint<DefaultQdrantPayload>? existingPoint = await this._qdrantClient
.GetVectorByPayloadIdAsync(index, record.Id, cancellationToken: cancellationToken)
.ConfigureAwait(false);
if (existingPoint == null)
{
this._log.LogTrace("No record with ID {0} found, nothing to delete", record.Id);
return;
}

this._log.LogTrace("Point ID {0} found, deleting...", existingPoint.Id);
await this._qdrantClient.DeleteVectorsAsync(index, new List<Guid> { existingPoint.Id }, cancellationToken).ConfigureAwait(false);
this._log.LogTrace("Point ID {0} found, deleting...", existingPoint.Id);
await this._qdrantClient.DeleteVectorsAsync(index, new List<Guid> { existingPoint.Id }, cancellationToken).ConfigureAwait(false);
}
catch (IndexNotFound e)
{
this._log.LogInformation(e, "Index not found, nothing to delete");
}
}

#region private ================================================================================
Expand Down
1 change: 1 addition & 0 deletions service/Abstractions/MemoryStorage/IMemoryDb.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Task DeleteIndexAsync(
/// <param name="record">Vector + payload to save</param>
/// <param name="cancellationToken">Task cancellation token</param>
/// <returns>Record ID</returns>
/// <exception cref="IndexNotFound">Error returned if the index where to write doesn't exist</exception>
Task<string> UpsertAsync(
string index,
MemoryRecord record,
Expand Down
Loading

0 comments on commit c0b2eef

Please sign in to comment.