From c631a647a4817986701a39c13af4206aea0a7c96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Ma=C3=B1ez?= Date: Thu, 25 Apr 2024 05:45:10 +0200 Subject: [PATCH] Add Azure AI Search hybrid search support (#428) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation and Context (Why the change? What's the scenario?) Already described in issue #159 The main idea is to support Azure AI search Hybrid search. ## High level description (Approach, Design) The idea is to have a new Config property in the AzureAISearchConfig class, so Hybrid is only enabled explicitly. When enabled, the CosineSimilarity is not calculated and the minDistance is set to the minRelevance parameter (passed from the top SearchAsync method). --------- Co-authored-by: “luismanez” <“luis.manez@outlook.com”> Co-authored-by: Devis Lucato --- KernelMemory.sln | 7 +- .../111-dotnet-azure-ai-hybrid-search.csproj | 15 ++++ .../Program.cs | 83 +++++++++++++++++++ .../appsettings.json | 77 +++++++++++++++++ .../AzureAISearch/AzureAISearchConfig.cs | 6 ++ .../AzureAISearch/AzureAISearchMemory.cs | 14 +++- service/Core/Search/SearchClient.cs | 12 +++ service/Service/appsettings.json | 5 +- .../Services/AzureAISearch.cs | 2 + .../UI/DictionaryExtensions.cs | 18 ++++ tools/InteractiveSetup/UI/SetupUI.cs | 26 +++--- 11 files changed, 247 insertions(+), 18 deletions(-) create mode 100644 examples/111-dotnet-azure-ai-hybrid-search/111-dotnet-azure-ai-hybrid-search.csproj create mode 100644 examples/111-dotnet-azure-ai-hybrid-search/Program.cs create mode 100644 examples/111-dotnet-azure-ai-hybrid-search/appsettings.json create mode 100644 tools/InteractiveSetup/UI/DictionaryExtensions.cs diff --git a/KernelMemory.sln b/KernelMemory.sln index 3e0d29a69..8161dee99 100644 --- a/KernelMemory.sln +++ b/KernelMemory.sln @@ -258,6 +258,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "modules", "modules", "{C2D3 EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Service.AspNetCore", "service\Service.AspNetCore\Service.AspNetCore.csproj", "{A46B0BE1-03F2-4520-A3DA-FD845BA1FD69}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "111-dotnet-azure-ai-hybrid-search", "examples\111-dotnet-azure-ai-hybrid-search\111-dotnet-azure-ai-hybrid-search.csproj", "{28534545-CB39-446A-9EB9-A5ABBFE0CFD3}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -332,6 +334,7 @@ Global {8FB12876-013D-44CB-9F0D-E926D9F0F4E3} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841} {C2D3A947-B6F9-4306-BD42-21D8D1F42750} = {B488168B-AD86-4CC5-9D89-324B6EB743D9} {A46B0BE1-03F2-4520-A3DA-FD845BA1FD69} = {87DEAE8D-138C-4FDD-B4C9-11C3A7817E8F} + {28534545-CB39-446A-9EB9-A5ABBFE0CFD3} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841} EndGlobalSection GlobalSection(ProjectConfigurationPlatforms) = postSolution {8A9FA587-7EBA-4D43-BE47-38D798B1C74C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU @@ -532,10 +535,12 @@ Global {8FB12876-013D-44CB-9F0D-E926D9F0F4E3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {8FB12876-013D-44CB-9F0D-E926D9F0F4E3}.Debug|Any CPU.Build.0 = Debug|Any CPU {8FB12876-013D-44CB-9F0D-E926D9F0F4E3}.Release|Any CPU.ActiveCfg = Release|Any CPU - {8FB12876-013D-44CB-9F0D-E926D9F0F4E3}.Release|Any CPU.Build.0 = Release|Any CPU {A46B0BE1-03F2-4520-A3DA-FD845BA1FD69}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {A46B0BE1-03F2-4520-A3DA-FD845BA1FD69}.Debug|Any CPU.Build.0 = Debug|Any CPU {A46B0BE1-03F2-4520-A3DA-FD845BA1FD69}.Release|Any CPU.ActiveCfg = Release|Any CPU {A46B0BE1-03F2-4520-A3DA-FD845BA1FD69}.Release|Any CPU.Build.0 = Release|Any CPU + {28534545-CB39-446A-9EB9-A5ABBFE0CFD3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {28534545-CB39-446A-9EB9-A5ABBFE0CFD3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {28534545-CB39-446A-9EB9-A5ABBFE0CFD3}.Release|Any CPU.ActiveCfg = Release|Any CPU EndGlobalSection EndGlobal diff --git a/examples/111-dotnet-azure-ai-hybrid-search/111-dotnet-azure-ai-hybrid-search.csproj b/examples/111-dotnet-azure-ai-hybrid-search/111-dotnet-azure-ai-hybrid-search.csproj new file mode 100644 index 000000000..01734d87f --- /dev/null +++ b/examples/111-dotnet-azure-ai-hybrid-search/111-dotnet-azure-ai-hybrid-search.csproj @@ -0,0 +1,15 @@ + + + + net8.0 + enable + enable + false + $(NoWarn);CA1050;CA2000;CA1707;CA1303;CA2007;CA1724;CA1861;CA1859; + + + + + + + diff --git a/examples/111-dotnet-azure-ai-hybrid-search/Program.cs b/examples/111-dotnet-azure-ai-hybrid-search/Program.cs new file mode 100644 index 000000000..dfd07972f --- /dev/null +++ b/examples/111-dotnet-azure-ai-hybrid-search/Program.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft. All rights reserved. + +// ReSharper disable InconsistentNaming + +using Microsoft.KernelMemory; +using Microsoft.KernelMemory.AI.OpenAI; + +public static class Program +{ + private const string indexName = "acronyms"; + + public static async Task Main() + { + var azureOpenAITextConfig = new AzureOpenAIConfig(); + var azureOpenAIEmbeddingConfig = new AzureOpenAIConfig(); + var azureAISearchConfigWithHybridSearch = new AzureAISearchConfig(); + var azureAISearchConfigWithoutHybridSearch = new AzureAISearchConfig(); + + new ConfigurationBuilder() + .AddJsonFile("appsettings.json") + .AddJsonFile("appsettings.Development.json", optional: true) + .Build() + .BindSection("KernelMemory:Services:AzureOpenAIText", azureOpenAITextConfig) + .BindSection("KernelMemory:Services:AzureOpenAIEmbedding", azureOpenAIEmbeddingConfig) + .BindSection("KernelMemory:Services:AzureAISearch", azureAISearchConfigWithHybridSearch) + .BindSection("KernelMemory:Services:AzureAISearch", azureAISearchConfigWithoutHybridSearch); + + azureAISearchConfigWithHybridSearch.UseHybridSearch = true; + azureAISearchConfigWithoutHybridSearch.UseHybridSearch = false; + + var memoryNoHybridSearch = new KernelMemoryBuilder() + .WithAzureOpenAITextGeneration(azureOpenAITextConfig, new DefaultGPTTokenizer()) + .WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig, new DefaultGPTTokenizer()) + .WithAzureAISearchMemoryDb(azureAISearchConfigWithoutHybridSearch) + .WithSearchClientConfig(new SearchClientConfig { MaxMatchesCount = 2, Temperature = 0, TopP = 0 }) + .Build(); + + var memoryWithHybridSearch = new KernelMemoryBuilder() + .WithAzureOpenAITextGeneration(azureOpenAITextConfig, new DefaultGPTTokenizer()) + .WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig, new DefaultGPTTokenizer()) + .WithAzureAISearchMemoryDb(azureAISearchConfigWithHybridSearch) + .WithSearchClientConfig(new SearchClientConfig { MaxMatchesCount = 2, Temperature = 0, TopP = 0 }) + .Build(); + + await CreateIndexAndImportData(memoryWithHybridSearch); + + const string question = "abc"; + + Console.WriteLine("Answer without hybrid search:"); + await AskQuestion(memoryNoHybridSearch, question); + // Output: INFO NOT FOUND + + Console.WriteLine("Answer using hybrid search:"); + await AskQuestion(memoryWithHybridSearch, question); + // Output: 'Aliens Brewing Coffee' + } + + private static async Task AskQuestion(IKernelMemory memory, string question) + { + var answer = await memory.AskAsync(question, index: indexName); + Console.WriteLine(answer.Result); + } + + private static async Task CreateIndexAndImportData(IKernelMemory memory) + { + await memory.DeleteIndexAsync(indexName); + + var data = """ + aaa bbb ccc 000000000 + C B A ....... + ai bee cee Something else + XY. abc means 'Aliens Brewing Coffee' + abeec abecedario + A B C D first 4 letters + """; + + var rows = data.Split("\n"); + foreach (var acronym in rows) + { + await memory.ImportTextAsync(acronym, index: indexName); + } + } +} diff --git a/examples/111-dotnet-azure-ai-hybrid-search/appsettings.json b/examples/111-dotnet-azure-ai-hybrid-search/appsettings.json new file mode 100644 index 000000000..32766e715 --- /dev/null +++ b/examples/111-dotnet-azure-ai-hybrid-search/appsettings.json @@ -0,0 +1,77 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Warning", + // Examples: how to handle logs differently by class + // "Microsoft.KernelMemory.Handlers.TextExtractionHandler": "Information", + // "Microsoft.KernelMemory.Handlers.TextPartitioningHandler": "Information", + // "Microsoft.KernelMemory.Handlers.GenerateEmbeddingsHandler": "Information", + // "Microsoft.KernelMemory.Handlers.SaveEmbeddingsHandler": "Information", + // "Microsoft.KernelMemory.ContentStorage.AzureBlobs": "Information", + // "Microsoft.KernelMemory.Pipeline.Queue.AzureQueues": "Information", + "Microsoft.AspNetCore": "Warning" + } + }, + "KernelMemory": { + "Services": { + "AzureAISearch": { + // "ApiKey" or "AzureIdentity". For other options see . + // AzureIdentity: use automatic AAD authentication mechanism. You can test locally + // using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET. + "Auth": "AzureIdentity", + "Endpoint": "https://<...>", + "APIKey": "" + }, + "AzureOpenAIText": { + // "ApiKey" or "AzureIdentity" + // AzureIdentity: use automatic AAD authentication mechanism. You can test locally + // using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET. + "Auth": "AzureIdentity", + "Endpoint": "https://<...>.openai.azure.com/", + "APIKey": "", + "Deployment": "", + // The max number of tokens supported by model deployed + // See https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models + "MaxTokenTotal": 16384, + // "ChatCompletion" or "TextCompletion" + "APIType": "ChatCompletion", + "MaxRetries": 10 + }, + "AzureOpenAIEmbedding": { + // "ApiKey" or "AzureIdentity" + // AzureIdentity: use automatic AAD authentication mechanism. You can test locally + // using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET. + "Auth": "AzureIdentity", + "Endpoint": "https://<...>.openai.azure.com/", + "APIKey": "", + "Deployment": "", + // The max number of tokens supported by model deployed + // See https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models + "MaxTokenTotal": 8191 + }, + "OpenAI": { + // Name of the model used to generate text (text completion or chat completion) + "TextModel": "gpt-3.5-turbo-16k", + // The max number of tokens supported by the text model. + "TextModelMaxTokenTotal": 16384, + // What type of text generation, by default autodetect using the model name. + // Possible values: "Auto", "TextCompletion", "Chat" + "TextGenerationType": "Auto", + // Name of the model used to generate text embeddings + "EmbeddingModel": "text-embedding-ada-002", + // The max number of tokens supported by the embedding model + // See https://platform.openai.com/docs/guides/embeddings/what-are-embeddings + "EmbeddingModelMaxTokenTotal": 8191, + // OpenAI API Key + "APIKey": "", + // OpenAI Organization ID (usually empty, unless you have multiple accounts on different orgs) + "OrgId": "", + // Endpoint to use. By default the system uses 'https://api.openai.com/v1'. + // Change this to use proxies or services compatible with OpenAI HTTP protocol like LM Studio. + "Endpoint": "", + // How many times to retry in case of throttling + "MaxRetries": 10 + } + } + } +} \ No newline at end of file diff --git a/extensions/AzureAISearch/AzureAISearch/AzureAISearchConfig.cs b/extensions/AzureAISearch/AzureAISearch/AzureAISearchConfig.cs index 37e49c565..946b2e302 100644 --- a/extensions/AzureAISearch/AzureAISearch/AzureAISearchConfig.cs +++ b/extensions/AzureAISearch/AzureAISearch/AzureAISearchConfig.cs @@ -25,6 +25,12 @@ public enum AuthTypes public string Endpoint { get; set; } = string.Empty; public string APIKey { get; set; } = string.Empty; + /// + /// Important: when using hybrid search, relevance scores a very + /// different from when using just vector search. + /// + public bool UseHybridSearch { get; set; } = false; + public void SetCredential(TokenCredential credential) { this.Auth = AuthTypes.ManualTokenCredential; diff --git a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs index fba4855a1..2c6be3274 100644 --- a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs +++ b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs @@ -34,6 +34,7 @@ public class AzureAISearchMemory : IMemoryDb { private readonly ITextEmbeddingGenerator _embeddingGenerator; private readonly ILogger _log; + private readonly bool _useHybridSearch; /// /// Create a new instance @@ -48,6 +49,7 @@ public AzureAISearchMemory( { this._embeddingGenerator = embeddingGenerator; this._log = log ?? DefaultLogger.Instance; + this._useHybridSearch = config.UseHybridSearch; if (string.IsNullOrEmpty(config.Endpoint)) { @@ -184,8 +186,9 @@ await client.IndexDocumentsAsync( Response>? searchResult = null; try { + var keyword = this._useHybridSearch ? text : null; searchResult = await client - .SearchAsync(null, options, cancellationToken: cancellationToken) + .SearchAsync(keyword, options, cancellationToken: cancellationToken) .ConfigureAwait(false); } catch (RequestFailedException e) when (e.Status == 404) @@ -196,14 +199,19 @@ await client.IndexDocumentsAsync( if (searchResult == null) { yield break; } - var minDistance = CosineSimilarityToScore(minRelevance); + var minDistance = this._useHybridSearch ? minRelevance : CosineSimilarityToScore(minRelevance); + var count = 0; await foreach (SearchResult? doc in searchResult.Value.GetResultsAsync().ConfigureAwait(false)) { if (doc == null || doc.Score < minDistance) { continue; } + // In cases where Azure Search is returning too many records + if (++count > limit) { break; } + MemoryRecord memoryRecord = doc.Document.ToMemoryRecord(withEmbeddings); - yield return (memoryRecord, ScoreToCosineSimilarity(doc.Score ?? 0)); + var documentScore = this._useHybridSearch ? doc.Score ?? 0 : ScoreToCosineSimilarity(doc.Score ?? 0); + yield return (memoryRecord, documentScore); } } diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index b999e2639..e4c217352 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -163,6 +163,12 @@ public async Task SearchAsync( LastUpdate = memory.GetLastUpdate(), Tags = memory.Tags, }); + + // In cases where a buggy storage connector is returning too many records + if (result.Results.Count >= this._config.MaxMatchesCount) + { + break; + } } if (result.Results.Count == 0) @@ -284,6 +290,12 @@ public async Task AskAsync( LastUpdate = memory.GetLastUpdate(), Tags = memory.Tags, }); + + // In cases where a buggy storage connector is returning too many records + if (factsUsedCount >= this._config.MaxMatchesCount) + { + break; + } } if (factsAvailableCount > 0 && factsUsedCount == 0) diff --git a/service/Service/appsettings.json b/service/Service/appsettings.json index f812f3f21..f536cf1fc 100644 --- a/service/Service/appsettings.json +++ b/service/Service/appsettings.json @@ -210,7 +210,10 @@ // using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET. "Auth": "AzureIdentity", "Endpoint": "https://<...>", - "APIKey": "" + "APIKey": "", + // Hybrid search is not enabled by default. Note that when using hybrid search + // relevance scores are different, usually lower, than when using just vector search + "UseHybridSearch": false }, "AzureAIDocIntel": { // "APIKey" or "AzureIdentity". diff --git a/tools/InteractiveSetup/Services/AzureAISearch.cs b/tools/InteractiveSetup/Services/AzureAISearch.cs index 071e02990..b6e58e364 100644 --- a/tools/InteractiveSetup/Services/AzureAISearch.cs +++ b/tools/InteractiveSetup/Services/AzureAISearch.cs @@ -21,6 +21,7 @@ public static void Setup(Context ctx, bool force = false) { "Auth", "ApiKey" }, { "Endpoint", "" }, { "APIKey", "" }, + { "UseHybridSearch", false }, }; } @@ -29,6 +30,7 @@ public static void Setup(Context ctx, bool force = false) { "Auth", "ApiKey" }, { "Endpoint", SetupUI.AskOpenQuestion("Azure AI Search ", config["Endpoint"].ToString()) }, { "APIKey", SetupUI.AskPassword("Azure AI Search ", config["APIKey"].ToString()) }, + { "UseHybridSearch", SetupUI.AskBoolean("Use hybrid search (yes/no)?", (bool)config["UseHybridSearch"]) }, }); } } diff --git a/tools/InteractiveSetup/UI/DictionaryExtensions.cs b/tools/InteractiveSetup/UI/DictionaryExtensions.cs new file mode 100644 index 000000000..3d509c875 --- /dev/null +++ b/tools/InteractiveSetup/UI/DictionaryExtensions.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; + +namespace Microsoft.KernelMemory.InteractiveSetup.UI; + +internal static class DictionaryExtensions +{ + public static string TryGet(this Dictionary data, string key) + { + return data.TryGetValue(key, out object? value) ? value.ToString() ?? string.Empty : string.Empty; + } + + public static string TryGetOr(this Dictionary data, string key, string fallbackValue) + { + return data.TryGetValue(key, out object? value) ? value.ToString() ?? string.Empty : fallbackValue; + } +} diff --git a/tools/InteractiveSetup/UI/SetupUI.cs b/tools/InteractiveSetup/UI/SetupUI.cs index 1fb64b0d8..1c4c8ddee 100644 --- a/tools/InteractiveSetup/UI/SetupUI.cs +++ b/tools/InteractiveSetup/UI/SetupUI.cs @@ -1,28 +1,28 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; +using System.Linq; namespace Microsoft.KernelMemory.InteractiveSetup.UI; -public static class DictionaryExtensions +internal static class SetupUI { - public static string TryGet(this Dictionary data, string key) + public static string AskPassword(string question, string? defaultValue, bool trim = true, bool optional = false) { - return data.TryGetValue(key, out object? value) ? value.ToString() ?? string.Empty : string.Empty; + return AskOpenQuestion(question: question, defaultValue: defaultValue, trim: trim, optional: optional, isPassword: true); } - public static string TryGetOr(this Dictionary data, string key, string fallbackValue) + public static bool AskBoolean(string question, bool defaultValue) { - return data.TryGetValue(key, out object? value) ? value.ToString() ?? string.Empty : fallbackValue; - } -} + string[] yes = { "YES", "Y" }; + string[] no = { "NO", "N" }; + while (true) + { + var answer = AskOpenQuestion(question: question, defaultValue: defaultValue ? "Yes" : "No", optional: false).ToUpperInvariant(); + if (yes.Contains(answer)) { return true; } -public static class SetupUI -{ - public static string AskPassword(string question, string? defaultValue, bool trim = true, bool optional = false) - { - return AskOpenQuestion(question: question, defaultValue: defaultValue, trim: trim, optional: optional, isPassword: true); + if (no.Contains(answer)) { return false; } + } } public static string AskOptionalOpenQuestion(string question, string? defaultValue)