diff --git a/examples/003-dotnet-Serverless/003-dotnet-Serverless.csproj b/examples/003-dotnet-Serverless/003-dotnet-Serverless.csproj index 4f270ded6..62f43c775 100644 --- a/examples/003-dotnet-Serverless/003-dotnet-Serverless.csproj +++ b/examples/003-dotnet-Serverless/003-dotnet-Serverless.csproj @@ -11,6 +11,9 @@ + + + diff --git a/examples/003-dotnet-Serverless/Program.cs b/examples/003-dotnet-Serverless/Program.cs index 0537606cc..f21ae1a97 100644 --- a/examples/003-dotnet-Serverless/Program.cs +++ b/examples/003-dotnet-Serverless/Program.cs @@ -11,6 +11,7 @@ using Microsoft.KernelMemory.AI.OpenAI; using Microsoft.KernelMemory.ContentStorage.DevTools; using Microsoft.KernelMemory.MemoryStorage.DevTools; +using Microsoft.KernelMemory.Postgres; /* Use MemoryServerlessClient to run the default import pipeline * in the same process, without distributed queues. @@ -36,6 +37,7 @@ public static async Task Main() var searchClientConfig = new SearchClientConfig(); var azDocIntelConfig = new AzureAIDocIntelConfig(); var azureAISearchConfig = new AzureAISearchConfig(); + var postgresConfig = new PostgresConfig(); new ConfigurationBuilder() .AddJsonFile("appsettings.json") @@ -47,6 +49,7 @@ public static async Task Main() .BindSection("KernelMemory:Services:LlamaSharp", llamaConfig) .BindSection("KernelMemory:Services:AzureAIDocIntel", azDocIntelConfig) .BindSection("KernelMemory:Services:AzureAISearch", azureAISearchConfig) + .BindSection("KernelMemory:Services:Postgres", postgresConfig) .BindSection("KernelMemory:Retrieval:SearchClient", searchClientConfig); s_memory = new KernelMemoryBuilder() @@ -60,6 +63,7 @@ public static async Task Main() // .WithAzureBlobsStorage(new AzureBlobsConfig {...}) // Store files in Azure Blobs // .WithSimpleVectorDb(SimpleVectorDbConfig.Persistent) // Store memories on disk // .WithAzureAISearchMemoryDb(azureAISearchConfig) // Store memories in Azure AI Search + // .WithPostgresMemoryDb(postgresConfig) // Store memories in Postgres // .WithQdrantMemoryDb("http://127.0.0.1:6333") // Store memories in Qdrant .Build(); diff --git a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs index 2c6be3274..69ee5d400 100644 --- a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs +++ b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs @@ -129,10 +129,17 @@ public async Task UpsertAsync(string index, MemoryRecord record, Cancell var client = this.GetSearchClient(index); AzureAISearchMemoryRecord localRecord = AzureAISearchMemoryRecord.FromMemoryRecord(record); - await client.IndexDocumentsAsync( - IndexDocumentsBatch.Upload(new[] { localRecord }), - new IndexDocumentsOptions { ThrowOnAnyError = true }, - cancellationToken: cancellationToken).ConfigureAwait(false); + try + { + await client.IndexDocumentsAsync( + IndexDocumentsBatch.Upload(new[] { localRecord }), + new IndexDocumentsOptions { ThrowOnAnyError = true }, + cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch (RequestFailedException e) when (IsIndexNotFoundException(e)) + { + throw new IndexNotFound(e.Message, e); + } return record.Id; } @@ -400,6 +407,13 @@ private SearchClient GetSearchClient(string index) return client; } + private static bool IsIndexNotFoundException(RequestFailedException e) + { + return e.Status == 404 + && e.Message.Contains("index", StringComparison.OrdinalIgnoreCase) + && e.Message.Contains("not found", StringComparison.OrdinalIgnoreCase); + } + private static void ValidateSchema(MemoryDbSchema schema) { schema.Validate(vectorSizeRequired: true); diff --git a/extensions/Postgres/Postgres/Db/PostgresDbClient.cs b/extensions/Postgres/Postgres/Db/PostgresDbClient.cs index d5f21dd7e..517006069 100644 --- a/extensions/Postgres/Postgres/Db/PostgresDbClient.cs +++ b/extensions/Postgres/Postgres/Db/PostgresDbClient.cs @@ -9,6 +9,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.Diagnostics; +using Microsoft.KernelMemory.MemoryStorage; using Npgsql; using NpgsqlTypes; using Pgvector; @@ -289,7 +290,7 @@ public async Task DeleteTableAsync( this._log.LogTrace("Deleting table. SQL: {0}", cmd.CommandText); await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } - catch (Npgsql.PostgresException e) when (IsNotFoundException(e)) + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) { this._log.LogTrace("Table not found: {0}", tableName); } @@ -315,12 +316,14 @@ public async Task UpsertAsync( NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - await using (connection) + try { - using NpgsqlCommand cmd = connection.CreateCommand(); + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); #pragma warning disable CA2100 // SQL reviewed - cmd.CommandText = $@" + cmd.CommandText = $@" INSERT INTO {tableName} ({this._colId}, {this._colEmbedding}, {this._colTags}, {this._colContent}, {this._colPayload}) VALUES @@ -333,16 +336,25 @@ DO UPDATE SET {this._colPayload} = @payload "; - cmd.Parameters.AddWithValue("@id", record.Id); - cmd.Parameters.AddWithValue("@embedding", record.Embedding); - cmd.Parameters.AddWithValue("@tags", NpgsqlDbType.Array | NpgsqlDbType.Text, record.Tags.ToArray() ?? emptyTags); - cmd.Parameters.AddWithValue("@content", NpgsqlDbType.Text, record.Content ?? EmptyContent); - cmd.Parameters.AddWithValue("@payload", NpgsqlDbType.Jsonb, record.Payload ?? EmptyPayload); + cmd.Parameters.AddWithValue("@id", record.Id); + cmd.Parameters.AddWithValue("@embedding", record.Embedding); + cmd.Parameters.AddWithValue("@tags", NpgsqlDbType.Array | NpgsqlDbType.Text, record.Tags.ToArray() ?? emptyTags); + cmd.Parameters.AddWithValue("@content", NpgsqlDbType.Text, record.Content ?? EmptyContent); + cmd.Parameters.AddWithValue("@payload", NpgsqlDbType.Jsonb, record.Payload ?? EmptyPayload); #pragma warning restore CA2100 - this._log.LogTrace("Upserting record '{0}' in table '{1}'", record.Id, tableName); + this._log.LogTrace("Upserting record '{0}' in table '{1}'", record.Id, tableName); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + } + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) + { + throw new IndexNotFound(e.Message, e); + } + catch (Exception e) + { + throw new PostgresException(e.Message, e); } } @@ -436,7 +448,7 @@ OFFSET @offset result.Add((this.ReadEntry(dataReader, withEmbeddings), similarity)); } } - catch (Npgsql.PostgresException e) when (IsNotFoundException(e)) + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) { this._log.LogTrace("Table not found: {0}", tableName); } @@ -530,7 +542,7 @@ OFFSET @offset result.Add(this.ReadEntry(dataReader, withEmbeddings)); } } - catch (Npgsql.PostgresException e) when (IsNotFoundException(e)) + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) { this._log.LogTrace("Table not found: {0}", tableName); } @@ -572,7 +584,7 @@ public async Task DeleteAsync( { await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } - catch (Npgsql.PostgresException e) when (IsNotFoundException(e)) + catch (Npgsql.PostgresException e) when (IsTableNotFoundException(e)) { this._log.LogTrace("Table not found: {0}", tableName); } @@ -637,7 +649,7 @@ private string WithTableNamePrefix(string tableName) return $"{this._tableNamePrefix}{tableName}"; } - private static bool IsNotFoundException(Npgsql.PostgresException e) + private static bool IsTableNotFoundException(Npgsql.PostgresException e) { return (e.SqlState == PgErrUndefinedTable || e.Message.Contains("does not exist", StringComparison.OrdinalIgnoreCase)); } diff --git a/extensions/Qdrant/Qdrant/Client/QdrantClient.cs b/extensions/Qdrant/Qdrant/Client/QdrantClient.cs index f750835d9..728920b76 100644 --- a/extensions/Qdrant/Qdrant/Client/QdrantClient.cs +++ b/extensions/Qdrant/Qdrant/Client/QdrantClient.cs @@ -12,6 +12,7 @@ using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.Diagnostics; using Microsoft.KernelMemory.MemoryDb.Qdrant.Client.Http; +using Microsoft.KernelMemory.MemoryStorage; namespace Microsoft.KernelMemory.MemoryDb.Qdrant.Client; @@ -177,7 +178,11 @@ public async Task UpsertVectorsAsync( var (response, content) = await this.ExecuteHttpRequestAsync(request, cancellationToken).ConfigureAwait(false); this.ValidateResponse(response, content, nameof(this.UpsertVectorsAsync)); - if (JsonSerializer.Deserialize(content)?.Status != "ok") + UpsertVectorResponse? qdrantResponse = JsonSerializer.Deserialize(content); + ArgumentNullExceptionEx.ThrowIfNull(qdrantResponse, nameof(qdrantResponse), "Qdrant response is NULL"); + ArgumentNullExceptionEx.ThrowIfNull(qdrantResponse.Status, nameof(qdrantResponse.Status), "Qdrant response status is NULL"); + + if (qdrantResponse.Status != "ok") { this._log.LogWarning("Vector upserts failed"); } @@ -423,6 +428,12 @@ private void ValidateResponse(HttpResponseMessage response, string content, stri } else { + if (response.StatusCode == HttpStatusCode.NotFound && responseContent.Contains("Not found: Collection", StringComparison.OrdinalIgnoreCase)) + { + this._log.LogWarning("Qdrant collection not found: {0}, {1}", response.StatusCode, responseContent); + throw new IndexNotFound(responseContent); + } + if (!responseContent.Contains("already exists", StringComparison.OrdinalIgnoreCase)) { this._log.LogWarning("Qdrant responded with error: {0}, {1}", response.StatusCode, responseContent); diff --git a/extensions/Qdrant/Qdrant/QdrantMemory.cs b/extensions/Qdrant/Qdrant/QdrantMemory.cs index c07d48e9b..25710fbca 100644 --- a/extensions/Qdrant/Qdrant/QdrantMemory.cs +++ b/extensions/Qdrant/Qdrant/QdrantMemory.cs @@ -71,8 +71,17 @@ public Task DeleteIndexAsync( string index, CancellationToken cancellationToken = default) { - index = NormalizeIndexName(index); - return this._qdrantClient.DeleteCollectionAsync(index, cancellationToken); + try + { + index = NormalizeIndexName(index); + return this._qdrantClient.DeleteCollectionAsync(index, cancellationToken); + } + catch (IndexNotFound) + { + this._log.LogInformation("Index not found, nothing to delete"); + } + + return Task.CompletedTask; } /// @@ -139,14 +148,25 @@ public async Task UpsertAsync( } Embedding textEmbedding = await this._embeddingGenerator.GenerateEmbeddingAsync(text, cancellationToken).ConfigureAwait(false); - List<(QdrantPoint, double)> results = await this._qdrantClient.GetSimilarListAsync( - collectionName: index, - target: textEmbedding, - scoreThreshold: minRelevance, - requiredTags: requiredTags, - limit: limit, - withVectors: withEmbeddings, - cancellationToken: cancellationToken).ConfigureAwait(false); + + List<(QdrantPoint, double)> results; + try + { + results = await this._qdrantClient.GetSimilarListAsync( + collectionName: index, + target: textEmbedding, + scoreThreshold: minRelevance, + requiredTags: requiredTags, + limit: limit, + withVectors: withEmbeddings, + cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch (IndexNotFound e) + { + this._log.LogWarning(e, "Index not found"); + // Nothing to return + yield break; + } foreach (var point in results) { @@ -174,13 +194,23 @@ public async IAsyncEnumerable GetListAsync( requiredTags.AddRange(filters.Select(filter => filter.GetFilters().Select(x => $"{x.Key}{Constants.ReservedEqualsChar}{x.Value}"))); } - List> results = await this._qdrantClient.GetListAsync( - collectionName: index, - requiredTags: requiredTags, - offset: 0, - limit: limit, - withVectors: withEmbeddings, - cancellationToken: cancellationToken).ConfigureAwait(false); + List> results; + try + { + results = await this._qdrantClient.GetListAsync( + collectionName: index, + requiredTags: requiredTags, + offset: 0, + limit: limit, + withVectors: withEmbeddings, + cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch (IndexNotFound e) + { + this._log.LogWarning(e, "Index not found"); + // Nothing to return + yield break; + } foreach (var point in results) { @@ -196,17 +226,24 @@ public async Task DeleteAsync( { index = NormalizeIndexName(index); - QdrantPoint? existingPoint = await this._qdrantClient - .GetVectorByPayloadIdAsync(index, record.Id, cancellationToken: cancellationToken) - .ConfigureAwait(false); - if (existingPoint == null) + try { - this._log.LogTrace("No record with ID {0} found, nothing to delete", record.Id); - return; - } + QdrantPoint? existingPoint = await this._qdrantClient + .GetVectorByPayloadIdAsync(index, record.Id, cancellationToken: cancellationToken) + .ConfigureAwait(false); + if (existingPoint == null) + { + this._log.LogTrace("No record with ID {0} found, nothing to delete", record.Id); + return; + } - this._log.LogTrace("Point ID {0} found, deleting...", existingPoint.Id); - await this._qdrantClient.DeleteVectorsAsync(index, new List { existingPoint.Id }, cancellationToken).ConfigureAwait(false); + this._log.LogTrace("Point ID {0} found, deleting...", existingPoint.Id); + await this._qdrantClient.DeleteVectorsAsync(index, new List { existingPoint.Id }, cancellationToken).ConfigureAwait(false); + } + catch (IndexNotFound e) + { + this._log.LogInformation(e, "Index not found, nothing to delete"); + } } #region private ================================================================================ diff --git a/service/Abstractions/MemoryStorage/IMemoryDb.cs b/service/Abstractions/MemoryStorage/IMemoryDb.cs index ffb6f58b5..7945038e4 100644 --- a/service/Abstractions/MemoryStorage/IMemoryDb.cs +++ b/service/Abstractions/MemoryStorage/IMemoryDb.cs @@ -45,6 +45,7 @@ Task DeleteIndexAsync( /// Vector + payload to save /// Task cancellation token /// Record ID + /// Error returned if the index where to write doesn't exist Task UpsertAsync( string index, MemoryRecord record, diff --git a/service/Core/Handlers/SaveRecordsHandler.cs b/service/Core/Handlers/SaveRecordsHandler.cs index 77c934f4a..d3bc05dee 100644 --- a/service/Core/Handlers/SaveRecordsHandler.cs +++ b/service/Core/Handlers/SaveRecordsHandler.cs @@ -102,6 +102,9 @@ public SaveRecordsHandler( { var embeddingsFound = false; + // TODO: replace with ConditionalWeakTable indexing on this._memoryDbs + var createdIndexes = new HashSet(); + // For each embedding file => For each Memory DB => Upsert record foreach (FileDetailsWithRecordId embeddingFile in GetListOfEmbeddingFiles(pipeline)) { @@ -142,11 +145,21 @@ public SaveRecordsHandler( foreach (IMemoryDb client in this._memoryDbs) { - this._log.LogTrace("Creating index '{0}'", pipeline.Index); - await client.CreateIndexAsync(pipeline.Index, record.Vector.Length, cancellationToken).ConfigureAwait(false); + try + { + await this.CreateIndexOnceAsync(client, createdIndexes, pipeline.Index, record.Vector.Length, cancellationToken).ConfigureAwait(false); - this._log.LogTrace("Saving record {0} in index '{1}'", record.Id, pipeline.Index); - await client.UpsertAsync(pipeline.Index, record, cancellationToken).ConfigureAwait(false); + this._log.LogTrace("Saving record {0} in index '{1}'", record.Id, pipeline.Index); + await client.UpsertAsync(pipeline.Index, record, cancellationToken).ConfigureAwait(false); + } + catch (IndexNotFound e) + { + this._log.LogWarning(e, "Index {0} not found, attempting to create it", pipeline.Index); + await this.CreateIndexOnceAsync(client, createdIndexes, pipeline.Index, record.Vector.Length, cancellationToken, true).ConfigureAwait(false); + + this._log.LogTrace("Retry: Saving record {0} in index '{1}'", record.Id, pipeline.Index); + await client.UpsertAsync(pipeline.Index, record, cancellationToken).ConfigureAwait(false); + } } embeddingFile.File.MarkProcessedBy(this); @@ -168,6 +181,9 @@ public SaveRecordsHandler( { var partitionsFound = false; + // TODO: replace with ConditionalWeakTable indexing on this._memoryDbs + var createdIndexes = new HashSet(); + // Create records only for partitions (text chunks) and synthetic data foreach (FileDetailsWithRecordId file in GetListOfPartitionAndSyntheticFiles(pipeline)) { @@ -206,11 +222,21 @@ public SaveRecordsHandler( foreach (IMemoryDb client in this._memoryDbs) { - this._log.LogTrace("Creating index '{0}'", pipeline.Index); - await client.CreateIndexAsync(pipeline.Index, record.Vector.Length, cancellationToken).ConfigureAwait(false); - - this._log.LogTrace("Saving record {0} in index '{1}'", record.Id, pipeline.Index); - await client.UpsertAsync(pipeline.Index, record, cancellationToken).ConfigureAwait(false); + try + { + await this.CreateIndexOnceAsync(client, createdIndexes, pipeline.Index, record.Vector.Length, cancellationToken).ConfigureAwait(false); + + this._log.LogTrace("Saving record {0} in index '{1}'", record.Id, pipeline.Index); + await client.UpsertAsync(pipeline.Index, record, cancellationToken).ConfigureAwait(false); + } + catch (IndexNotFound e) + { + this._log.LogWarning(e, "Index {0} not found, attempting to create it", pipeline.Index); + await this.CreateIndexOnceAsync(client, createdIndexes, pipeline.Index, record.Vector.Length, cancellationToken, true).ConfigureAwait(false); + + this._log.LogTrace("Retry: saving record {0} in index '{1}'", record.Id, pipeline.Index); + await client.UpsertAsync(pipeline.Index, record, cancellationToken).ConfigureAwait(false); + } } break; @@ -273,6 +299,24 @@ private static IEnumerable GetListOfPartitionAndSynthet .Select(x => new FileDetailsWithRecordId(pipeline, x.Value))); } + private async Task CreateIndexOnceAsync( + IMemoryDb client, + HashSet createdIndexes, + string indexName, + int vectorLength, + CancellationToken cancellationToken, + bool force = false) + { + // TODO: add support for the same client being used multiple times with different models with the same vectorLength + var key = $"{client.GetType().Name}::{indexName}::{vectorLength}"; + + if (!force && createdIndexes.Contains(key)) { return; } + + this._log.LogTrace("Creating index '{0}'", indexName); + await client.CreateIndexAsync(indexName, vectorLength, cancellationToken).ConfigureAwait(false); + createdIndexes.Add(key); + } + private async Task GetSourceUrlAsync( DataPipeline pipeline, DataPipeline.FileDetails file, diff --git a/service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs b/service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs index ec2d506f0..0bcb575c0 100644 --- a/service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs +++ b/service/Core/MemoryStorage/DevTools/SimpleVectorDb.cs @@ -84,6 +84,7 @@ public Task DeleteIndexAsync(string index, CancellationToken cancellationToken = /// public async Task UpsertAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default) { + // Note: if the index doesn't exist, it's automatically created (the index is just a folder) index = NormalizeIndexName(index); await this._fileSystem.WriteFileAsync(index, "", EncodeId(record.Id), JsonSerializer.Serialize(record), cancellationToken).ConfigureAwait(false); return record.Id;